aboutsummaryrefslogtreecommitdiff
path: root/src/encoding
diff options
context:
space:
mode:
authorlujjjh <lujjjh@gmail.com>2020-09-17 14:39:13 +0000
committerDaniel Martí <mvdan@mvdan.cc>2020-09-24 18:18:20 +0000
commit428509402b03c608e625a4844ab0cce75e4bead2 (patch)
treeb24e7a03d00c5c7f30e935801c1f3c426f8abcd9 /src/encoding
parent25a33daa2b7e7bda773705215113450923ae4815 (diff)
downloadgo-428509402b03c608e625a4844ab0cce75e4bead2.tar.gz
go-428509402b03c608e625a4844ab0cce75e4bead2.zip
encoding/json: detect cyclic maps and slices
Now reports an error if cyclic maps and slices are to be encoded instead of an infinite recursion. This case wasn't handled in CL 187920. Fixes #40745. Change-Id: Ia34b014ecbb71fd2663bb065ba5355a307dbcc15 GitHub-Last-Rev: 6f874944f4065b5237babbb0fdce14c1c74a3c97 GitHub-Pull-Request: golang/go#40756 Reviewed-on: https://go-review.googlesource.com/c/go/+/248358 Reviewed-by: Daniel Martí <mvdan@mvdan.cc> Trust: Daniel Martí <mvdan@mvdan.cc> Trust: Bryan C. Mills <bcmills@google.com> Run-TryBot: Daniel Martí <mvdan@mvdan.cc> TryBot-Result: Go Bot <gobot@golang.org>
Diffstat (limited to 'src/encoding')
-rw-r--r--src/encoding/json/encode.go27
-rw-r--r--src/encoding/json/encode_test.go27
2 files changed, 53 insertions, 1 deletions
diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go
index c2d191442c..ea5eca51ef 100644
--- a/src/encoding/json/encode.go
+++ b/src/encoding/json/encode.go
@@ -779,6 +779,16 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
+ if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+ // We're a large number of nested ptrEncoder.encode calls deep;
+ // start checking if we've run into a pointer cycle.
+ ptr := v.Pointer()
+ if _, ok := e.ptrSeen[ptr]; ok {
+ e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+ }
+ e.ptrSeen[ptr] = struct{}{}
+ defer delete(e.ptrSeen, ptr)
+ }
e.WriteByte('{')
// Extract and sort the keys.
@@ -801,6 +811,7 @@ func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
me.elemEnc(e, v.MapIndex(kv.v), opts)
}
e.WriteByte('}')
+ e.ptrLevel--
}
func newMapEncoder(t reflect.Type) encoderFunc {
@@ -857,7 +868,23 @@ func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
e.WriteString("null")
return
}
+ if e.ptrLevel++; e.ptrLevel > startDetectingCyclesAfter {
+ // We're a large number of nested ptrEncoder.encode calls deep;
+ // start checking if we've run into a pointer cycle.
+ // Here we use a struct to memorize the pointer to the first element of the slice
+ // and its length.
+ ptr := struct {
+ ptr uintptr
+ len int
+ }{v.Pointer(), v.Len()}
+ if _, ok := e.ptrSeen[ptr]; ok {
+ e.error(&UnsupportedValueError{v, fmt.Sprintf("encountered a cycle via %s", v.Type())})
+ }
+ e.ptrSeen[ptr] = struct{}{}
+ defer delete(e.ptrSeen, ptr)
+ }
se.arrayEnc(e, v, opts)
+ e.ptrLevel--
}
func newSliceEncoder(t reflect.Type) encoderFunc {
diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go
index 7290eca06f..42bb09d5cd 100644
--- a/src/encoding/json/encode_test.go
+++ b/src/encoding/json/encode_test.go
@@ -183,7 +183,15 @@ type PointerCycleIndirect struct {
Ptrs []interface{}
}
-var pointerCycleIndirect = &PointerCycleIndirect{}
+type RecursiveSlice []RecursiveSlice
+
+var (
+ pointerCycleIndirect = &PointerCycleIndirect{}
+ mapCycle = make(map[string]interface{})
+ sliceCycle = []interface{}{nil}
+ sliceNoCycle = []interface{}{nil, nil}
+ recursiveSliceCycle = []RecursiveSlice{nil}
+)
func init() {
ptr := &SamePointerNoCycle{}
@@ -192,6 +200,14 @@ func init() {
pointerCycle.Ptr = pointerCycle
pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}
+
+ mapCycle["x"] = mapCycle
+ sliceCycle[0] = sliceCycle
+ sliceNoCycle[1] = sliceNoCycle[:1]
+ for i := startDetectingCyclesAfter; i > 0; i-- {
+ sliceNoCycle = []interface{}{sliceNoCycle}
+ }
+ recursiveSliceCycle[0] = recursiveSliceCycle
}
func TestSamePointerNoCycle(t *testing.T) {
@@ -200,12 +216,21 @@ func TestSamePointerNoCycle(t *testing.T) {
}
}
+func TestSliceNoCycle(t *testing.T) {
+ if _, err := Marshal(sliceNoCycle); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
pointerCycle,
pointerCycleIndirect,
+ mapCycle,
+ sliceCycle,
+ recursiveSliceCycle,
}
func TestUnsupportedValues(t *testing.T) {