diff options
Diffstat (limited to 'src/encoding')
-rw-r--r-- | src/encoding/gob/decode.go | 19 | ||||
-rw-r--r-- | src/encoding/gob/gobencdec_test.go | 24 | ||||
-rw-r--r-- | src/encoding/xml/read.go | 42 | ||||
-rw-r--r-- | src/encoding/xml/read_test.go | 32 |
4 files changed, 95 insertions, 22 deletions
diff --git a/src/encoding/gob/decode.go b/src/encoding/gob/decode.go index d2f6c749b1..0e0ec75ccc 100644 --- a/src/encoding/gob/decode.go +++ b/src/encoding/gob/decode.go @@ -871,8 +871,13 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg return &op } +var maxIgnoreNestingDepth = 10000 + // decIgnoreOpFor returns the decoding op for a field that has no destination. -func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) *decOp { +func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp, depth int) *decOp { + if depth > maxIgnoreNestingDepth { + error_(errors.New("invalid nesting depth")) + } // If this type is already in progress, it's a recursive type (e.g. map[string]*T). // Return the pointer to the op we're already building. if opPtr := inProgress[wireId]; opPtr != nil { @@ -896,7 +901,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) errorf("bad data: undefined type %s", wireId.string()) case wire.ArrayT != nil: elemId := wire.ArrayT.Elem - elemOp := dec.decIgnoreOpFor(elemId, inProgress) + elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreArray(state, *elemOp, wire.ArrayT.Len) } @@ -904,15 +909,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId, inProgress map[typeId]*decOp) case wire.MapT != nil: keyId := dec.wireType[wireId].MapT.Key elemId := dec.wireType[wireId].MapT.Elem - keyOp := dec.decIgnoreOpFor(keyId, inProgress) - elemOp := dec.decIgnoreOpFor(elemId, inProgress) + keyOp := dec.decIgnoreOpFor(keyId, inProgress, depth+1) + elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreMap(state, *keyOp, *elemOp) } case wire.SliceT != nil: elemId := wire.SliceT.Elem - elemOp := dec.decIgnoreOpFor(elemId, inProgress) + elemOp := dec.decIgnoreOpFor(elemId, inProgress, depth+1) op = func(i *decInstr, state *decoderState, value reflect.Value) { state.dec.ignoreSlice(state, *elemOp) } @@ -1073,7 +1078,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *de func (dec *Decoder) compileIgnoreSingle(remoteId typeId) *decEngine { engine := new(decEngine) engine.instr = make([]decInstr, 1) // one item - op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp)) + op := dec.decIgnoreOpFor(remoteId, make(map[typeId]*decOp), 0) ovfl := overflow(dec.typeString(remoteId)) engine.instr[0] = decInstr{*op, 0, nil, ovfl} engine.numInstr = 1 @@ -1118,7 +1123,7 @@ func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEn localField, present := srt.FieldByName(wireField.Name) // TODO(r): anonymous names if !present || !isExported(wireField.Name) { - op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp)) + op := dec.decIgnoreOpFor(wireField.Id, make(map[typeId]*decOp), 0) engine.instr[fieldnum] = decInstr{*op, fieldnum, nil, ovfl} continue } diff --git a/src/encoding/gob/gobencdec_test.go b/src/encoding/gob/gobencdec_test.go index 6d2c8db42d..1b52ecc6c8 100644 --- a/src/encoding/gob/gobencdec_test.go +++ b/src/encoding/gob/gobencdec_test.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "net" + "reflect" "strings" "testing" "time" @@ -796,3 +797,26 @@ func TestNetIP(t *testing.T) { t.Errorf("decoded to %v, want 1.2.3.4", ip.String()) } } + +func TestIngoreDepthLimit(t *testing.T) { + // We don't test the actual depth limit because it requires building an + // extremely large message, which takes quite a while. + oldNestingDepth := maxIgnoreNestingDepth + maxIgnoreNestingDepth = 100 + defer func() { maxIgnoreNestingDepth = oldNestingDepth }() + b := new(bytes.Buffer) + enc := NewEncoder(b) + typ := reflect.TypeOf(int(0)) + nested := reflect.ArrayOf(1, typ) + for i := 0; i < 100; i++ { + nested = reflect.ArrayOf(1, nested) + } + badStruct := reflect.New(reflect.StructOf([]reflect.StructField{{Name: "F", Type: nested}})) + enc.Encode(badStruct.Interface()) + dec := NewDecoder(b) + var output struct{ Hello int } + expectedErr := "invalid nesting depth" + if err := dec.Decode(&output); err == nil || err.Error() != expectedErr { + t.Errorf("Decode didn't fail with depth limit of 100: want %q, got %q", expectedErr, err) + } +} diff --git a/src/encoding/xml/read.go b/src/encoding/xml/read.go index ef5df3f7f6..c77579880c 100644 --- a/src/encoding/xml/read.go +++ b/src/encoding/xml/read.go @@ -148,7 +148,7 @@ func (d *Decoder) DecodeElement(v interface{}, start *StartElement) error { if val.Kind() != reflect.Ptr { return errors.New("non-pointer passed to Unmarshal") } - return d.unmarshal(val.Elem(), start) + return d.unmarshal(val.Elem(), start, 0) } // An UnmarshalError represents an error in the unmarshaling process. @@ -304,8 +304,15 @@ var ( textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) +const maxUnmarshalDepth = 10000 + +var errExeceededMaxUnmarshalDepth = errors.New("exceeded max depth") + // Unmarshal a single XML element into val. -func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error { +func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error { + if depth >= maxUnmarshalDepth { + return errExeceededMaxUnmarshalDepth + } // Find start element if we need it. if start == nil { for { @@ -398,7 +405,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error { v.Set(reflect.Append(val, reflect.Zero(v.Type().Elem()))) // Recur to read element into slice. - if err := d.unmarshal(v.Index(n), start); err != nil { + if err := d.unmarshal(v.Index(n), start, depth+1); err != nil { v.SetLen(n) return err } @@ -521,13 +528,15 @@ Loop: case StartElement: consumed := false if sv.IsValid() { - consumed, err = d.unmarshalPath(tinfo, sv, nil, &t) + // unmarshalPath can call unmarshal, so we need to pass the depth through so that + // we can continue to enforce the maximum recusion limit. + consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth) if err != nil { return err } if !consumed && saveAny.IsValid() { consumed = true - if err := d.unmarshal(saveAny, &t); err != nil { + if err := d.unmarshal(saveAny, &t, depth+1); err != nil { return err } } @@ -672,7 +681,7 @@ func copyValue(dst reflect.Value, src []byte) (err error) { // The consumed result tells whether XML elements have been consumed // from the Decoder until start's matching end element, or if it's // still untouched because start is uninteresting for sv's fields. -func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) { +func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) { recurse := false Loop: for i := range tinfo.fields { @@ -687,7 +696,7 @@ Loop: } if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local { // It's a perfect match, unmarshal the field. - return true, d.unmarshal(finfo.value(sv, initNilPointers), start) + return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1) } if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local { // It's a prefix for the field. Break and recurse @@ -716,7 +725,9 @@ Loop: } switch t := tok.(type) { case StartElement: - consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t) + // the recursion depth of unmarshalPath is limited to the path length specified + // by the struct field tag, so we don't increment the depth here. + consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth) if err != nil { return true, err } @@ -732,12 +743,12 @@ Loop: } // Skip reads tokens until it has consumed the end element -// matching the most recent start element already consumed. -// It recurs if it encounters a start element, so it can be used to -// skip nested structures. +// matching the most recent start element already consumed, +// skipping nested structures. // It returns nil if it finds an end element matching the start // element; otherwise it returns an error describing the problem. func (d *Decoder) Skip() error { + var depth int64 for { tok, err := d.Token() if err != nil { @@ -745,11 +756,12 @@ func (d *Decoder) Skip() error { } switch tok.(type) { case StartElement: - if err := d.Skip(); err != nil { - return err - } + depth++ case EndElement: - return nil + if depth == 0 { + return nil + } + depth-- } } } diff --git a/src/encoding/xml/read_test.go b/src/encoding/xml/read_test.go index 8c2e70fa22..8c940aefb8 100644 --- a/src/encoding/xml/read_test.go +++ b/src/encoding/xml/read_test.go @@ -5,8 +5,11 @@ package xml import ( + "bytes" + "errors" "io" "reflect" + "runtime" "strings" "testing" "time" @@ -1079,3 +1082,32 @@ func TestUnmarshalWhitespaceAttrs(t *testing.T) { t.Fatalf("whitespace attrs: Unmarshal:\nhave: %#+v\nwant: %#+v", v, want) } } + +func TestCVE202230633(t *testing.T) { + if runtime.GOARCH == "wasm" { + t.Skip("causes memory exhaustion on js/wasm") + } + defer func() { + p := recover() + if p != nil { + t.Fatal("Unmarshal panicked") + } + }() + var example struct { + Things []string + } + Unmarshal(bytes.Repeat([]byte("<a>"), 17_000_000), &example) +} + +func TestCVE202228131(t *testing.T) { + type nested struct { + Parent *nested `xml:",any"` + } + var n nested + err := Unmarshal(bytes.Repeat([]byte("<a>"), maxUnmarshalDepth+1), &n) + if err == nil { + t.Fatal("Unmarshal did not fail") + } else if !errors.Is(err, errExeceededMaxUnmarshalDepth) { + t.Fatalf("Unmarshal unexpected error: got %q, want %q", err, errExeceededMaxUnmarshalDepth) + } +} |