aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuss Cox <rsc@golang.org>2011-11-01 00:29:41 -0400
committerRuss Cox <rsc@golang.org>2011-11-01 00:29:41 -0400
commit2e79e8e54920c005af29447a85d7b241460c34cb (patch)
tree56ab3eff731cb345792b0465596b61a22f6397d2
parent7b04471dfaaddc49efad470275fcc45546870a73 (diff)
downloadgo-2e79e8e54920c005af29447a85d7b241460c34cb.tar.gz
go-2e79e8e54920c005af29447a85d7b241460c34cb.zip
rpc: avoid infinite loop on input error
Fixes #1828. Fixes #2179. R=golang-dev, r CC=golang-dev https://golang.org/cl/5305084
-rw-r--r--src/pkg/rpc/jsonrpc/all_test.go65
-rw-r--r--src/pkg/rpc/server.go20
-rw-r--r--src/pkg/rpc/server_test.go3
3 files changed, 79 insertions, 9 deletions
diff --git a/src/pkg/rpc/jsonrpc/all_test.go b/src/pkg/rpc/jsonrpc/all_test.go
index c1a9e8ecbc..99253baf3c 100644
--- a/src/pkg/rpc/jsonrpc/all_test.go
+++ b/src/pkg/rpc/jsonrpc/all_test.go
@@ -6,6 +6,7 @@ package jsonrpc
import (
"fmt"
+ "io"
"json"
"net"
"os"
@@ -154,3 +155,67 @@ func TestClient(t *testing.T) {
t.Error("Div: expected divide by zero error; got", err)
}
}
+
+func TestMalformedInput(t *testing.T) {
+ cli, srv := net.Pipe()
+ go cli.Write([]byte(`{id:1}`)) // invalid json
+ ServeConn(srv) // must return, not loop
+}
+
+func TestUnexpectedError(t *testing.T) {
+ cli, srv := myPipe()
+ go cli.PipeWriter.CloseWithError(os.NewError("unexpected error!")) // reader will get this error
+ ServeConn(srv) // must return, not loop
+}
+
+// Copied from package net.
+func myPipe() (*pipe, *pipe) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ return &pipe{r1, w2}, &pipe{r2, w1}
+}
+
+type pipe struct {
+ *io.PipeReader
+ *io.PipeWriter
+}
+
+type pipeAddr int
+
+func (pipeAddr) Network() string {
+ return "pipe"
+}
+
+func (pipeAddr) String() string {
+ return "pipe"
+}
+
+func (p *pipe) Close() os.Error {
+ err := p.PipeReader.Close()
+ err1 := p.PipeWriter.Close()
+ if err == nil {
+ err = err1
+ }
+ return err
+}
+
+func (p *pipe) LocalAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) RemoteAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) SetTimeout(nsec int64) os.Error {
+ return os.NewError("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetReadTimeout(nsec int64) os.Error {
+ return os.NewError("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetWriteTimeout(nsec int64) os.Error {
+ return os.NewError("net.Pipe does not support timeouts")
+}
diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go
index f03710061a..142bf8a529 100644
--- a/src/pkg/rpc/server.go
+++ b/src/pkg/rpc/server.go
@@ -394,12 +394,12 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
for {
- service, mtype, req, argv, replyv, err := server.readRequest(codec)
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
if err != os.EOF {
log.Println("rpc:", err)
}
- if err == os.EOF || err == io.ErrUnexpectedEOF {
+ if !keepReading {
break
}
// send a response if we actually managed to read a header.
@@ -418,9 +418,9 @@ func (server *Server) ServeCodec(codec ServerCodec) {
// It does not close the codec upon completion.
func (server *Server) ServeRequest(codec ServerCodec) os.Error {
sending := new(sync.Mutex)
- service, mtype, req, argv, replyv, err := server.readRequest(codec)
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
if err != nil {
- if err == os.EOF || err == io.ErrUnexpectedEOF {
+ if !keepReading {
return err
}
// send a response if we actually managed to read a header.
@@ -474,10 +474,10 @@ func (server *Server) freeResponse(resp *Response) {
server.respLock.Unlock()
}
-func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) {
- service, mtype, req, err = server.readRequestHeader(codec)
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err os.Error) {
+ service, mtype, req, keepReading, err = server.readRequestHeader(codec)
if err != nil {
- if err == os.EOF || err == io.ErrUnexpectedEOF {
+ if !keepReading {
return
}
// discard body
@@ -505,7 +505,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m
return
}
-func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) {
+func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err os.Error) {
// Grab the request header.
req = server.getRequest()
err = codec.ReadRequestHeader(req)
@@ -518,6 +518,10 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
return
}
+ // We read the header successfully. If we see an error now,
+ // we can still recover and move on to the next request.
+ keepReading = true
+
serviceMethod := strings.Split(req.ServiceMethod, ".")
if len(serviceMethod) != 2 {
err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod)
diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go
index 029741b28b..3e9fe297d4 100644
--- a/src/pkg/rpc/server_test.go
+++ b/src/pkg/rpc/server_test.go
@@ -311,8 +311,9 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error {
func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
if resp.Error != "" {
codec.err = os.NewError(resp.Error)
+ } else {
+ *codec.reply = *(reply.(*Reply))
}
- *codec.reply = *(reply.(*Reply))
return nil
}