aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Pike <r@golang.org>2010-09-04 23:41:54 +1000
committerRob Pike <r@golang.org>2010-09-04 23:41:54 +1000
commitd54b921c9baf19105d074896da98d6dbf3c9e219 (patch)
tree7fd411c84eaad40f29b7f72dba917d7a8e2ba690
parent6405ab0fae63b9b5e051c8d19ba2ee3666cb327e (diff)
downloadgo-d54b921c9baf19105d074896da98d6dbf3c9e219.tar.gz
go-d54b921c9baf19105d074896da98d6dbf3c9e219.zip
netchan: use acknowledgements on export send.
Also add exporter.Drain() to wait for completion. This makes it possible for an Exporter to fire off a message and wait (by calling Drain) for the message to be received, even if a client has yet to call to retrieve it. Once this design is settled, I'll do the same for import send. Testing strategies welcome. I have some working stand-alone tests. R=rsc CC=golang-dev https://golang.org/cl/2137041
-rw-r--r--src/pkg/netchan/common.go98
-rw-r--r--src/pkg/netchan/export.go162
-rw-r--r--src/pkg/netchan/import.go26
-rw-r--r--src/pkg/netchan/netchan_test.go43
4 files changed, 268 insertions, 61 deletions
diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go
index 624397ef46..c5fd5698cf 100644
--- a/src/pkg/netchan/common.go
+++ b/src/pkg/netchan/common.go
@@ -10,6 +10,7 @@ import (
"os"
"reflect"
"sync"
+ "time"
)
// The direction of a connection from the client's perspective.
@@ -25,6 +26,7 @@ const (
payRequest = iota // request structure follows
payError // error structure follows
payData // user payload follows
+ payAck // acknowledgement; no payload
)
// A header is sent as a prefix to every transmission. It will be followed by
@@ -32,13 +34,14 @@ const (
type header struct {
name string
payloadType int
+ seqNum int64
}
// Sent with a header once per channel from importer to exporter to report
// that it wants to bind to a channel with the specified direction for count
// messages. If count is zero, it means unlimited.
type request struct {
- count int
+ count int64
dir Dir
}
@@ -47,6 +50,27 @@ type error struct {
error string
}
+// Used to unify management of acknowledgements for import and export.
+type unackedCounter interface {
+ unackedCount() int64
+ ack() int64
+ seq() int64
+}
+
+// A channel and its direction.
+type chanDir struct {
+ ch *reflect.ChanValue
+ dir Dir
+}
+
+// clientSet contains the objects and methods needed for tracking
+// clients of an exporter and draining outstanding messages.
+type clientSet struct {
+ mu sync.Mutex // protects access to channel and client maps
+ chans map[string]*chanDir
+ clients map[unackedCounter]bool
+}
+
// Mutex-protected encoder and decoder pair.
type encDec struct {
decLock sync.Mutex
@@ -79,10 +103,78 @@ func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.E
hdr.payloadType = payloadType
err := ed.enc.Encode(hdr)
if err == nil {
- err = ed.enc.Encode(payload)
- } else {
+ if payload != nil {
+ err = ed.enc.Encode(payload)
+ }
+ }
+ if err != nil {
// TODO: tear down connection if there is an error?
}
ed.encLock.Unlock()
return err
}
+
+// See the comment for Exporter.Drain.
+func (cs *clientSet) drain(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any messages waiting for a client?
+ for _, chDir := range cs.chans {
+ if chDir.ch.Len() > 0 {
+ pending = true
+ }
+ }
+ // Any unacknowledged messages?
+ for client := range cs.clients {
+ n := client.unackedCount()
+ if n > 0 { // Check for > rather than != just to be safe.
+ pending = true
+ break
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}
+
+// See the comment for Exporter.Sync.
+func (cs *clientSet) sync(timeout int64) os.Error {
+ startTime := time.Nanoseconds()
+ // seq remembers the clients and their seqNum at point of entry.
+ seq := make(map[unackedCounter]int64)
+ for client := range cs.clients {
+ seq[client] = client.seq()
+ }
+ for {
+ pending := false
+ cs.mu.Lock()
+ // Any unacknowledged messages? Look only at clients that existed
+ // when we started and are still in this client set.
+ for client := range seq {
+ if _, ok := cs.clients[client]; ok {
+ if client.ack() < seq[client] {
+ pending = true
+ break
+ }
+ }
+ }
+ cs.mu.Unlock()
+ if !pending {
+ break
+ }
+ if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+ return os.ErrorString("timeout")
+ }
+ time.Sleep(100 * 1e6) // 100 milliseconds
+ }
+ return nil
+}
diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go
index 3142eebf73..c42e35c56d 100644
--- a/src/pkg/netchan/export.go
+++ b/src/pkg/netchan/export.go
@@ -31,59 +31,47 @@ import (
// Export
-// A channel and its associated information: a direction plus
-// a handy marshaling place for its data.
-type exportChan struct {
- ch *reflect.ChanValue
- dir Dir
-}
-
// An Exporter allows a set of channels to be published on a single
// network port. A single machine may have multiple Exporters
// but they must use different ports.
type Exporter struct {
+ *clientSet
listener net.Listener
- chanLock sync.Mutex // protects access to channel map
- chans map[string]*exportChan
}
type expClient struct {
*encDec
- exp *Exporter
+ exp *Exporter
+ mu sync.Mutex // protects remaining fields
+ errored bool // client has been sent an error
+ seqNum int64 // sequences messages sent to client; has value of highest sent
+ ackNum int64 // highest sequence number acknowledged
}
func newClient(exp *Exporter, conn net.Conn) *expClient {
client := new(expClient)
client.exp = exp
client.encDec = newEncDec(conn)
+ client.seqNum = 0
+ client.ackNum = 0
return client
}
-// Wait for incoming connections, start a new runner for each
-func (exp *Exporter) listen() {
- for {
- conn, err := exp.listener.Accept()
- if err != nil {
- log.Stderr("exporter.listen:", err)
- break
- }
- client := newClient(exp, conn)
- go client.run()
- }
-}
-
func (client *expClient) sendError(hdr *header, err string) {
error := &error{err}
log.Stderr("export:", error.error)
client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+ client.mu.Lock()
+ client.errored = true
+ client.mu.Unlock()
}
-func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
+func (client *expClient) getChan(hdr *header, dir Dir) *chanDir {
exp := client.exp
- exp.chanLock.Lock()
+ exp.mu.Lock()
ech, ok := exp.chans[hdr.name]
- exp.chanLock.Unlock()
+ exp.mu.Unlock()
if !ok {
client.sendError(hdr, "no such channel: "+hdr.name)
return nil
@@ -95,9 +83,10 @@ func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
return ech
}
-// Manage sends and receives for a single client. For each (client Recv) request,
-// this will launch a serveRecv goroutine to deliver the data for that channel,
-// while (client Send) requests are handled as data arrives from the client.
+// The function run manages sends and receives for a single client. For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
func (client *expClient) run() {
hdr := new(header)
hdrValue := reflect.NewValue(hdr)
@@ -107,15 +96,13 @@ func (client *expClient) run() {
for {
if err := client.decode(hdrValue); err != nil {
log.Stderr("error decoding client header:", err)
- // TODO: tear down connection
- return
+ break
}
switch hdr.payloadType {
case payRequest:
if err := client.decode(reqValue); err != nil {
log.Stderr("error decoding client request:", err)
- // TODO: tear down connection
- return
+ break
}
switch req.dir {
case Recv:
@@ -132,13 +119,27 @@ func (client *expClient) run() {
}
case payData:
client.serveSend(*hdr)
+ case payAck:
+ client.mu.Lock()
+ if client.ackNum != hdr.seqNum-1 {
+ // Since the sequence number is incremented and the message is sent
+ // in a single instance of locking client.mu, the messages are guaranteed
+ // to be sent in order. Therefore receipt of acknowledgement N means
+ // all messages <=N have been seen by the recipient. We check anyway.
+ log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum)
+ }
+ if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count.
+ client.ackNum = hdr.seqNum
+ }
+ client.mu.Unlock()
}
}
+ client.exp.delClient(client)
}
// Send all the data on a single channel to a client asking for a Recv.
// The header is passed by value to avoid issues of overwriting.
-func (client *expClient) serveRecv(hdr header, count int) {
+func (client *expClient) serveRecv(hdr header, count int64) {
ech := client.getChan(&hdr, Send)
if ech == nil {
return
@@ -149,7 +150,16 @@ func (client *expClient) serveRecv(hdr header, count int) {
client.sendError(&hdr, os.EOF.String())
break
}
- if err := client.encode(&hdr, payData, val.Interface()); err != nil {
+ // We hold the lock during transmission to guarantee messages are
+ // sent in sequence number order. Also, we increment first so the
+ // value of client.seqNum is the value of the highest used sequence
+ // number, not one beyond.
+ client.mu.Lock()
+ client.seqNum++
+ hdr.seqNum = client.seqNum
+ err := client.encode(&hdr, payData, val.Interface())
+ client.mu.Unlock()
+ if err != nil {
log.Stderr("error encoding client response:", err)
client.sendError(&hdr, err.String())
break
@@ -180,6 +190,40 @@ func (client *expClient) serveSend(hdr header) {
// TODO count
}
+func (client *expClient) unackedCount() int64 {
+ client.mu.Lock()
+ n := client.seqNum - client.ackNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) seq() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+func (client *expClient) ack() int64 {
+ client.mu.Lock()
+ n := client.seqNum
+ client.mu.Unlock()
+ return n
+}
+
+// Wait for incoming connections, start a new runner for each
+func (exp *Exporter) listen() {
+ for {
+ conn, err := exp.listener.Accept()
+ if err != nil {
+ log.Stderr("exporter.listen:", err)
+ break
+ }
+ client := exp.addClient(conn)
+ go client.run()
+ }
+}
+
// NewExporter creates a new Exporter to export channels
// on the network and local address defined as in net.Listen.
func NewExporter(network, localaddr string) (*Exporter, os.Error) {
@@ -189,12 +233,52 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
}
e := &Exporter{
listener: listener,
- chans: make(map[string]*exportChan),
+ clientSet: &clientSet{
+ chans: make(map[string]*chanDir),
+ clients: make(map[unackedCounter]bool),
+ },
}
go e.listen()
return e, nil
}
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn net.Conn) *expClient {
+ client := newClient(exp, conn)
+ exp.clients[client] = true
+ exp.mu.Unlock()
+ return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+ exp.mu.Lock()
+ exp.clients[client] = false, false
+ exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer. In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked. Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client. If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+ // This wrapper function is here so the method's comment will appear in godoc.
+ return exp.clientSet.sync(timeout)
+}
+
// Addr returns the Exporter's local network address.
func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
@@ -230,12 +314,12 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
if err != nil {
return err
}
- exp.chanLock.Lock()
- defer exp.chanLock.Unlock()
+ exp.mu.Lock()
+ defer exp.mu.Unlock()
_, present := exp.chans[name]
if present {
return os.ErrorString("channel name already being exported:" + name)
}
- exp.chans[name] = &exportChan{ch, dir}
+ exp.chans[name] = &chanDir{ch, dir}
return nil
}
diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go
index 1effbaef4a..6a065543b5 100644
--- a/src/pkg/netchan/import.go
+++ b/src/pkg/netchan/import.go
@@ -14,13 +14,6 @@ import (
// Import
-// A channel and its associated information: a template value and direction,
-// plus a handy marshaling place for its data.
-type importChan struct {
- ch *reflect.ChanValue
- dir Dir
-}
-
// An Importer allows a set of channels to be imported from a single
// remote machine/network port. A machine may have multiple
// importers, even from the same machine/network port.
@@ -28,7 +21,7 @@ type Importer struct {
*encDec
conn net.Conn
chanLock sync.Mutex // protects access to channel map
- chans map[string]*importChan
+ chans map[string]*chanDir
}
// NewImporter creates a new Importer object to import channels
@@ -43,7 +36,7 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
imp := new(Importer)
imp.encDec = newEncDec(conn)
imp.conn = conn
- imp.chans = make(map[string]*importChan)
+ imp.chans = make(map[string]*chanDir)
go imp.run()
return imp, nil
}
@@ -67,6 +60,7 @@ func (imp *Importer) run() {
// Loop on responses; requests are sent by ImportNValues()
hdr := new(header)
hdrValue := reflect.NewValue(hdr)
+ ackHdr := new(header)
err := new(error)
errValue := reflect.NewValue(err)
for {
@@ -103,6 +97,10 @@ func (imp *Importer) run() {
log.Stderr("cannot happen: receive from non-Recv channel")
return
}
+ // Acknowledge receipt
+ ackHdr.name = hdr.name
+ ackHdr.seqNum = hdr.seqNum
+ imp.encode(ackHdr, payAck, nil)
// Create a new value for each received item.
value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
if e := imp.decode(value); e != nil {
@@ -144,14 +142,10 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
if present {
return os.ErrorString("channel name already being imported:" + name)
}
- imp.chans[name] = &importChan{ch, dir}
+ imp.chans[name] = &chanDir{ch, dir}
// Tell the other side about this channel.
- hdr := new(header)
- hdr.name = name
- hdr.payloadType = payRequest
- req := new(request)
- req.dir = dir
- req.count = n
+ hdr := &header{name: name, payloadType: payRequest}
+ req := &request{count: int64(n), dir: dir}
if err := imp.encode(hdr, payRequest, req); err != nil {
log.Stderr("importer request encode:", err)
return err
diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go
index eb5a11ea44..1bd4c9d4f8 100644
--- a/src/pkg/netchan/netchan_test.go
+++ b/src/pkg/netchan/netchan_test.go
@@ -37,7 +37,7 @@ func exportReceive(exp *Exporter, t *testing.T) {
}
}
-func importReceive(imp *Importer, t *testing.T) {
+func importReceive(imp *Importer, t *testing.T, done chan bool) {
ch := make(chan int)
err := imp.ImportNValues("exportedSend", ch, Recv, count)
if err != nil {
@@ -55,6 +55,9 @@ func importReceive(imp *Importer, t *testing.T) {
t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v)
}
}
+ if done != nil {
+ done <- true
+ }
}
func importSend(imp *Importer, t *testing.T) {
@@ -78,7 +81,7 @@ func TestExportSendImportReceive(t *testing.T) {
t.Fatal("new importer:", err)
}
exportSend(exp, count, t)
- importReceive(imp, t)
+ importReceive(imp, t, nil)
}
func TestExportReceiveImportSend(t *testing.T) {
@@ -104,5 +107,39 @@ func TestClosingExportSendImportReceive(t *testing.T) {
t.Fatal("new importer:", err)
}
exportSend(exp, closeCount, t)
- importReceive(imp, t)
+ importReceive(imp, t, nil)
+}
+
+// Not a great test but it does at least invoke Drain.
+func TestExportDrain(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ done := make(chan bool)
+ go exportSend(exp, closeCount, t)
+ go importReceive(imp, t, done)
+ exp.Drain(0)
+ <-done
+}
+
+// Not a great test but it does at least invoke Sync.
+func TestExportSync(t *testing.T) {
+ exp, err := NewExporter("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal("new exporter:", err)
+ }
+ imp, err := NewImporter("tcp", exp.Addr().String())
+ if err != nil {
+ t.Fatal("new importer:", err)
+ }
+ done := make(chan bool)
+ go importReceive(imp, t, done)
+ exportSend(exp, closeCount, t)
+ exp.Sync(0)
+ <-done
}