diff options
author | Rob Pike <r@golang.org> | 2010-10-22 15:16:34 -0700 |
---|---|---|
committer | Rob Pike <r@golang.org> | 2010-10-22 15:16:34 -0700 |
commit | f593b37f237f624724c6bbe32124819395fa9d1d (patch) | |
tree | d4738d13c07288eefe3fc8a5e7093ef2c3406106 | |
parent | 1dd0319be36ae0b00d14caeb13912c1cc2f13d1f (diff) | |
download | go-f593b37f237f624724c6bbe32124819395fa9d1d.tar.gz go-f593b37f237f624724c6bbe32124819395fa9d1d.zip |
gobs: error cleanup part 1.
Remove err from the encoderState and decoderState types, so we're
not always copying to and from various copies of the error, and then
use panic/recover to eliminate lots of error checking.
another pass might take a crack at the same thing for the compilation phase.
R=rsc
CC=golang-dev
https://golang.org/cl/2660042
-rw-r--r-- | src/pkg/gob/Makefile | 1 | ||||
-rw-r--r-- | src/pkg/gob/codec_test.go | 30 | ||||
-rw-r--r-- | src/pkg/gob/decode.go | 174 | ||||
-rw-r--r-- | src/pkg/gob/decoder.go | 40 | ||||
-rw-r--r-- | src/pkg/gob/encode.go | 76 | ||||
-rw-r--r-- | src/pkg/gob/encoder.go | 21 | ||||
-rw-r--r-- | src/pkg/gob/encoder_test.go | 60 | ||||
-rw-r--r-- | src/pkg/gob/error.go | 41 |
8 files changed, 229 insertions, 214 deletions
diff --git a/src/pkg/gob/Makefile b/src/pkg/gob/Makefile index 77ec9d98ce..68007c189e 100644 --- a/src/pkg/gob/Makefile +++ b/src/pkg/gob/Makefile @@ -11,6 +11,7 @@ GOFILES=\ doc.go\ encode.go\ encoder.go\ + error.go\ type.go\ include ../../Make.pkg diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index ba97f51a1b..2e52a0f1dd 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -37,16 +37,23 @@ var encodeT = []EncodeT{ {1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, } +// testError is meant to be used as a deferred function to turn a panic(gobError) into a +// plain test.Error call. +func testError(t *testing.T) { + if e := recover(); e != nil { + t.Error(e.(gobError).Error) // Will re-panic if not one of our errors, such as a runtime error. + } + return +} + // Test basic encode/decode routines for unsigned integers func TestUintCodec(t *testing.T) { + defer testError(t) b := new(bytes.Buffer) encState := newEncoderState(b) for _, tt := range encodeT { b.Reset() encodeUint(encState, tt.x) - if encState.err != nil { - t.Error("encodeUint:", tt.x, encState.err) - } if !bytes.Equal(tt.b, b.Bytes()) { t.Errorf("encodeUint: %#x encode: expected % x got % x", tt.x, tt.b, b.Bytes()) } @@ -55,13 +62,7 @@ func TestUintCodec(t *testing.T) { for u := uint64(0); ; u = (u + 1) * 7 { b.Reset() encodeUint(encState, u) - if encState.err != nil { - t.Error("encodeUint:", u, encState.err) - } v := decodeUint(decState) - if decState.err != nil { - t.Error("DecodeUint:", u, decState.err) - } if u != v { t.Errorf("Encode/Decode: sent %#x received %#x", u, v) } @@ -72,18 +73,13 @@ func TestUintCodec(t *testing.T) { } func verifyInt(i int64, t *testing.T) { + defer testError(t) var b = new(bytes.Buffer) encState := newEncoderState(b) encodeInt(encState, i) - if encState.err != nil { - t.Error("encodeInt:", i, encState.err) - } decState := newDecodeState(&b) decState.buf = make([]byte, 8) j := decodeInt(decState) - if decState.err != nil { - t.Error("DecodeInt:", i, decState.err) - } if i != j { t.Errorf("Encode/Decode: sent %#x received %#x", uint64(i), uint64(j)) } @@ -320,10 +316,8 @@ func TestScalarEncInstructions(t *testing.T) { } func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { + defer testError(t) v := int(decodeUint(state)) - if state.err != nil { - t.Fatalf("decoding %s field: %v", typ, state.err) - } if v+state.fieldnum != 6 { t.Fatalf("decoding field number %d, got %d", 6, v+state.fieldnum) } diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 5791c37ecb..96d3176847 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -27,8 +27,8 @@ var ( type decodeState struct { // The buffer is stored with an extra indirection because it may be replaced // if we load a type during decode (when reading an interface value). - b **bytes.Buffer - err os.Error + b **bytes.Buffer + // err os.Error fieldnum int // the last field number read. buf []byte } @@ -80,21 +80,21 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, err os.Error) { // Sets state.err. If state.err is already non-nil, it does nothing. // Does not check for overflow. func decodeUint(state *decodeState) (x uint64) { - if state.err != nil { - return + b, err := state.b.ReadByte() + if err != nil { + error(err) } - var b uint8 - b, state.err = state.b.ReadByte() if b <= 0x7f { // includes state.err != nil return uint64(b) } nb := -int(int8(b)) if nb > uint64Size { - state.err = errBadUint - return + error(errBadUint) + } + n, err := state.b.Read(state.buf[0:nb]) + if err != nil { + error(err) } - var n int - n, state.err = state.b.Read(state.buf[0:nb]) // Don't need to check error; it's safe to loop regardless. // Could check that the high byte is zero but it's not worth it. for i := 0; i < n; i++ { @@ -109,9 +109,6 @@ func decodeUint(state *decodeState) (x uint64) { // Does not check for overflow. func decodeInt(state *decodeState) int64 { x := decodeUint(state) - if state.err != nil { - return 0 - } if x&1 != 0 { return ^int64(x >> 1) } @@ -176,7 +173,7 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeInt(state) if v < math.MinInt8 || math.MaxInt8 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int8)(p) = int8(v) } @@ -191,7 +188,7 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeUint(state) if math.MaxUint8 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint8)(p) = uint8(v) } @@ -206,7 +203,7 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeInt(state) if v < math.MinInt16 || math.MaxInt16 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int16)(p) = int16(v) } @@ -221,7 +218,7 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeUint(state) if math.MaxUint16 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint16)(p) = uint16(v) } @@ -236,7 +233,7 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeInt(state) if v < math.MinInt32 || math.MaxInt32 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*int32)(p) = int32(v) } @@ -251,7 +248,7 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { } v := decodeUint(state) if math.MaxUint32 < v { - state.err = i.ovfl + error(i.ovfl) } else { *(*uint32)(p) = uint32(v) } @@ -300,7 +297,7 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { } // +Inf is OK in both 32- and 64-bit floats. Underflow is always OK. if math.MaxFloat32 < av && av <= math.MaxFloat64 { - state.err = i.ovfl + error(i.ovfl) } else { *(*float32)(p) = float32(v) } @@ -407,15 +404,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { return *(*uintptr)(up) } -func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) os.Error { +func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { + defer catchError(&err) p = allocate(rtyp, p, indir) state := newDecodeState(b) state.fieldnum = singletonField basep := p delta := int(decodeUint(state)) if delta != 0 { - state.err = os.ErrorString("gob decode: corrupted data: non-zero delta for singleton") - return state.err + errorf("gob decode: corrupted data: non-zero delta for singleton") } instr := &engine.instr[singletonField] ptr := unsafe.Pointer(basep) // offset will be zero @@ -423,26 +420,26 @@ func decodeSingle(engine *decEngine, rtyp reflect.Type, b **bytes.Buffer, p uint ptr = decIndirect(ptr, instr.indir) } instr.op(instr, state, ptr) - return state.err + return nil } -func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) os.Error { +func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b **bytes.Buffer, p uintptr, indir int) (err os.Error) { + defer catchError(&err) p = allocate(rtyp, p, indir) state := newDecodeState(b) state.fieldnum = -1 basep := p - for state.b.Len() > 0 && state.err == nil { + for state.b.Len() > 0 { delta := int(decodeUint(state)) if delta < 0 { - state.err = os.ErrorString("gob decode: corrupted data: negative delta") - break + errorf("gob decode: corrupted data: negative delta") } - if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum + if delta == 0 { // struct terminator is zero delta fieldnum break } fieldnum := state.fieldnum + delta if fieldnum >= len(engine.instr) { - state.err = errRange + error(errRange) break } instr := &engine.instr[fieldnum] @@ -453,36 +450,35 @@ func (dec *Decoder) decodeStruct(engine *decEngine, rtyp *reflect.StructType, b instr.op(instr, state, p) state.fieldnum = fieldnum } - return state.err + return nil } -func ignoreStruct(engine *decEngine, b **bytes.Buffer) os.Error { +func ignoreStruct(engine *decEngine, b **bytes.Buffer) (err os.Error) { + defer catchError(&err) state := newDecodeState(b) state.fieldnum = -1 - for state.b.Len() > 0 && state.err == nil { + for state.b.Len() > 0 { delta := int(decodeUint(state)) if delta < 0 { - state.err = os.ErrorString("gob ignore decode: corrupted data: negative delta") - break + errorf("gob ignore decode: corrupted data: negative delta") } - if state.err != nil || delta == 0 { // struct terminator is zero delta fieldnum + if delta == 0 { // struct terminator is zero delta fieldnum break } fieldnum := state.fieldnum + delta if fieldnum >= len(engine.instr) { - state.err = errRange - break + error(errRange) } instr := &engine.instr[fieldnum] instr.op(instr, state, unsafe.Pointer(nil)) state.fieldnum = fieldnum } - return state.err + return nil } -func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) os.Error { +func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} - for i := 0; i < length && state.err == nil; i++ { + for i := 0; i < length; i++ { up := unsafe.Pointer(p) if elemIndir > 1 { up = decIndirect(up, elemIndir) @@ -490,17 +486,16 @@ func decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uint elemOp(instr, state, up) p += uintptr(elemWid) } - return state.err } -func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) os.Error { +func decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } if n := decodeUint(state); n != uint64(length) { - return os.ErrorString("gob: length mismatch in decodeArray") + errorf("gob: length mismatch in decodeArray") } - return decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) + decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { @@ -513,7 +508,7 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o return v } -func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) os.Error { +func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -527,47 +522,38 @@ func decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elem // the iteration. v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) n := int(decodeUint(state)) - for i := 0; i < n && state.err == nil; i++ { + for i := 0; i < n; i++ { key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) - if state.err != nil { - break - } elem := decodeIntoValue(state, elemOp, elemIndir, reflect.MakeZero(mtyp.Elem()), ovfl) - if state.err != nil { - break - } v.SetElem(key, elem) } - return state.err } -func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) os.Error { +func ignoreArrayHelper(state *decodeState, elemOp decOp, length int) { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} - for i := 0; i < length && state.err == nil; i++ { + for i := 0; i < length; i++ { elemOp(instr, state, nil) } - return state.err } -func ignoreArray(state *decodeState, elemOp decOp, length int) os.Error { +func ignoreArray(state *decodeState, elemOp decOp, length int) { if n := decodeUint(state); n != uint64(length) { - return os.ErrorString("gob: length mismatch in ignoreArray") + errorf("gob: length mismatch in ignoreArray") } - return ignoreArrayHelper(state, elemOp, length) + ignoreArrayHelper(state, elemOp, length) } -func ignoreMap(state *decodeState, keyOp, elemOp decOp) os.Error { +func ignoreMap(state *decodeState, keyOp, elemOp decOp) { n := int(decodeUint(state)) keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} - for i := 0; i < n && state.err == nil; i++ { + for i := 0; i < n; i++ { keyOp(keyInstr, state, nil) elemOp(elemInstr, state, nil) } - return state.err } -func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) os.Error { +func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { n := int(uintptr(decodeUint(state))) if indir > 0 { up := unsafe.Pointer(p) @@ -583,11 +569,11 @@ func decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp hdrp.Data = uintptr(unsafe.NewArray(atyp.Elem(), n)) hdrp.Len = n hdrp.Cap = n - return decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) + decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) } -func ignoreSlice(state *decodeState, elemOp decOp) os.Error { - return ignoreArrayHelper(state, elemOp, int(decodeUint(state))) +func ignoreSlice(state *decodeState, elemOp decOp) { + ignoreArrayHelper(state, elemOp, int(decodeUint(state))) } // setInterfaceValue sets an interface value to a concrete value through @@ -596,19 +582,18 @@ func ignoreSlice(state *decodeState, elemOp decOp) os.Error { // This dance avoids manually checking that the value satisfies the // interface. // TODO(rsc): avoid panic+recover after fixing issue 327. -func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) (err os.Error) { +func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { defer func() { if e := recover(); e != nil { - err = e.(os.Error) + error(e.(os.Error)) } }() ivalue.Set(value) - return nil } // decodeInterface receives the name of a concrete type followed by its value. // If the name is empty, the value is nil and no value is sent. -func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) os.Error { +func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) { // Create an interface reflect.Value. We need one even for the nil case. ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) // Read the name of the concrete type. @@ -619,20 +604,18 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt // Copy the representation of the nil interface value to the target. // This is horribly unsafe and special. *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() - return state.err + return } // The concrete type must be registered. typ, ok := nameToConcreteType[name] if !ok { - state.err = os.ErrorString("gob: name not registered for interface: " + name) - return state.err + errorf("gob: name not registered for interface: %q", name) } // Read the concrete value. value := reflect.MakeZero(typ) dec.decodeValueFromBuffer(value, false) - if dec.state.err != nil { - state.err = dec.state.err - return state.err + if dec.err != nil { + error(dec.err) } // Allocate the destination interface value. if indir > 0 { @@ -640,26 +623,22 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt } // Assign the concrete value to the interface. // Tread carefully; it might not satisfy the interface. - dec.state.err = setInterfaceValue(ivalue, value) - if dec.state.err != nil { - state.err = dec.state.err - return state.err - } + setInterfaceValue(ivalue, value) // Copy the representation of the interface value to the target. // This is horribly unsafe and special. *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() - return nil } -func (dec *Decoder) ignoreInterface(state *decodeState) os.Error { +func (dec *Decoder) ignoreInterface(state *decodeState) { // Read the name of the concrete type. b := make([]byte, decodeUint(state)) _, err := state.b.Read(b) if err != nil { dec.decodeValueFromBuffer(nil, true) - err = dec.state.err + if dec.err != nil { + error(err) + } } - return err } // Index by Go types. @@ -712,7 +691,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) + decodeArray(t, state, uintptr(p), elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } case *reflect.MapType: @@ -730,7 +709,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { up := unsafe.Pointer(p) - state.err = decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) + decodeMap(t, state, uintptr(up), keyOp, elemOp, i.indir, keyIndir, elemIndir, ovfl) } case *reflect.SliceType: @@ -751,7 +730,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } ovfl := overflow(name) op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) + decodeSlice(t, state, uintptr(p), elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: @@ -762,11 +741,14 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs - state.err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) + err = dec.decodeStruct(*enginePtr, t, state.b, uintptr(p), i.indir) + if err != nil { + error(err) + } } case *reflect.InterfaceType: op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = dec.decodeInterface(t, state, uintptr(p), i.indir) + dec.decodeInterface(t, state, uintptr(p), i.indir) } } } @@ -784,7 +766,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { // Special case because it's a method: the ignored item might // define types and we need to record their state in the decoder. op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = dec.ignoreInterface(state) + dec.ignoreInterface(state) } return op, nil } @@ -800,7 +782,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { return nil, err } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreArray(state, elemOp, wire.arrayT.Len) + ignoreArray(state, elemOp, wire.arrayT.Len) } case wire.mapT != nil: @@ -815,7 +797,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { return nil, err } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreMap(state, keyOp, elemOp) + ignoreMap(state, keyOp, elemOp) } case wire.sliceT != nil: @@ -825,7 +807,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { return nil, err } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - state.err = ignoreSlice(state, elemOp) + ignoreSlice(state, elemOp) } case wire.structT != nil: @@ -836,7 +818,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) { } op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs - state.err = ignoreStruct(*enginePtr, state.b) + ignoreStruct(*enginePtr, state.b) } } } diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index e2f1e363f6..b86bdf3985 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -25,6 +25,7 @@ type Decoder struct { buf []byte countBuf [9]byte // counts may be uint64s (unlikely!), require 9 bytes byteBuffer *bytes.Buffer + err os.Error } // NewDecoder returns a new decoder that reads from the io.Reader. @@ -43,13 +44,16 @@ func NewDecoder(r io.Reader) *Decoder { func (dec *Decoder) recvType(id typeId) { // Have we already seen this type? That's an error if dec.wireType[id] != nil { - dec.state.err = os.ErrorString("gob: duplicate type received") + dec.err = os.ErrorString("gob: duplicate type received") return } // Type: wire := new(wireType) - dec.state.err = dec.decode(tWireType, reflect.NewValue(wire)) + dec.err = dec.decode(tWireType, reflect.NewValue(wire)) + if dec.err != nil { + return + } // Remember we've seen this type. dec.wireType[id] = wire @@ -66,8 +70,8 @@ func (dec *Decoder) Decode(e interface{}) os.Error { // If e represents a value as opposed to a pointer, the answer won't // get back to the caller. Make sure it's a pointer. if value.Type().Kind() != reflect.Ptr { - dec.state.err = os.ErrorString("gob: attempt to decode into a non-pointer") - return dec.state.err + dec.err = os.ErrorString("gob: attempt to decode into a non-pointer") + return dec.err } return dec.DecodeValue(value) } @@ -77,8 +81,8 @@ func (dec *Decoder) Decode(e interface{}) os.Error { func (dec *Decoder) recv() { // Read a count. var nbytes uint64 - nbytes, dec.state.err = decodeUintReader(dec.r, dec.countBuf[0:]) - if dec.state.err != nil { + nbytes, dec.err = decodeUintReader(dec.r, dec.countBuf[0:]) + if dec.err != nil { return } // Allocate the buffer. @@ -88,10 +92,10 @@ func (dec *Decoder) recv() { dec.byteBuffer = bytes.NewBuffer(dec.buf[0:nbytes]) // Read the data - _, dec.state.err = io.ReadFull(dec.r, dec.buf[0:nbytes]) - if dec.state.err != nil { - if dec.state.err == os.EOF { - dec.state.err = io.ErrUnexpectedEOF + _, dec.err = io.ReadFull(dec.r, dec.buf[0:nbytes]) + if dec.err != nil { + if dec.err == os.EOF { + dec.err = io.ErrUnexpectedEOF } return } @@ -104,7 +108,7 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) { for dec.state.b.Len() > 0 { // Receive a type id. id := typeId(decodeInt(dec.state)) - if dec.state.err != nil { + if dec.err != nil { break } @@ -112,7 +116,7 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) { if id < 0 { // 0 is the error state, handled above // If the id is negative, we have a type. dec.recvType(-id) - if dec.state.err != nil { + if dec.err != nil { break } continue @@ -126,10 +130,10 @@ func (dec *Decoder) decodeValueFromBuffer(value reflect.Value, ignore bool) { // Make sure the type has been defined already or is a builtin type (for // top-level singleton values). if dec.wireType[id] == nil && builtinIdToType[id] == nil { - dec.state.err = errBadType + dec.err = errBadType break } - dec.state.err = dec.decode(id, value) + dec.err = dec.decode(id, value) break } } @@ -143,11 +147,11 @@ func (dec *Decoder) DecodeValue(value reflect.Value) os.Error { dec.mutex.Lock() defer dec.mutex.Unlock() - dec.state.err = nil + dec.err = nil dec.recv() - if dec.state.err != nil { - return dec.state.err + if dec.err != nil { + return dec.err } dec.decodeValueFromBuffer(value, false) - return dec.state.err + return dec.err } diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 833b87c767..0be2d81a5a 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -21,7 +21,6 @@ const uint64Size = unsafe.Sizeof(uint64(0)) // 0 terminates the structure. type encoderState struct { b *bytes.Buffer - err os.Error // error encountered during encoding. sendZero bool // encoding an array element or map key/value pair; send zero values fieldnum int // the last field number written. buf [1 + uint64Size]byte // buffer used by the encoder; here to avoid allocation. @@ -39,11 +38,11 @@ func newEncoderState(b *bytes.Buffer) *encoderState { // encodeUint writes an encoded unsigned integer to state.b. Sets state.err. // If state.err is already non-nil, it does nothing. func encodeUint(state *encoderState, x uint64) { - if state.err != nil { - return - } if x <= 0x7F { - state.err = state.b.WriteByte(uint8(x)) + err := state.b.WriteByte(uint8(x)) + if err != nil { + error(err) + } return } var n, m int @@ -54,7 +53,10 @@ func encodeUint(state *encoderState, x uint64) { m-- } state.buf[m] = uint8(-(n - 1)) - n, state.err = state.b.Write(state.buf[m : uint64Size+1]) + n, err := state.b.Write(state.buf[m : uint64Size+1]) + if err != nil { + error(err) + } } // encodeInt writes an encoded signed integer to state.w. @@ -317,7 +319,8 @@ type encEngine struct { const singletonField = 0 -func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { +func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) { + defer catchError(&err) state := newEncoderState(b) state.fieldnum = singletonField // There is no surrounding struct to frame the transmission, so we must @@ -331,10 +334,11 @@ func encodeSingle(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { } } instr.op(instr, state, p) - return state.err + return } -func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { +func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) (err os.Error) { + defer catchError(&err) state := newEncoderState(b) state.fieldnum = -1 for i := 0; i < len(engine.instr); i++ { @@ -346,32 +350,27 @@ func encodeStruct(engine *encEngine, b *bytes.Buffer, basep uintptr) os.Error { } } instr.op(instr, state, p) - if state.err != nil { - break - } } - return state.err + return nil } -func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) os.Error { +func encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) { state := newEncoderState(b) state.fieldnum = -1 state.sendZero = true encodeUint(state, uint64(length)) - for i := 0; i < length && state.err == nil; i++ { + for i := 0; i < length; i++ { elemp := p up := unsafe.Pointer(elemp) if elemIndir > 0 { if up = encIndirect(up, elemIndir); up == nil { - state.err = os.ErrorString("gob: encodeArray: nil element") - break + errorf("gob: encodeArray: nil element") } elemp = uintptr(up) } op(nil, state, unsafe.Pointer(elemp)) p += uintptr(elemWid) } - return state.err } func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { @@ -379,61 +378,54 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in v = reflect.Indirect(v) } if v == nil { - state.err = os.ErrorString("gob: encodeReflectValue: nil element") - return + errorf("gob: encodeReflectValue: nil element") } op(nil, state, unsafe.Pointer(v.Addr())) } -func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) os.Error { +func encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { state := newEncoderState(b) state.fieldnum = -1 state.sendZero = true keys := mv.Keys() encodeUint(state, uint64(len(keys))) for _, key := range keys { - if state.err != nil { - break - } encodeReflectValue(state, key, keyOp, keyIndir) encodeReflectValue(state, mv.Elem(key), elemOp, elemIndir) } - return state.err } // To send an interface, we send a string identifying the concrete type, followed // by the type identifier (which might require defining that type right now), followed // by the concrete value. A nil value gets sent as the empty string for the name, // followed by no value. -func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) os.Error { +func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) { state := newEncoderState(b) state.fieldnum = -1 state.sendZero = true if iv.IsNil() { encodeUint(state, 0) - return state.err + return } typ := iv.Elem().Type() name, ok := concreteTypeToName[typ] if !ok { - state.err = os.ErrorString("gob: type not registered for interface: " + typ.String()) - return state.err + errorf("gob: type not registered for interface: %s", typ) } // Send the name. encodeUint(state, uint64(len(name))) - _, state.err = io.WriteString(state.b, name) - if state.err != nil { - return state.err + _, err := io.WriteString(state.b, name) + if err != nil { + error(err) } // Send (and maybe first define) the type id. enc.sendTypeDescriptor(typ) - if state.err != nil { - return state.err - } // Send the value. - state.err = enc.encode(state.b, iv.Elem()) - return state.err + err = enc.encode(state.b, iv.Elem()) + if err != nil { + error(err) + } } var encOpMap = []encOp{ @@ -486,7 +478,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { return } state.update(i) - state.err = encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) + encodeArray(state.b, slice.Data, elemOp, t.Elem().Size(), indir, int(slice.Len)) } case *reflect.ArrayType: // True arrays have size in the type. @@ -496,7 +488,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) - state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) + encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), indir, t.Len()) } case *reflect.MapType: keyOp, keyIndir, err := enc.encOpFor(t.Key()) @@ -517,7 +509,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { return } state.update(i) - state.err = encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) + encodeMap(state.b, mv, keyOp, elemOp, keyIndir, elemIndir) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. @@ -529,7 +521,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) // indirect through info to delay evaluation for recursive structs - state.err = encodeStruct(info.encoder, state.b, uintptr(p)) + encodeStruct(info.encoder, state.b, uintptr(p)) } case *reflect.InterfaceType: op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { @@ -541,7 +533,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type) (encOp, int, os.Error) { return } state.update(i) - state.err = enc.encodeInterface(state.b, iv) + enc.encodeInterface(state.b, iv) } } } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index ff9834600f..5d12d920b4 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -21,6 +21,7 @@ type Encoder struct { state *encoderState // so we can encode integers, strings directly countState *encoderState // stage for writing counts buf []byte // for collecting the output. + err os.Error } // NewEncoder returns a new encoder that will transmit on the io.Writer. @@ -38,8 +39,8 @@ func (enc *Encoder) badType(rt reflect.Type) { } func (enc *Encoder) setError(err os.Error) { - if enc.state.err == nil { // remember the first. - enc.state.err = err + if enc.err == nil { // remember the first. + enc.err = err } enc.state.b.Reset() } @@ -115,7 +116,7 @@ func (enc *Encoder) sendType(origt reflect.Type) (sent bool) { // Type: enc.encode(enc.state.b, reflect.NewValue(info.wire)) enc.send() - if enc.state.err != nil { + if enc.err != nil { return } @@ -150,7 +151,7 @@ func (enc *Encoder) sendTypeDescriptor(rt reflect.Type) { if _, alreadySent := enc.sent[rt]; !alreadySent { // No, so send it. sent := enc.sendType(rt) - if enc.state.err != nil { + if enc.err != nil { return } // If the type info has still not been transmitted, it means we have @@ -180,18 +181,18 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { enc.mutex.Lock() defer enc.mutex.Unlock() - enc.state.err = nil + enc.err = nil rt, _ := indirect(value.Type()) // Sanity check only: encoder should never come in with data present. if enc.state.b.Len() > 0 || enc.countState.b.Len() > 0 { - enc.state.err = os.ErrorString("encoder: buffer not empty") - return enc.state.err + enc.err = os.ErrorString("encoder: buffer not empty") + return enc.err } enc.sendTypeDescriptor(rt) - if enc.state.err != nil { - return enc.state.err + if enc.err != nil { + return enc.err } // Encode the object. @@ -202,5 +203,5 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { enc.send() } - return enc.state.err + return enc.err } diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index 4f2702a4dd..91d85bb7ad 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -43,15 +43,15 @@ func TestEncoderDecoder(t *testing.T) { et1 := new(ET1) et1.a = 7 et1.et2 = new(ET2) - enc.Encode(et1) - if enc.state.err != nil { - t.Error("encoder fail:", enc.state.err) + err := enc.Encode(et1) + if err != nil { + t.Error("encoder fail:", err) } dec := NewDecoder(b) newEt1 := new(ET1) - dec.Decode(newEt1) - if dec.state.err != nil { - t.Fatal("error decoding ET1:", dec.state.err) + err = dec.Decode(newEt1) + if err != nil { + t.Fatal("error decoding ET1:", err) } if !reflect.DeepEqual(et1, newEt1) { @@ -63,9 +63,9 @@ func TestEncoderDecoder(t *testing.T) { enc.Encode(et1) newEt1 = new(ET1) - dec.Decode(newEt1) - if dec.state.err != nil { - t.Fatal("round 2: error decoding ET1:", dec.state.err) + err = dec.Decode(newEt1) + if err != nil { + t.Fatal("round 2: error decoding ET1:", err) } if !reflect.DeepEqual(et1, newEt1) { t.Fatalf("round 2: invalid data for et1: expected %+v; got %+v", *et1, *newEt1) @@ -75,13 +75,13 @@ func TestEncoderDecoder(t *testing.T) { } // Now test with a running encoder/decoder pair that we recognize a type mismatch. - enc.Encode(et1) - if enc.state.err != nil { - t.Error("round 3: encoder fail:", enc.state.err) + err = enc.Encode(et1) + if err != nil { + t.Error("round 3: encoder fail:", err) } newEt2 := new(ET2) - dec.Decode(newEt2) - if dec.state.err == nil { + err = dec.Decode(newEt2) + if err == nil { t.Fatal("round 3: expected `bad type' error decoding ET2") } } @@ -94,17 +94,17 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) { et1 := new(ET1) et1.a = 7 et1.et2 = new(ET2) - enc.Encode(et1) - if enc.state.err != nil { - t.Error("encoder fail:", enc.state.err) + err := enc.Encode(et1) + if err != nil { + t.Error("encoder fail:", err) } dec := NewDecoder(b) - dec.Decode(e) - if shouldFail && (dec.state.err == nil) { + err = dec.Decode(e) + if shouldFail && err == nil { t.Error("expected error for", msg) } - if !shouldFail && (dec.state.err != nil) { - t.Error("unexpected error for", msg, dec.state.err) + if !shouldFail && err != nil { + t.Error("unexpected error for", msg, err) } } @@ -118,9 +118,9 @@ func TestWrongTypeDecoder(t *testing.T) { func corruptDataCheck(s string, err os.Error, t *testing.T) { b := bytes.NewBufferString(s) dec := NewDecoder(b) - dec.Decode(new(ET2)) - if dec.state.err != err { - t.Error("expected error", err, "got", dec.state.err) + err1 := dec.Decode(new(ET2)) + if err1 != err { + t.Error("expected error", err, "got", err1) } } @@ -151,14 +151,14 @@ func TestUnsupported(t *testing.T) { func encAndDec(in, out interface{}) os.Error { b := new(bytes.Buffer) enc := NewEncoder(b) - enc.Encode(in) - if enc.state.err != nil { - return enc.state.err + err := enc.Encode(in) + if err != nil { + return err } dec := NewDecoder(b) - dec.Decode(out) - if dec.state.err != nil { - return dec.state.err + err = dec.Decode(out) + if err != nil { + return err } return nil } diff --git a/src/pkg/gob/error.go b/src/pkg/gob/error.go new file mode 100644 index 0000000000..b053761fbc --- /dev/null +++ b/src/pkg/gob/error.go @@ -0,0 +1,41 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gob + +import ( + "fmt" + "os" +) + +// Errors in decoding and encoding are handled using panic and recover. +// Panics caused by user error (that is, everything except run-time panics +// such as "index out of bounds" errors) do not leave the file that caused +// them, but are instead turned into plain os.Error returns. Encoding and +// decoding functions and methods that do not return an os.Error either use +// panic to report an error or are guaranteed error-free. + +// A gobError wraps an os.Error and is used to distinguish errors (panics) generated in this package. +type gobError struct { + os.Error +} + +// errorf is like error but takes Printf-style arguments to construct an os.Error. +func errorf(format string, args ...interface{}) { + error(fmt.Errorf(format, args...)) +} + +// error wraps the argument error and uses it as the argument to panic. +func error(err os.Error) { + panic(gobError{Error: err}) +} + +// catchError is meant to be used as a deferred function to turn a panic(gobError) into a +// plain os.Error. It overwrites the error return of the function that deferred its call. +func catchError(err *os.Error) { + if e := recover(); e != nil { + *err = e.(gobError).Error // Will re-panic if not one of our errors, such as a runtime error. + } + return +} |