aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Pike <r@golang.org>2011-03-04 12:25:18 -0800
committerRob Pike <r@golang.org>2011-03-04 12:25:18 -0800
commitc91daefb2489febd6fc738554d1539fb151404d6 (patch)
treef1f6734a9113987c7bb0acaf2ac00daa456365f9
parent7b563be51647fe13c1cbfa45b310ee0068833e55 (diff)
downloadgo-c91daefb2489febd6fc738554d1539fb151404d6.tar.gz
go-c91daefb2489febd6fc738554d1539fb151404d6.zip
gob: beginning of support for GobEncoder/GobDecoder interfaces.
This allows a data item that can marshal itself to be transmitted by its own encoding, enabling some types to be handled that cannot be normally, plus providing a way to use gobs on data with unexported fields. In this CL, the necessary methods are protected by leading _, so only package gob can use the facilities (in its tests, of course); this code is not ready for real use yet. I could be talked into enabling it for experimentation, though. The main drawback is that the methods must be implemented by the actual type passed through, not by an indirection from it. For instance, if *T implements GobEncoder, you must send a *T, not a T. This will be addressed in due course. Also there is improved commentary and a couple of unrelated minor bug fixes. R=rsc CC=golang-dev https://golang.org/cl/4243056
-rw-r--r--src/pkg/gob/codec_test.go4
-rw-r--r--src/pkg/gob/debug.go35
-rw-r--r--src/pkg/gob/decode.go301
-rw-r--r--src/pkg/gob/decoder.go2
-rw-r--r--src/pkg/gob/encode.go134
-rw-r--r--src/pkg/gob/encoder.go107
-rw-r--r--src/pkg/gob/gobencdec_test.go331
-rw-r--r--src/pkg/gob/type.go195
-rw-r--r--src/pkg/gob/type_test.go2
9 files changed, 936 insertions, 175 deletions
diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go
index c822d6863a..4562e19309 100644
--- a/src/pkg/gob/codec_test.go
+++ b/src/pkg/gob/codec_test.go
@@ -303,7 +303,7 @@ func TestScalarEncInstructions(t *testing.T) {
}
}
-func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) {
+func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) {
defer testError(t)
v := int(state.decodeUint())
if v+state.fieldnum != 6 {
@@ -313,7 +313,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un
state.fieldnum = 6
}
-func newDecodeStateFromData(data []byte) *decodeState {
+func newDecodeStateFromData(data []byte) *decoderState {
b := bytes.NewBuffer(data)
state := newDecodeState(nil, b)
state.fieldnum = -1
diff --git a/src/pkg/gob/debug.go b/src/pkg/gob/debug.go
index e4583901e9..69c83bda78 100644
--- a/src/pkg/gob/debug.go
+++ b/src/pkg/gob/debug.go
@@ -155,6 +155,16 @@ func (deb *debugger) dump(format string, args ...interface{}) {
// Debug prints a human-readable representation of the gob data read from r.
func Debug(r io.Reader) {
+ err := debug(r)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "gob debug: %s\n", err)
+ }
+}
+
+// debug implements Debug, but catches panics and returns
+// them as errors to be printed by Debug.
+func debug(r io.Reader) (err os.Error) {
+ defer catchError(&err)
fmt.Fprintln(os.Stderr, "Start of debugging")
deb := &debugger{
r: newPeekReader(r),
@@ -166,6 +176,7 @@ func Debug(r io.Reader) {
deb.remainingKnown = true
}
deb.gobStream()
+ return
}
// note that we've consumed some bytes
@@ -386,11 +397,15 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) {
// Field number 1 is type Id of key
deb.delta(1)
keyId := deb.typeId()
- wire.SliceT = &sliceType{com, id}
// Field number 2 is type Id of elem
deb.delta(1)
elemId := deb.typeId()
wire.MapT = &mapType{com, keyId, elemId}
+ case 4: // GobEncoder type, one field of {{Common}}
+ // Field number 0 is CommonType
+ deb.delta(1)
+ com := deb.common()
+ wire.GobEncoderT = &gobEncoderType{com}
default:
errorf("bad field in type %d", fieldNum)
}
@@ -507,6 +522,8 @@ func (deb *debugger) printWireType(indent tab, wire *wireType) {
for i, field := range wire.StructT.Field {
fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id)
}
+ case wire.GobEncoderT != nil:
+ deb.printCommonType(indent, "GobEncoder", &wire.GobEncoderT.CommonType)
}
indent--
fmt.Fprintf(os.Stderr, "%s}\n", indent)
@@ -538,6 +555,8 @@ func (deb *debugger) fieldValue(indent tab, id typeId) {
deb.sliceValue(indent, wire)
case wire.StructT != nil:
deb.structValue(indent, id)
+ case wire.GobEncoderT != nil:
+ deb.gobEncoderValue(indent, id)
default:
panic("bad wire type for field")
}
@@ -654,3 +673,17 @@ func (deb *debugger) structValue(indent tab, id typeId) {
fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name())
deb.dump(">> End of struct value of type %d %q", id, id.name())
}
+
+// GobEncoderValue:
+// uint(n) byte*n
+func (deb *debugger) gobEncoderValue(indent tab, id typeId) {
+ len := deb.uint64()
+ deb.dump("GobEncoder value of %q id=%d, length %d\n", id.name(), id, len)
+ fmt.Fprintf(os.Stderr, "%s%s (implements GobEncoder)\n", indent, id.name())
+ data := make([]byte, len)
+ _, err := deb.r.Read(data)
+ if err != nil {
+ errorf("gobEncoder data read: %s", err)
+ }
+ fmt.Fprintf(os.Stderr, "%s[% .2x]\n", indent+1, data)
+}
diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go
index 8f599e1004..37f49312a8 100644
--- a/src/pkg/gob/decode.go
+++ b/src/pkg/gob/decode.go
@@ -24,9 +24,9 @@ var (
errRange = os.ErrorString("gob: internal error: field numbers out of bounds")
)
-// The execution state of an instance of the decoder. A new state
+// decoderState is the execution state of an instance of the decoder. A new state
// is created for nested objects.
-type decodeState struct {
+type decoderState struct {
dec *Decoder
// The buffer is stored with an extra indirection because it may be replaced
// if we load a type during decode (when reading an interface value).
@@ -37,8 +37,8 @@ type decodeState struct {
// We pass the bytes.Buffer separately for easier testing of the infrastructure
// without requiring a full Decoder.
-func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState {
- d := new(decodeState)
+func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decoderState {
+ d := new(decoderState)
d.dec = dec
d.b = buf
d.buf = make([]byte, uint64Size)
@@ -85,7 +85,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro
// decodeUint reads an encoded unsigned integer from state.r.
// Does not check for overflow.
-func (state *decodeState) decodeUint() (x uint64) {
+func (state *decoderState) decodeUint() (x uint64) {
b, err := state.b.ReadByte()
if err != nil {
error(err)
@@ -112,7 +112,7 @@ func (state *decodeState) decodeUint() (x uint64) {
// decodeInt reads an encoded signed integer from state.r.
// Does not check for overflow.
-func (state *decodeState) decodeInt() int64 {
+func (state *decoderState) decodeInt() int64 {
x := state.decodeUint()
if x&1 != 0 {
return ^int64(x >> 1)
@@ -120,7 +120,8 @@ func (state *decodeState) decodeInt() int64 {
return int64(x >> 1)
}
-type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer)
+// decOp is the signature of a decoding operator for a given type.
+type decOp func(i *decInstr, state *decoderState, p unsafe.Pointer)
// The 'instructions' of the decoding machine
type decInstr struct {
@@ -150,26 +151,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
return p
}
-func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreUint discards a uint value with no destination.
+func ignoreUint(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.decodeUint()
}
-func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreTwoUints discards a uint value with no destination. It's used to skip
+// complex values.
+func ignoreTwoUints(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.decodeUint()
state.decodeUint()
}
-func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decBool decodes a uiint and stores it as a boolean through p.
+func decBool(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool))
}
p = *(*unsafe.Pointer)(p)
}
- *(*bool)(p) = state.decodeInt() != 0
+ *(*bool)(p) = state.decodeUint() != 0
}
-func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt8 decodes an integer and stores it as an int8 through p.
+func decInt8(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8))
@@ -184,7 +190,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint8 decodes an unsigned integer and stores it as a uint8 through p.
+func decUint8(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8))
@@ -199,7 +206,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt16 decodes an integer and stores it as an int16 through p.
+func decInt16(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16))
@@ -214,7 +222,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint16 decodes an unsigned integer and stores it as a uint16 through p.
+func decUint16(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16))
@@ -229,7 +238,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt32 decodes an integer and stores it as an int32 through p.
+func decInt32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32))
@@ -244,7 +254,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint32 decodes an unsigned integer and stores it as a uint32 through p.
+func decUint32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32))
@@ -259,7 +270,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decInt64 decodes an integer and stores it as an int64 through p.
+func decInt64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64))
@@ -269,7 +281,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*int64)(p) = int64(state.decodeInt())
}
-func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decUint64 decodes an unsigned integer and stores it as a uint64 through p.
+func decUint64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64))
@@ -294,7 +307,9 @@ func floatFromBits(u uint64) float64 {
return math.Float64frombits(v)
}
-func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// storeFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
+// number, and stores it through p. It's a helper function for float32 and complex64.
+func storeFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
v := floatFromBits(state.decodeUint())
av := v
if av < 0 {
@@ -308,7 +323,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
}
}
-func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point
+// number, and stores it through p.
+func decFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32))
@@ -318,7 +335,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) {
storeFloat32(i, state, p)
}
-func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point
+// number, and stores it through p.
+func decFloat64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64))
@@ -328,8 +347,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*float64)(p) = floatFromBits(uint64(state.decodeUint()))
}
-// Complex numbers are just a pair of floating-point numbers, real part first.
-func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decComplex64 decodes a pair of unsigned integers, treats them as a
+// pair of floating point numbers, and stores them as a complex64 through p.
+// The real part comes first.
+func decComplex64(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64))
@@ -340,7 +361,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) {
storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0)))))
}
-func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// decComplex128 decodes a pair of unsigned integers, treats them as a
+// pair of floating point numbers, and stores them as a complex128 through p.
+// The real part comes first.
+func decComplex128(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128))
@@ -352,8 +376,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*complex128)(p) = complex(real, imag)
}
+// decUint8Array decodes byte array and stores through p a slice header
+// describing the data.
// uint8 arrays are encoded as an unsigned count followed by the raw bytes.
-func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
+func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8))
@@ -365,8 +391,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*[]uint8)(p) = b
}
+// decString decodes byte array and stores through p a string header
+// describing the data.
// Strings are encoded as an unsigned count followed by the raw bytes.
-func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
+func decString(i *decInstr, state *decoderState, p unsafe.Pointer) {
if i.indir > 0 {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte))
@@ -378,7 +406,8 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) {
*(*string)(p) = string(b)
}
-func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
+// ignoreUint8Array skips over the data for a byte slice value with no destination.
+func ignoreUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) {
b := make([]byte, state.decodeUint())
state.b.Read(b)
}
@@ -409,8 +438,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
return *(*uintptr)(up)
}
+// decodeSingle decodes a top-level value that is not a struct and stores it through p.
+// Such values are preceded by a zero, making them have the memory layout of a
+// struct field (although with an illegal field number).
func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
- p = allocate(ut.base, p, ut.indir)
+ indir := ut.indir
+ if ut.isGobDecoder {
+ indir = int(ut.decIndir)
+ }
+ p = allocate(ut.base, p, indir)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
basep := p
@@ -427,6 +463,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
return nil
}
+// decodeSingle decodes a top-level struct and stores it through p.
// Indir is for the value, not the type. At the time of the call it may
// differ from ut.indir, which was computed when the engine was built.
// This state cannot arise for decodeSingle, which is called directly
@@ -460,6 +497,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
return nil
}
+// ignoreStruct discards the data for a struct with no destination.
func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
@@ -482,6 +520,8 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
return nil
}
+// ignoreSingle discards the data for a top-level non-struct value with no
+// destination. It's used when calling Decode with a nil value.
func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
@@ -494,7 +534,8 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
return nil
}
-func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
+// decodeArrayHelper does the work for decoding arrays and slices.
+func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) {
instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl}
for i := 0; i < length; i++ {
up := unsafe.Pointer(p)
@@ -506,7 +547,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO
}
}
-func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
+// decodeArray decodes an array and stores it through p, that is, p points to the zeroth element.
+// The length is an unsigned integer preceding the elements. Even though the length is redundant
+// (it's part of the type), it's a useful check and is included in the encoding.
+func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) {
if indir > 0 {
p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect
}
@@ -516,7 +560,9 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u
dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl)
}
-func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
+// decodeIntoValue is a helper for map decoding. Since maps are decoded using reflection,
+// unlike the other items we can't use a pointer directly.
+func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value {
instr := &decInstr{op, 0, indir, 0, ovfl}
up := unsafe.Pointer(v.UnsafeAddr())
if indir > 1 {
@@ -526,7 +572,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o
return v
}
-func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
+// decodeMap decodes a map and stores its header through p.
+// Maps are encoded as a length followed by key:value pairs.
+// Because the internals of maps are not visible to us, we must
+// use reflection rather than pointer magic.
+func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) {
if indir > 0 {
p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect
}
@@ -538,7 +588,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
// the iteration.
- v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue)
+ v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue)
n := int(state.decodeUint())
for i := 0; i < n; i++ {
key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl)
@@ -547,21 +597,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp
}
}
-func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) {
+// ignoreArrayHelper does the work for discarding arrays and slices.
+func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) {
instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
for i := 0; i < length; i++ {
elemOp(instr, state, nil)
}
}
-func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) {
+// ignoreArray discards the data for an array value with no destination.
+func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) {
if n := state.decodeUint(); n != uint64(length) {
errorf("gob: length mismatch in ignoreArray")
}
dec.ignoreArrayHelper(state, elemOp, length)
}
-func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) {
+// ignoreMap discards the data for a map value with no destination.
+func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) {
n := int(state.decodeUint())
keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")}
elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")}
@@ -571,7 +624,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) {
}
}
-func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
+// decodeSlice decodes a slice and stores the slice header through p.
+// Slices are encoded as an unsigned length followed by the elements.
+func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) {
n := int(uintptr(state.decodeUint()))
if indir > 0 {
up := unsafe.Pointer(p)
@@ -590,7 +645,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u
dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl)
}
-func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) {
+// ignoreSlice skips over the data for a slice value with no destination.
+func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) {
dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint()))
}
@@ -609,9 +665,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) {
ivalue.Set(value)
}
-// decodeInterface receives the name of a concrete type followed by its value.
+// decodeInterface decodes an interface value and stores it through p.
+// Interfaces are encoded as the name of a concrete type followed by a value.
// If the name is empty, the value is nil and no value is sent.
-func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) {
+func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) {
// Create an interface reflect.Value. We need one even for the nil case.
ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue)
// Read the name of the concrete type.
@@ -655,7 +712,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt
*(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get()
}
-func (dec *Decoder) ignoreInterface(state *decodeState) {
+// ignoreInterface discards the data for an interface value with no destination.
+func (dec *Decoder) ignoreInterface(state *decoderState) {
// Read the name of the concrete type.
b := make([]byte, state.decodeUint())
_, err := state.b.Read(b)
@@ -670,6 +728,32 @@ func (dec *Decoder) ignoreInterface(state *decodeState) {
state.b.Next(int(state.decodeUint()))
}
+// decodeGobDecoder decodes something implementing the GobDecoder interface.
+// The data is encoded as a byte slice.
+func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value, index int) {
+ // Read the bytes for the value.
+ b := make([]byte, state.decodeUint())
+ _, err := state.b.Read(b)
+ if err != nil {
+ error(err)
+ }
+ // We know it's a GobDecoder, so just call the method directly.
+ err = v.Interface().(_GobDecoder)._GobDecode(b)
+ if err != nil {
+ error(err)
+ }
+}
+
+// ignoreGobDecoder discards the data for a GobDecoder value with no destination.
+func (dec *Decoder) ignoreGobDecoder(state *decoderState) {
+ // Read the bytes for the value.
+ b := make([]byte, state.decodeUint())
+ _, err := state.b.Read(b)
+ if err != nil {
+ error(err)
+ }
+}
+
// Index by Go types.
var decOpTable = [...]decOp{
reflect.Bool: decBool,
@@ -699,10 +783,14 @@ var decIgnoreOpMap = map[typeId]decOp{
tComplex: ignoreTwoUints,
}
-// Return the decoding op for the base type under rt and
+// decOpFor returns the decoding op for the base type under rt and
// the indirection count to reach it.
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) {
ut := userType(rt)
+ // If the type implements GobEncoder, we handle it without further processing.
+ if ut.isGobDecoder {
+ return dec.gobDecodeOpFor(ut)
+ }
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
@@ -724,7 +812,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
elemId := dec.wireType[wireId].ArrayT.Elem
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl)
}
@@ -735,7 +823,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress)
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
up := unsafe.Pointer(p)
state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl)
}
@@ -754,17 +842,17 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
}
elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress)
ovfl := overflow(name)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl)
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
- enginePtr, err := dec.getDecEnginePtr(wireId, typ)
+ enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ))
if err != nil {
error(err)
}
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs.
err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir)
if err != nil {
@@ -772,8 +860,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
}
}
case *reflect.InterfaceType:
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
- dec.decodeInterface(t, state, uintptr(p), i.indir)
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+ state.dec.decodeInterface(t, state, uintptr(p), i.indir)
}
}
}
@@ -783,15 +871,15 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg
return &op, indir
}
-// Return the decoding op for a field that has no destination.
+// decIgnoreOpFor returns the decoding op for a field that has no destination.
func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
op, ok := decIgnoreOpMap[wireId]
if !ok {
if wireId == tInterface {
// Special case because it's a method: the ignored item might
// define types and we need to record their state in the decoder.
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
- dec.ignoreInterface(state)
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+ state.dec.ignoreInterface(state)
}
return op
}
@@ -803,7 +891,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
case wire.ArrayT != nil:
elemId := wire.ArrayT.Elem
elemOp := dec.decIgnoreOpFor(elemId)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len)
}
@@ -812,14 +900,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
elemId := dec.wireType[wireId].MapT.Elem
keyOp := dec.decIgnoreOpFor(keyId)
elemOp := dec.decIgnoreOpFor(elemId)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreMap(state, keyOp, elemOp)
}
case wire.SliceT != nil:
elemId := wire.SliceT.Elem
elemOp := dec.decIgnoreOpFor(elemId)
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
state.dec.ignoreSlice(state, elemOp)
}
@@ -829,10 +917,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
if err != nil {
error(err)
}
- op = func(i *decInstr, state *decodeState, p unsafe.Pointer) {
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
// indirect through enginePtr to delay evaluation for recursive structs
state.dec.ignoreStruct(*enginePtr)
}
+
+ case wire.GobEncoderT != nil:
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+ state.dec.ignoreGobDecoder(state)
+ }
}
}
if op == nil {
@@ -841,16 +934,56 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp {
return op
}
-// Are these two gob Types compatible?
-// Answers the question for basic types, arrays, and slices.
+// gobDecodeOpFor returns the op for a type that is known to implement
+// GobDecoder.
+func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) {
+ rt := ut.user
+ if ut.decIndir != 0 {
+ errorf("gob: TODO: can't handle indirection to reach GobDecoder")
+ }
+ index := -1
+ for i := 0; i < rt.NumMethod(); i++ {
+ if rt.Method(i).Name == gobDecodeMethodName {
+ index = i
+ break
+ }
+ }
+ if index < 0 {
+ panic("can't find GobDecode method")
+ }
+ var op decOp
+ op = func(i *decInstr, state *decoderState, p unsafe.Pointer) {
+ // Allocate the underlying data, but hold on to the address we have,
+ // since it's known to be the receiver's address.
+ // TODO: fix this up when decIndir can be non-zero.
+ allocate(ut.base, uintptr(p), ut.indir)
+ v := reflect.NewValue(unsafe.Unreflect(rt, p))
+ state.dec.decodeGobDecoder(state, v, index)
+ }
+ return &op, int(ut.decIndir)
+
+}
+
+// compatibleType asks: Are these two gob Types compatible?
+// Answers the question for basic types, arrays, maps and slices, plus
+// GobEncoder/Decoder pairs.
// Structs are considered ok; fields will be checked later.
func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool {
if rhs, ok := inProgress[fr]; ok {
return rhs == fw
}
inProgress[fr] = fw
- fr = userType(fr).base
- switch t := fr.(type) {
+ ut := userType(fr)
+ wire, ok := dec.wireType[fw]
+ // If fr is a GobDecoder, the wire type must be GobEncoder.
+ // And if fr is not a GobDecoder, the wire type must not be either.
+ if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct.
+ return false
+ }
+ if ut.isGobDecoder { // This test trumps all others.
+ return true
+ }
+ switch t := ut.base.(type) {
default:
// chan, etc: cannot handle.
return false
@@ -869,14 +1002,12 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re
case *reflect.InterfaceType:
return fw == tInterface
case *reflect.ArrayType:
- wire, ok := dec.wireType[fw]
if !ok || wire.ArrayT == nil {
return false
}
array := wire.ArrayT
return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress)
case *reflect.MapType:
- wire, ok := dec.wireType[fw]
if !ok || wire.MapT == nil {
return false
}
@@ -911,8 +1042,13 @@ func (dec *Decoder) typeString(remoteId typeId) string {
return dec.wireType[remoteId].string()
}
-
-func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+// compileSingle compiles the decoder engine for a non-struct top-level value, including
+// GobDecoders.
+func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
+ rt := ut.base
+ if ut.isGobDecoder {
+ rt = ut.user
+ }
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
name := rt.String() // best we can do
@@ -926,6 +1062,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec
return
}
+// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded.
func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) {
engine = new(decEngine)
engine.instr = make([]decInstr, 1) // one item
@@ -936,16 +1073,19 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err
return
}
-// Is this an exported - upper case - name?
+// isExported reports whether this is an exported - upper case - name.
func isExported(name string) bool {
rune, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
-func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
+// compileDec compiles the decoder engine for a value. If the value is not a struct,
+// it calls out to compileSingle.
+func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) {
+ rt := ut.base
srt, ok := rt.(*reflect.StructType)
- if !ok {
- return dec.compileSingle(remoteId, rt)
+ if !ok || ut.isGobDecoder {
+ return dec.compileSingle(remoteId, ut)
}
var wireStruct *structType
// Builtin types can come from global pool; the rest must be defined by the decoder.
@@ -990,7 +1130,9 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng
return
}
-func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
+// getDecEnginePtr returns the engine for the specified type.
+func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err os.Error) {
+ rt := ut.base
decoderMap, ok := dec.decoderCache[rt]
if !ok {
decoderMap = make(map[typeId]**decEngine)
@@ -1000,7 +1142,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
// To handle recursive types, mark this engine as underway before compiling.
enginePtr = new(*decEngine)
decoderMap[remoteId] = enginePtr
- *enginePtr, err = dec.compileDec(remoteId, rt)
+ *enginePtr, err = dec.compileDec(remoteId, ut)
if err != nil {
decoderMap[remoteId] = nil, false
}
@@ -1008,11 +1150,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr
return
}
-// When ignoring struct data, in effect we compile it into this type
+// emptyStruct is the type we compile into when ignoring a struct value.
type emptyStruct struct{}
var emptyStructType = reflect.Typeof(emptyStruct{})
+// getDecEnginePtr returns the engine for the specified type when the value is to be discarded.
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
var ok bool
if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
@@ -1021,7 +1164,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
dec.ignorerCache[wireId] = enginePtr
wire := dec.wireType[wireId]
if wire != nil && wire.StructT != nil {
- *enginePtr, err = dec.compileDec(wireId, emptyStructType)
+ *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType))
} else {
*enginePtr, err = dec.compileIgnoreSingle(wireId)
}
@@ -1032,6 +1175,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return
}
+// decodeValue decodes the data stream representing a value and stores it in val.
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
defer catchError(&err)
// If the value is nil, it means we should just ignore this item.
@@ -1042,12 +1186,18 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error)
ut := userType(val.Type())
base := ut.base
indir := ut.indir
- enginePtr, err := dec.getDecEnginePtr(wireId, base)
+ if ut.isGobDecoder {
+ indir = int(ut.decIndir)
+ if indir != 0 {
+ errorf("TODO: can't handle indirection in GobDecoder value")
+ }
+ }
+ enginePtr, err := dec.getDecEnginePtr(wireId, ut)
if err != nil {
return err
}
engine := *enginePtr
- if st, ok := base.(*reflect.StructType); ok {
+ if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder {
if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 {
name := base.Name()
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name)
@@ -1057,6 +1207,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error)
return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr()))
}
+// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it.
func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error {
enginePtr, err := dec.getIgnoreEnginePtr(wireId)
if err != nil {
diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go
index f7c994ffa7..7192745836 100644
--- a/src/pkg/gob/decoder.go
+++ b/src/pkg/gob/decoder.go
@@ -21,7 +21,7 @@ type Decoder struct {
wireType map[typeId]*wireType // map from remote ID to local description
decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines
ignorerCache map[typeId]**decEngine // ditto for ignored objects
- countState *decodeState // reads counts from wire
+ countState *decoderState // reads counts from wire
countBuf []byte // used for decoding integers while parsing messages
tmp []byte // temporary storage for i/o; saves reallocating
err os.Error
diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go
index e92db74ffd..d69e734ff9 100644
--- a/src/pkg/gob/encode.go
+++ b/src/pkg/gob/encode.go
@@ -15,7 +15,7 @@ import (
const uint64Size = unsafe.Sizeof(uint64(0))
-// The global execution state of an instance of the encoder.
+// encoderState is the global execution state of an instance of the encoder.
// Field numbers are delta encoded and always increase. The field
// number is initialized to -1 so 0 comes out as delta(1). A delta of
// 0 terminates the structure.
@@ -72,6 +72,7 @@ func (state *encoderState) encodeInt(i int64) {
state.encodeUint(uint64(x))
}
+// encOp is the signature of an encoding operator for a given type.
type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer)
// The 'instructions' of the encoding machine
@@ -82,8 +83,8 @@ type encInstr struct {
offset uintptr // offset in the structure of the field to encode
}
-// Emit a field number and update the state to record its value for delta encoding.
-// If the instruction pointer is nil, do nothing
+// update emits a field number and updates the state to record its value for delta encoding.
+// If the instruction pointer is nil, it does nothing
func (state *encoderState) update(instr *encInstr) {
if instr != nil {
state.encodeUint(uint64(instr.field - state.fieldnum))
@@ -97,6 +98,7 @@ func (state *encoderState) update(instr *encInstr) {
// Otherwise, the output (for a scalar) is the field number, as an encoded integer,
// followed by the field data in its appropriate format.
+// encIndirect dereferences p indir times and returns the result.
func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
for ; indir > 0; indir-- {
p = *(*unsafe.Pointer)(p)
@@ -107,6 +109,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer {
return p
}
+// encBool encodes the bool with address p as an unsigned 0 or 1.
func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
b := *(*bool)(p)
if b || state.sendZero {
@@ -119,6 +122,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt encodes the int with address p.
func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int)(p))
if v != 0 || state.sendZero {
@@ -127,6 +131,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUint encodes the uint with address p.
func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint)(p))
if v != 0 || state.sendZero {
@@ -135,6 +140,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt8 encodes the int8 with address p.
func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int8)(p))
if v != 0 || state.sendZero {
@@ -143,6 +149,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUint8 encodes the uint8 with address p.
func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint8)(p))
if v != 0 || state.sendZero {
@@ -151,6 +158,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt16 encodes the int16 with address p.
func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int16)(p))
if v != 0 || state.sendZero {
@@ -159,6 +167,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUint16 encodes the uint16 with address p.
func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint16)(p))
if v != 0 || state.sendZero {
@@ -167,6 +176,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt32 encodes the int32 with address p.
func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := int64(*(*int32)(p))
if v != 0 || state.sendZero {
@@ -175,6 +185,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUint encodes the uint32 with address p.
func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uint32)(p))
if v != 0 || state.sendZero {
@@ -183,6 +194,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt64 encodes the int64 with address p.
func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := *(*int64)(p)
if v != 0 || state.sendZero {
@@ -191,6 +203,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encInt64 encodes the uint64 with address p.
func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := *(*uint64)(p)
if v != 0 || state.sendZero {
@@ -199,6 +212,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUintptr encodes the uintptr with address p.
func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
v := uint64(*(*uintptr)(p))
if v != 0 || state.sendZero {
@@ -207,6 +221,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// floatBits returns a uint64 holding the bits of a floating-point number.
// Floating-point numbers are transmitted as uint64s holding the bits
// of the underlying representation. They are sent byte-reversed, with
// the exponent end coming out first, so integer floating point numbers
@@ -223,6 +238,7 @@ func floatBits(f float64) uint64 {
return v
}
+// encFloat32 encodes the float32 with address p.
func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
f := *(*float32)(p)
if f != 0 || state.sendZero {
@@ -232,6 +248,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encFloat64 encodes the float64 with address p.
func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
f := *(*float64)(p)
if f != 0 || state.sendZero {
@@ -241,6 +258,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encComplex64 encodes the complex64 with address p.
// Complex numbers are just a pair of floating-point numbers, real part first.
func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
c := *(*complex64)(p)
@@ -253,6 +271,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encComplex128 encodes the complex128 with address p.
func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
c := *(*complex128)(p)
if c != 0+0i || state.sendZero {
@@ -264,6 +283,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encUint8Array encodes the byte slice whose header has address p.
// Byte arrays are encoded as an unsigned count followed by the raw bytes.
func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
b := *(*[]byte)(p)
@@ -274,6 +294,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
+// encString encodes the string whose header has address p.
// Strings are encoded as an unsigned count followed by the raw bytes.
func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
s := *(*string)(p)
@@ -284,14 +305,15 @@ func encString(i *encInstr, state *encoderState, p unsafe.Pointer) {
}
}
-// The end of a struct is marked by a delta field number of 0.
+// encStructTerminator encodes the end of an encoded struct
+// as delta field number of 0.
func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.encodeUint(0)
}
// Execution engine
-// The encoder engine is an array of instructions indexed by field number of the encoding
+// encEngine an array of instructions indexed by field number of the encoding
// data, typically a struct. It is executed top to bottom, walking the struct.
type encEngine struct {
instr []encInstr
@@ -299,6 +321,7 @@ type encEngine struct {
const singletonField = 0
+// encodeSingle encodes a single top-level non-struct value.
func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) {
state := newEncoderState(enc, b)
state.fieldnum = singletonField
@@ -315,6 +338,7 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp
instr.op(instr, state, p)
}
+// encodeStruct encodes a single struct value.
func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) {
state := newEncoderState(enc, b)
state.fieldnum = -1
@@ -330,6 +354,7 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp
}
}
+// encodeArray encodes the array whose 0th element is at p.
func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) {
state := newEncoderState(enc, b)
state.fieldnum = -1
@@ -349,6 +374,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui
}
}
+// encodeReflectValue is a helper for maps. It encodes the value v.
func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) {
for i := 0; i < indir && v != nil; i++ {
v = reflect.Indirect(v)
@@ -359,6 +385,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in
op(nil, state, unsafe.Pointer(v.UnsafeAddr()))
}
+// encodeMap encodes a map as unsigned count followed by key:value pairs.
+// Because map internals are not exposed, we must use reflection rather than
+// addresses.
func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) {
state := newEncoderState(enc, b)
state.fieldnum = -1
@@ -371,6 +400,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem
}
}
+// encodeInterface encodes the interface value iv.
// To send an interface, we send a string identifying the concrete type, followed
// by the type identifier (which might require defining that type right now), followed
// by the concrete value. A nil value gets sent as the empty string for the name,
@@ -414,6 +444,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue)
}
}
+// encGobEncoder encodes a value that implements the GobEncoder interface.
+// The data is sent as a byte array.
+func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value, index int) {
+ // TODO: should we catch panics from the called method?
+ // We know it's a GobEncoder, so just call the method directly.
+ data, err := v.Interface().(_GobEncoder)._GobEncode()
+ if err != nil {
+ error(err)
+ }
+ state := newEncoderState(enc, b)
+ state.fieldnum = -1
+ state.encodeUint(uint64(len(data)))
+ state.b.Write(data)
+}
+
var encOpTable = [...]encOp{
reflect.Bool: encBool,
reflect.Int: encInt,
@@ -434,10 +479,14 @@ var encOpTable = [...]encOp{
reflect.String: encString,
}
-// Return (a pointer to) the encoding op for the base type under rt and
+// encOpFor returns (a pointer to) the encoding op for the base type under rt and
// the indirection count to reach it.
func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) {
ut := userType(rt)
+ // If the type implements GobEncoder, we handle it without further processing.
+ if ut.isGobEncoder {
+ return enc.gobEncodeOpFor(ut)
+ }
// If this type is already in progress, it's a recursive type (e.g. map[string]*T).
// Return the pointer to the op we're already building.
if opPtr := inProgress[rt]; opPtr != nil {
@@ -483,7 +532,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
// Maps cannot be accessed by moving addresses around the way
// that slices etc. can. We must recover a full reflection value for
// the iteration.
- v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
+ v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
mv := reflect.Indirect(v).(*reflect.MapValue)
if !state.sendZero && mv.Len() == 0 {
return
@@ -493,7 +542,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
}
case *reflect.StructType:
// Generate a closure that calls out to the engine for the nested type.
- enc.getEncEngine(typ)
+ enc.getEncEngine(userType(typ))
info := mustGetTypeInfo(typ)
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
state.update(i)
@@ -504,7 +553,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
// Interfaces transmit the name and contents of the concrete
// value they contain.
- v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p))))
+ v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p)))
iv := reflect.Indirect(v).(*reflect.InterfaceValue)
if !state.sendZero && (iv == nil || iv.IsNil()) {
return
@@ -520,12 +569,43 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp
return &op, indir
}
-// The local Type was compiled from the actual value, so we know it's compatible.
-func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
- srt, isStruct := rt.(*reflect.StructType)
+// gobEncodeOpFor returns the op for a type that is known to implement
+// GobEncoder.
+func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) {
+ rt := ut.user
+ if ut.encIndir != 0 {
+ errorf("gob: TODO: can't handle indirection to reach GobEncoder")
+ }
+ index := -1
+ for i := 0; i < rt.NumMethod(); i++ {
+ if rt.Method(i).Name == gobEncodeMethodName {
+ index = i
+ break
+ }
+ }
+ if index < 0 {
+ panic("can't find GobEncode method")
+ }
+ var op encOp
+ op = func(i *encInstr, state *encoderState, p unsafe.Pointer) {
+ // TODO: this will need fixing when ut.encIndr != 0.
+ v := reflect.NewValue(unsafe.Unreflect(rt, p))
+ state.update(i)
+ state.enc.encodeGobEncoder(state.b, v, index)
+ }
+ return &op, int(ut.encIndir)
+}
+
+// compileEnc returns the engine to compile the type.
+func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine {
+ srt, isStruct := ut.base.(*reflect.StructType)
engine := new(encEngine)
seen := make(map[reflect.Type]*encOp)
- if isStruct {
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ }
+ if !ut.isGobEncoder && isStruct {
for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ {
f := srt.Field(fieldNum)
if !isExported(f.Name) {
@@ -546,35 +626,43 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine {
return engine
}
+// getEncEngine returns the engine to compile the type.
// typeLock must be held (or we're in initialization and guaranteed single-threaded).
-// The reflection type must have all its indirections processed out.
-func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine {
- info, err1 := getTypeInfo(rt)
+func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine {
+ info, err1 := getTypeInfo(ut)
if err1 != nil {
error(err1)
}
if info.encoder == nil {
// mark this engine as underway before compiling to handle recursive types.
info.encoder = new(encEngine)
- info.encoder = enc.compileEnc(rt)
+ info.encoder = enc.compileEnc(ut)
}
return info.encoder
}
-// Put this in a function so we can hold the lock only while compiling, not when encoding.
-func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine {
+// lockAndGetEncEngine is a function that locks and compiles.
+// This lets us hold the lock only while compiling, not when encoding.
+func (enc *Encoder) lockAndGetEncEngine(ut *userTypeInfo) *encEngine {
typeLock.Lock()
defer typeLock.Unlock()
- return enc.getEncEngine(rt)
+ return enc.getEncEngine(ut)
}
func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) {
defer catchError(&err)
- for i := 0; i < ut.indir; i++ {
+ engine := enc.lockAndGetEncEngine(ut)
+ indir := ut.indir
+ if ut.isGobEncoder {
+ indir = int(ut.encIndir)
+ if indir != 0 {
+ errorf("TODO: can't handle indirection in GobEncoder value")
+ }
+ }
+ for i := 0; i < indir; i++ {
value = reflect.Indirect(value)
}
- engine := enc.lockAndGetEncEngine(ut.base)
- if value.Type().Kind() == reflect.Struct {
+ if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct {
enc.encodeStruct(b, engine, value.UnsafeAddr())
} else {
enc.encodeSingle(b, engine, value.UnsafeAddr())
diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go
index 92d036c11c..4bfcf15c7f 100644
--- a/src/pkg/gob/encoder.go
+++ b/src/pkg/gob/encoder.go
@@ -78,12 +78,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) {
}
}
+// sendActualType sends the requested type, without further investigation, unless
+// it's been sent before.
+func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) {
+ if _, alreadySent := enc.sent[actual]; alreadySent {
+ return false
+ }
+ typeLock.Lock()
+ info, err := getTypeInfo(ut)
+ typeLock.Unlock()
+ if err != nil {
+ enc.setError(err)
+ return
+ }
+ // Send the pair (-id, type)
+ // Id:
+ state.encodeInt(-int64(info.id))
+ // Type:
+ enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
+ enc.writeMessage(w, state.b)
+ if enc.err != nil {
+ return
+ }
+
+ // Remember we've sent this type, both what the user gave us and the base type.
+ enc.sent[ut.base] = info.id
+ if ut.user != ut.base {
+ enc.sent[ut.user] = info.id
+ }
+ // Now send the inner types
+ switch st := actual.(type) {
+ case *reflect.StructType:
+ for i := 0; i < st.NumField(); i++ {
+ enc.sendType(w, state, st.Field(i).Type)
+ }
+ case reflect.ArrayOrSliceType:
+ enc.sendType(w, state, st.Elem())
+ }
+ return true
+}
+
+// sendType sends the type info to the other side, if necessary.
func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) {
- // Drill down to the base type.
ut := userType(origt)
- rt := ut.base
+ if ut.isGobEncoder {
+ // The rules are different: regardless of the underlying type's representation,
+ // we need to tell the other side that this exact type is a GobEncoder.
+ return enc.sendActualType(w, state, ut, ut.user)
+ }
- switch rt := rt.(type) {
+ // It's a concrete value, so drill down to the base type.
+ switch rt := ut.base.(type) {
default:
// Basic types and interfaces do not need to be described.
return
@@ -109,43 +154,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ
return
}
- // Have we already sent this type? This time we ask about the base type.
- if _, alreadySent := enc.sent[rt]; alreadySent {
- return
- }
-
- // Need to send it.
- typeLock.Lock()
- info, err := getTypeInfo(rt)
- typeLock.Unlock()
- if err != nil {
- enc.setError(err)
- return
- }
- // Send the pair (-id, type)
- // Id:
- state.encodeInt(-int64(info.id))
- // Type:
- enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo)
- enc.writeMessage(w, state.b)
- if enc.err != nil {
- return
- }
-
- // Remember we've sent this type.
- enc.sent[rt] = info.id
- // Remember we've sent the top-level, possibly indirect type too.
- enc.sent[origt] = info.id
- // Now send the inner types
- switch st := rt.(type) {
- case *reflect.StructType:
- for i := 0; i < st.NumField(); i++ {
- enc.sendType(w, state, st.Field(i).Type)
- }
- case reflect.ArrayOrSliceType:
- enc.sendType(w, state, st.Elem())
- }
- return true
+ return enc.sendActualType(w, state, ut, ut.base)
}
// Encode transmits the data item represented by the empty interface value,
@@ -159,11 +168,17 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
// sent.
func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) {
// Make sure the type is known to the other side.
- // First, have we already sent this (base) type?
- base := ut.base
- if _, alreadySent := enc.sent[base]; !alreadySent {
+ // First, have we already sent this type?
+ rt := ut.base
+ if ut.isGobEncoder {
+ rt = ut.user
+ if ut.encIndir != 0 {
+ panic("TODO: can't handle non-zero encIndir")
+ }
+ }
+ if _, alreadySent := enc.sent[rt]; !alreadySent {
// No, so send it.
- sent := enc.sendType(w, state, base)
+ sent := enc.sendType(w, state, rt)
if enc.err != nil {
return
}
@@ -172,13 +187,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use
// need to send the type info but we do need to update enc.sent.
if !sent {
typeLock.Lock()
- info, err := getTypeInfo(base)
+ info, err := getTypeInfo(ut)
typeLock.Unlock()
if err != nil {
enc.setError(err)
return
}
- enc.sent[base] = info.id
+ enc.sent[rt] = info.id
}
}
}
diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go
new file mode 100644
index 0000000000..dbe7d3fe31
--- /dev/null
+++ b/src/pkg/gob/gobencdec_test.go
@@ -0,0 +1,331 @@
+// Copyright 20011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains tests of the GobEncoder/GobDecoder support.
+
+package gob
+
+import (
+ "bytes"
+ "fmt"
+ "os"
+ "strings"
+ "testing"
+)
+
+// Types that implement the GobEncoder/Decoder interfaces.
+
+type ByteStruct struct {
+ a byte // not an exported field
+}
+
+type StringStruct struct {
+ s string // not an exported field
+}
+
+type Gobber int
+
+type ValueGobber string // encodes with a value, decodes with a pointer.
+
+// The relevant methods
+
+func (g *ByteStruct) _GobEncode() ([]byte, os.Error) {
+ b := make([]byte, 3)
+ b[0] = g.a
+ b[1] = g.a + 1
+ b[2] = g.a + 2
+ return b, nil
+}
+
+func (g *ByteStruct) _GobDecode(data []byte) os.Error {
+ if g == nil {
+ return os.ErrorString("NIL RECEIVER")
+ }
+ // Expect N sequential-valued bytes.
+ if len(data) == 0 {
+ return os.EOF
+ }
+ g.a = data[0]
+ for i, c := range data {
+ if c != g.a+byte(i) {
+ return os.ErrorString("invalid data sequence")
+ }
+ }
+ return nil
+}
+
+func (g *StringStruct) _GobEncode() ([]byte, os.Error) {
+ return []byte(g.s), nil
+}
+
+func (g *StringStruct) _GobDecode(data []byte) os.Error {
+ // Expect N sequential-valued bytes.
+ if len(data) == 0 {
+ return os.EOF
+ }
+ a := data[0]
+ for i, c := range data {
+ if c != a+byte(i) {
+ return os.ErrorString("invalid data sequence")
+ }
+ }
+ g.s = string(data)
+ return nil
+}
+
+func (g *Gobber) _GobEncode() ([]byte, os.Error) {
+ return []byte(fmt.Sprintf("VALUE=%d", *g)), nil
+}
+
+func (g *Gobber) _GobDecode(data []byte) os.Error {
+ _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g))
+ return err
+}
+
+func (v ValueGobber) _GobEncode() ([]byte, os.Error) {
+ return []byte(fmt.Sprintf("VALUE=%s", v)), nil
+}
+
+func (v *ValueGobber) _GobDecode(data []byte) os.Error {
+ _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v))
+ return err
+}
+
+// Structs that include GobEncodable fields.
+
+type GobTest0 struct {
+ X int // guarantee we have something in common with GobTest*
+ G *ByteStruct
+}
+
+type GobTest1 struct {
+ X int // guarantee we have something in common with GobTest*
+ G *StringStruct
+}
+
+type GobTest2 struct {
+ X int // guarantee we have something in common with GobTest*
+ G string // not a GobEncoder - should give us errors
+}
+
+type GobTest3 struct {
+ X int // guarantee we have something in common with GobTest*
+ G *Gobber // TODO: should be able to satisfy interface without a pointer
+}
+
+type GobTest4 struct {
+ X int // guarantee we have something in common with GobTest*
+ V ValueGobber
+}
+
+type GobTest5 struct {
+ X int // guarantee we have something in common with GobTest*
+ V *ValueGobber
+}
+
+type GobTestIgnoreEncoder struct {
+ X int // guarantee we have something in common with GobTest*
+}
+
+func TestGobEncoderField(t *testing.T) {
+ b := new(bytes.Buffer)
+ // First a field that's a structure.
+ enc := NewEncoder(b)
+ err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(GobTest0)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x.G.a != 'A' {
+ t.Errorf("expected 'A' got %c", x.G.a)
+ }
+ // Now a field that's not a structure.
+ b.Reset()
+ gobber := Gobber(23)
+ err = enc.Encode(GobTest3{17, &gobber})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ y := new(GobTest3)
+ err = dec.Decode(y)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if *y.G != 23 {
+ t.Errorf("expected '23 got %d", *y.G)
+ }
+}
+
+// As long as the fields have the same name and implement the
+// interface, we can cross-connect them. Not sure it's useful
+// and may even be bad but it works and it's hard to prevent
+// without exposing the contents of the object, which would
+// defeat the purpose.
+func TestGobEncoderFieldsOfDifferentType(t *testing.T) {
+ // first, string in field to byte in field
+ b := new(bytes.Buffer)
+ enc := NewEncoder(b)
+ err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(GobTest0)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x.G.a != 'A' {
+ t.Errorf("expected 'A' got %c", x.G.a)
+ }
+ // now the other direction, byte in field to string in field
+ b.Reset()
+ err = enc.Encode(GobTest0{17, &ByteStruct{'X'}})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ y := new(GobTest1)
+ err = dec.Decode(y)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if y.G.s != "XYZ" {
+ t.Fatalf("expected `XYZ` got %c", y.G.s)
+ }
+}
+
+// Test that we can encode a value and decode into a pointer.
+func TestGobEncoderValueEncoder(t *testing.T) {
+ // first, string in field to byte in field
+ b := new(bytes.Buffer)
+ enc := NewEncoder(b)
+ err := enc.Encode(GobTest4{17, ValueGobber("hello")})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(GobTest5)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if *x.V != "hello" {
+ t.Errorf("expected `hello` got %s", x.V)
+ }
+}
+
+func TestGobEncoderFieldTypeError(t *testing.T) {
+ // GobEncoder to non-decoder: error
+ b := new(bytes.Buffer)
+ enc := NewEncoder(b)
+ err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := &GobTest2{}
+ err = dec.Decode(x)
+ if err == nil {
+ t.Fatal("expected decode error for mistmatched fields (encoder to non-decoder)")
+ }
+ if strings.Index(err.String(), "type") < 0 {
+ t.Fatal("expected type error; got", err)
+ }
+ // Non-encoder to GobDecoder: error
+ b.Reset()
+ err = enc.Encode(GobTest2{17, "ABC"})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ y := &GobTest1{}
+ err = dec.Decode(y)
+ if err == nil {
+ t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)")
+ }
+ if strings.Index(err.String(), "type") < 0 {
+ t.Fatal("expected type error; got", err)
+ }
+}
+
+// Even though ByteStruct is a struct, it's treated as a singleton at the top level.
+func TestGobEncoderStructSingleton(t *testing.T) {
+ b := new(bytes.Buffer)
+ enc := NewEncoder(b)
+ err := enc.Encode(&ByteStruct{'A'})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(ByteStruct)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x.a != 'A' {
+ t.Errorf("expected 'A' got %c", x.a)
+ }
+}
+
+func TestGobEncoderNonStructSingleton(t *testing.T) {
+ b := new(bytes.Buffer)
+ enc := NewEncoder(b)
+ g := Gobber(1234) // TODO: shouldn't need to take the address here.
+ err := enc.Encode(&g)
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ var x Gobber
+ err = dec.Decode(&x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x != 1234 {
+ t.Errorf("expected 1234 got %c", x)
+ }
+}
+
+func TestGobEncoderIgnoreStructField(t *testing.T) {
+ b := new(bytes.Buffer)
+ // First a field that's a structure.
+ enc := NewEncoder(b)
+ err := enc.Encode(GobTest0{17, &ByteStruct{'A'}})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(GobTestIgnoreEncoder)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x.X != 17 {
+ t.Errorf("expected 17 got %c", x.X)
+ }
+}
+
+func TestGobEncoderIgnoreNonStructField(t *testing.T) {
+ b := new(bytes.Buffer)
+ // First a field that's a structure.
+ enc := NewEncoder(b)
+ gobber := Gobber(23)
+ err := enc.Encode(GobTest3{17, &gobber})
+ if err != nil {
+ t.Fatal("encode error:", err)
+ }
+ dec := NewDecoder(b)
+ x := new(GobTestIgnoreEncoder)
+ err = dec.Decode(x)
+ if err != nil {
+ t.Fatal("decode error:", err)
+ }
+ if x.X != 17 {
+ t.Errorf("expected 17 got %c", x.X)
+ }
+}
diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go
index 6e3f148b4e..05d5f122e9 100644
--- a/src/pkg/gob/type.go
+++ b/src/pkg/gob/type.go
@@ -15,9 +15,13 @@ import (
// to the package. It's computed once and stored in a map keyed by reflection
// type.
type userTypeInfo struct {
- user reflect.Type // the type the user handed us
- base reflect.Type // the base type after all indirections
- indir int // number of indirections to reach the base type
+ user reflect.Type // the type the user handed us
+ base reflect.Type // the base type after all indirections
+ indir int // number of indirections to reach the base type
+ isGobEncoder bool // does the type implement _GobEncoder?
+ isGobDecoder bool // does the type implement _GobDecoder?
+ encIndir int8 // number of indirections to reach the receiver type; may be negative
+ decIndir int8 // number of indirections to reach the receiver type; may be negative
}
var (
@@ -68,10 +72,83 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
}
ut.indir++
}
+ ut.isGobEncoder, ut.encIndir = implementsGobEncoder(ut.user)
+ ut.isGobDecoder, ut.decIndir = implementsGobDecoder(ut.user)
userTypeCache[rt] = ut
+ if ut.encIndir != 0 || ut.decIndir != 0 {
+ // There are checks in lots of other places, but putting this here means we won't even
+ // attempt to encode/decode this type.
+ // TODO: make it possible to handle types that are indirect to the implementation,
+ // such as a structure field of type T when *T implements GobDecoder.
+ return nil, os.ErrorString("TODO: gob can't handle indirections to GobEncoder/Decoder")
+ }
return
}
+const (
+ gobEncodeMethodName = "_GobEncode"
+ gobDecodeMethodName = "_GobDecode"
+)
+
+// implementsGobEncoder reports whether the type implements the interface. It also
+// returns the number of indirections required to get to the implementation.
+// TODO: when reflection makes it possible, should also be prepared to climb up
+// one level if we're not on a pointer (implementation could be on *T for our T).
+// That will mean that indir could be < 0, which is sure to cause problems, but
+// we ignore them now as indir is always >= 0 now.
+func implementsGobEncoder(rt reflect.Type) (implements bool, indir int8) {
+ if rt == nil {
+ return
+ }
+ // The type might be a pointer, or it might not, and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance
+ if _, ok := reflect.MakeZero(rt).Interface().(_GobEncoder); ok {
+ return true, indir
+ }
+ }
+ if p, ok := rt.(*reflect.PtrType); ok {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ return false, 0
+}
+
+// implementsGobDecoder reports whether the type implements the interface. It also
+// returns the number of indirections required to get to the implementation.
+// TODO: see comment on implementsGobEncoder.
+func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) {
+ if rt == nil {
+ return
+ }
+ // The type might be a pointer, or it might not, and we need to keep
+ // dereferencing to the base type until we find an implementation.
+ for {
+ if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance
+ if _, ok := reflect.MakeZero(rt).Interface().(_GobDecoder); ok {
+ return true, indir
+ }
+ }
+ if p, ok := rt.(*reflect.PtrType); ok {
+ indir++
+ if indir > 100 { // insane number of indirections
+ return false, 0
+ }
+ rt = p.Elem()
+ continue
+ }
+ break
+ }
+ return false, 0
+}
+
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
@@ -229,6 +306,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string {
func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) }
+// GobEncoder type (something that implements the _GobEncoder interface)
+type gobEncoderType struct {
+ CommonType
+}
+
+func newGobEncoderType(name string) *gobEncoderType {
+ g := &gobEncoderType{CommonType{Name: name}}
+ setTypeId(g)
+ return g
+}
+
+func (g *gobEncoderType) safeString(seen map[typeId]bool) string {
+ return g.Name
+}
+
+func (g *gobEncoderType) string() string { return g.Name }
+
// Map type
type mapType struct {
CommonType
@@ -328,7 +422,16 @@ func (s *structType) init(field []*fieldType) {
s.Field = field
}
-func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
+// newTypeObject allocates a gobType for the reflection type rt.
+// Unless ut represents a GobEncoder, rt should be the base type
+// of ut.
+// This is only called from the encoding side. The decoding side
+// works through typeIds and userTypeInfos alone.
+func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
+ // Does this type implement GobEncoder?
+ if ut.isGobEncoder {
+ return newGobEncoderType(name), nil
+ }
var err os.Error
var type0, type1 gobType
defer func() {
@@ -364,7 +467,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
case *reflect.ArrayType:
at := newArrayType(name)
types[rt] = at
- type0, err = getType("", t.Elem())
+ type0, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
@@ -382,11 +485,11 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
case *reflect.MapType:
mt := newMapType(name)
types[rt] = mt
- type0, err = getType("", t.Key())
+ type0, err = getBaseType("", t.Key())
if err != nil {
return nil, err
}
- type1, err = getType("", t.Elem())
+ type1, err = getBaseType("", t.Elem())
if err != nil {
return nil, err
}
@@ -400,7 +503,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
}
st := newSliceType(name)
types[rt] = st
- type0, err = getType(t.Elem().Name(), t.Elem())
+ type0, err = getBaseType(t.Elem().Name(), t.Elem())
if err != nil {
return nil, err
}
@@ -413,6 +516,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
idToType[st.id()] = st
field := make([]*fieldType, t.NumField())
for i := 0; i < t.NumField(); i++ {
+ // TODO: don't send unexported fields.
f := t.Field(i)
typ := userType(f.Type).base
tname := typ.Name()
@@ -420,7 +524,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
t := userType(f.Type).base
tname = t.String()
}
- gt, err := getType(tname, f.Type)
+ gt, err := getBaseType(tname, f.Type)
if err != nil {
return nil, err
}
@@ -435,15 +539,24 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
return nil, nil
}
+// getBaseType returns the Gob type describing the given reflect.Type's base type.
+// typeLock must be held.
+func getBaseType(name string, rt reflect.Type) (gobType, os.Error) {
+ ut := userType(rt)
+ return getType(name, ut, ut.base)
+}
+
// getType returns the Gob type describing the given reflect.Type.
+// Should be called only when handling GobEncoders/Decoders,
+// which may be pointers. All other types are handled through the
+// base type, never a pointer.
// typeLock must be held.
-func getType(name string, rt reflect.Type) (gobType, os.Error) {
- rt = userType(rt).base
+func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) {
typ, present := types[rt]
if present {
return typ, nil
}
- typ, err := newTypeObject(name, rt)
+ typ, err := newTypeObject(name, ut, rt)
if err == nil {
types[rt] = typ
}
@@ -484,10 +597,11 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId {
// To maintain binary compatibility, if you extend this type, always put
// the new fields last.
type wireType struct {
- ArrayT *arrayType
- SliceT *sliceType
- StructT *structType
- MapT *mapType
+ ArrayT *arrayType
+ SliceT *sliceType
+ StructT *structType
+ MapT *mapType
+ GobEncoderT *gobEncoderType
}
func (w *wireType) string() string {
@@ -504,6 +618,8 @@ func (w *wireType) string() string {
return w.StructT.Name
case w.MapT != nil:
return w.MapT.Name
+ case w.GobEncoderT != nil:
+ return w.GobEncoderT.Name
}
return unknown
}
@@ -516,23 +632,43 @@ type typeInfo struct {
var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock
-// The reflection type must have all its indirections processed out.
// typeLock must be held.
-func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
- if rt.Kind() == reflect.Ptr {
- panic("pointer type in getTypeInfo: " + rt.String())
+func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) {
+
+ if ut.isGobEncoder {
+ // TODO: clean up this code - too much duplication.
+ info, ok := typeInfoMap[ut.user]
+ if ok {
+ return info, nil
+ }
+ // We want the user type, not the base type.
+ userType, err := getType(ut.user.Name(), ut, ut.user)
+ if err != nil {
+ return nil, err
+ }
+ info = new(typeInfo)
+ gt, err := getBaseType(ut.base.Name(), ut.base)
+ if err != nil {
+ return nil, err
+ }
+ info.id = gt.id()
+ info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)}
+ typeInfoMap[ut.user] = info
+ return info, nil
}
- info, ok := typeInfoMap[rt]
+
+ base := ut.base
+ info, ok := typeInfoMap[base]
if !ok {
info = new(typeInfo)
- name := rt.Name()
- gt, err := getType(name, rt)
+ name := base.Name()
+ gt, err := getBaseType(name, base)
if err != nil {
return nil, err
}
info.id = gt.id()
t := info.id.gobType()
- switch typ := rt.(type) {
+ switch typ := base.(type) {
case *reflect.ArrayType:
info.wire = &wireType{ArrayT: t.(*arrayType)}
case *reflect.MapType:
@@ -545,20 +681,27 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) {
case *reflect.StructType:
info.wire = &wireType{StructT: t.(*structType)}
}
- typeInfoMap[rt] = info
+ typeInfoMap[base] = info
}
return info, nil
}
// Called only when a panic is acceptable and unexpected.
func mustGetTypeInfo(rt reflect.Type) *typeInfo {
- t, err := getTypeInfo(rt)
+ t, err := getTypeInfo(userType(rt))
if err != nil {
panic("getTypeInfo: " + err.String())
}
return t
}
+type _GobEncoder interface {
+ _GobEncode() ([]byte, os.Error)
+} // use _ prefix until we get it working properly
+type _GobDecoder interface {
+ _GobDecode([]byte) os.Error
+} // use _ prefix until we get it working properly
+
var (
nameToConcreteType = make(map[string]reflect.Type)
concreteTypeToName = make(map[reflect.Type]string)
diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go
index 5aecde103a..6fe1ecf93e 100644
--- a/src/pkg/gob/type_test.go
+++ b/src/pkg/gob/type_test.go
@@ -26,7 +26,7 @@ var basicTypes = []typeT{
func getTypeUnlocked(name string, rt reflect.Type) gobType {
typeLock.Lock()
defer typeLock.Unlock()
- t, err := getType(name, rt)
+ t, err := getBaseType(name, rt)
if err != nil {
panic("getTypeUnlocked: " + err.String())
}