aboutsummaryrefslogtreecommitdiff
path: root/src/encoding/asn1
diff options
context:
space:
mode:
authorHiroshi Ioka <hirochachacha@gmail.com>2016-08-02 14:41:53 +0900
committerBrad Fitzpatrick <bradfitz@golang.org>2016-09-13 21:05:27 +0000
commitae4aac00bba5d1d616408a1c07bd4ef5691e3a00 (patch)
treef5e2b6856fea65d469af67b5334a13bcce06c90f /src/encoding/asn1
parentee3f3a60070ee9edeb3f10fa2e4b90404068cb3a (diff)
downloadgo-ae4aac00bba5d1d616408a1c07bd4ef5691e3a00.tar.gz
go-ae4aac00bba5d1d616408a1c07bd4ef5691e3a00.zip
encoding/asn1: reduce allocations in Marshal
Current code uses trees of bytes.Buffer as data representation. Each bytes.Buffer takes 4k bytes at least, so it's waste of memory. The change introduces trees of lazy-encoder as alternative one which reduce allocations. name old time/op new time/op delta Marshal-4 64.7µs ± 2% 42.0µs ± 1% -35.07% (p=0.000 n=9+10) name old alloc/op new alloc/op delta Marshal-4 35.1kB ± 0% 7.6kB ± 0% -78.27% (p=0.000 n=10+10) name old allocs/op new allocs/op delta Marshal-4 503 ± 0% 293 ± 0% -41.75% (p=0.000 n=10+10) Change-Id: I32b96c20b8df00414b282d69743d71a598a11336 Reviewed-on: https://go-review.googlesource.com/27030 Reviewed-by: Adam Langley <agl@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Adam Langley <agl@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
Diffstat (limited to 'src/encoding/asn1')
-rw-r--r--src/encoding/asn1/asn1_test.go6
-rw-r--r--src/encoding/asn1/marshal.go615
-rw-r--r--src/encoding/asn1/marshal_test.go10
3 files changed, 317 insertions, 314 deletions
diff --git a/src/encoding/asn1/asn1_test.go b/src/encoding/asn1/asn1_test.go
index f8623fa9a2..81f4dba8c2 100644
--- a/src/encoding/asn1/asn1_test.go
+++ b/src/encoding/asn1/asn1_test.go
@@ -132,9 +132,9 @@ func TestParseBigInt(t *testing.T) {
if ret.String() != test.base10 {
t.Errorf("#%d: bad result from %x, got %s want %s", i, test.in, ret.String(), test.base10)
}
- fw := newForkableWriter()
- marshalBigInt(fw, ret)
- result := fw.Bytes()
+ e := makeBigInt(ret)
+ result := make([]byte, e.Len())
+ e.Encode(result)
if !bytes.Equal(result, test.in) {
t.Errorf("#%d: got %x from marshaling %s, want %x", i, result, ret, test.in)
}
diff --git a/src/encoding/asn1/marshal.go b/src/encoding/asn1/marshal.go
index 30797ef099..f0664d3d46 100644
--- a/src/encoding/asn1/marshal.go
+++ b/src/encoding/asn1/marshal.go
@@ -5,77 +5,125 @@
package asn1
import (
- "bytes"
"errors"
"fmt"
- "io"
"math/big"
"reflect"
"time"
"unicode/utf8"
)
-// A forkableWriter is an in-memory buffer that can be
-// 'forked' to create new forkableWriters that bracket the
-// original. After
-// pre, post := w.fork()
-// the overall sequence of bytes represented is logically w+pre+post.
-type forkableWriter struct {
- *bytes.Buffer
- pre, post *forkableWriter
+var (
+ byte00Encoder encoder = byteEncoder(0x00)
+ byteFFEncoder encoder = byteEncoder(0xff)
+)
+
+// encoder represents a ASN.1 element that is waiting to be marshaled.
+type encoder interface {
+ // Len returns the number of bytes needed to marshal this element.
+ Len() int
+ // Encode encodes this element by writing Len() bytes to dst.
+ Encode(dst []byte)
+}
+
+type byteEncoder byte
+
+func (c byteEncoder) Len() int {
+ return 1
}
-func newForkableWriter() *forkableWriter {
- return &forkableWriter{new(bytes.Buffer), nil, nil}
+func (c byteEncoder) Encode(dst []byte) {
+ dst[0] = byte(c)
}
-func (f *forkableWriter) fork() (pre, post *forkableWriter) {
- if f.pre != nil || f.post != nil {
- panic("have already forked")
+type bytesEncoder []byte
+
+func (b bytesEncoder) Len() int {
+ return len(b)
+}
+
+func (b bytesEncoder) Encode(dst []byte) {
+ if copy(dst, b) != len(b) {
+ panic("internal error")
}
- f.pre = newForkableWriter()
- f.post = newForkableWriter()
- return f.pre, f.post
}
-func (f *forkableWriter) Len() (l int) {
- l += f.Buffer.Len()
- if f.pre != nil {
- l += f.pre.Len()
+type stringEncoder string
+
+func (s stringEncoder) Len() int {
+ return len(s)
+}
+
+func (s stringEncoder) Encode(dst []byte) {
+ if copy(dst, s) != len(s) {
+ panic("internal error")
}
- if f.post != nil {
- l += f.post.Len()
+}
+
+type multiEncoder []encoder
+
+func (m multiEncoder) Len() int {
+ var size int
+ for _, e := range m {
+ size += e.Len()
}
- return
+ return size
}
-func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
- n, err = out.Write(f.Bytes())
- if err != nil {
- return
+func (m multiEncoder) Encode(dst []byte) {
+ var off int
+ for _, e := range m {
+ e.Encode(dst[off:])
+ off += e.Len()
}
+}
- var nn int
+type taggedEncoder struct {
+ // scratch contains temporary space for encoding the tag and length of
+ // an element in order to avoid extra allocations.
+ scratch [8]byte
+ tag encoder
+ body encoder
+}
- if f.pre != nil {
- nn, err = f.pre.writeTo(out)
- n += nn
- if err != nil {
- return
- }
+func (t *taggedEncoder) Len() int {
+ return t.tag.Len() + t.body.Len()
+}
+
+func (t *taggedEncoder) Encode(dst []byte) {
+ t.tag.Encode(dst)
+ t.body.Encode(dst[t.tag.Len():])
+}
+
+type int64Encoder int64
+
+func (i int64Encoder) Len() int {
+ n := 1
+
+ for i > 127 {
+ n++
+ i >>= 8
}
- if f.post != nil {
- nn, err = f.post.writeTo(out)
- n += nn
+ for i < -128 {
+ n++
+ i >>= 8
}
- return
+
+ return n
}
-func marshalBase128Int(out *forkableWriter, n int64) (err error) {
+func (i int64Encoder) Encode(dst []byte) {
+ n := i.Len()
+
+ for j := 0; j < n; j++ {
+ dst[j] = byte(i >> uint((n-1-j)*8))
+ }
+}
+
+func base128IntLength(n int64) int {
if n == 0 {
- err = out.WriteByte(0)
- return
+ return 1
}
l := 0
@@ -83,54 +131,29 @@ func marshalBase128Int(out *forkableWriter, n int64) (err error) {
l++
}
+ return l
+}
+
+func appendBase128Int(dst []byte, n int64) []byte {
+ l := base128IntLength(n)
+
for i := l - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
- err = out.WriteByte(o)
- if err != nil {
- return
- }
- }
-
- return nil
-}
-func marshalInt64(out *forkableWriter, i int64) (err error) {
- n := int64Length(i)
-
- for ; n > 0; n-- {
- err = out.WriteByte(byte(i >> uint((n-1)*8)))
- if err != nil {
- return
- }
+ dst = append(dst, o)
}
- return nil
+ return dst
}
-func int64Length(i int64) (numBytes int) {
- numBytes = 1
-
- for i > 127 {
- numBytes++
- i >>= 8
- }
-
- for i < -128 {
- numBytes++
- i >>= 8
- }
-
- return
-}
-
-func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
+func makeBigInt(n *big.Int) encoder {
if n.Sign() < 0 {
// A negative number has to be converted to two's-complement
- // form. So we'll subtract 1 and invert. If the
+ // form. So we'll invert and subtract 1. If the
// most-significant-bit isn't set then we'll need to pad the
// beginning with 0xff in order to keep the number negative.
nMinus1 := new(big.Int).Neg(n)
@@ -140,41 +163,31 @@ func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
bytes[i] ^= 0xff
}
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
- err = out.WriteByte(0xff)
- if err != nil {
- return
- }
+ return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)})
}
- _, err = out.Write(bytes)
+ return bytesEncoder(bytes)
} else if n.Sign() == 0 {
// Zero is written as a single 0 zero rather than no bytes.
- err = out.WriteByte(0x00)
+ return byte00Encoder
} else {
bytes := n.Bytes()
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
// We'll have to pad this with 0x00 in order to stop it
// looking like a negative number.
- err = out.WriteByte(0)
- if err != nil {
- return
- }
+ return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)})
}
- _, err = out.Write(bytes)
+ return bytesEncoder(bytes)
}
- return
}
-func marshalLength(out *forkableWriter, i int) (err error) {
+func appendLength(dst []byte, i int) []byte {
n := lengthLength(i)
for ; n > 0; n-- {
- err = out.WriteByte(byte(i >> uint((n-1)*8)))
- if err != nil {
- return
- }
+ dst = append(dst, byte(i>>uint((n-1)*8)))
}
- return nil
+ return dst
}
func lengthLength(i int) (numBytes int) {
@@ -186,123 +199,104 @@ func lengthLength(i int) (numBytes int) {
return
}
-func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
+func appendTagAndLength(dst []byte, t tagAndLength) []byte {
b := uint8(t.class) << 6
if t.isCompound {
b |= 0x20
}
if t.tag >= 31 {
b |= 0x1f
- err = out.WriteByte(b)
- if err != nil {
- return
- }
- err = marshalBase128Int(out, int64(t.tag))
- if err != nil {
- return
- }
+ dst = append(dst, b)
+ dst = appendBase128Int(dst, int64(t.tag))
} else {
b |= uint8(t.tag)
- err = out.WriteByte(b)
- if err != nil {
- return
- }
+ dst = append(dst, b)
}
if t.length >= 128 {
l := lengthLength(t.length)
- err = out.WriteByte(0x80 | byte(l))
- if err != nil {
- return
- }
- err = marshalLength(out, t.length)
- if err != nil {
- return
- }
+ dst = append(dst, 0x80|byte(l))
+ dst = appendLength(dst, t.length)
} else {
- err = out.WriteByte(byte(t.length))
- if err != nil {
- return
- }
+ dst = append(dst, byte(t.length))
}
- return nil
+ return dst
}
-func marshalBitString(out *forkableWriter, b BitString) (err error) {
- paddingBits := byte((8 - b.BitLength%8) % 8)
- err = out.WriteByte(paddingBits)
- if err != nil {
- return
- }
- _, err = out.Write(b.Bytes)
- return
+type bitStringEncoder BitString
+
+func (b bitStringEncoder) Len() int {
+ return len(b.Bytes) + 1
}
-func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
- if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
- return StructuralError{"invalid object identifier"}
+func (b bitStringEncoder) Encode(dst []byte) {
+ dst[0] = byte((8 - b.BitLength%8) % 8)
+ if copy(dst[1:], b.Bytes) != len(b.Bytes) {
+ panic("internal error")
}
+}
- err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
- if err != nil {
- return
+type oidEncoder []int
+
+func (oid oidEncoder) Len() int {
+ l := base128IntLength(int64(oid[0]*40 + oid[1]))
+ for i := 2; i < len(oid); i++ {
+ l += base128IntLength(int64(oid[i]))
}
+ return l
+}
+
+func (oid oidEncoder) Encode(dst []byte) {
+ dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
for i := 2; i < len(oid); i++ {
- err = marshalBase128Int(out, int64(oid[i]))
- if err != nil {
- return
- }
+ dst = appendBase128Int(dst, int64(oid[i]))
+ }
+}
+
+func makeObjectIdentifier(oid []int) (e encoder, err error) {
+ if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
+ return nil, StructuralError{"invalid object identifier"}
}
- return
+ return oidEncoder(oid), nil
}
-func marshalPrintableString(out *forkableWriter, s string) (err error) {
- b := []byte(s)
- for _, c := range b {
- if !isPrintable(c) {
- return StructuralError{"PrintableString contains invalid character"}
+func makePrintableString(s string) (e encoder, err error) {
+ for i := 0; i < len(s); i++ {
+ if !isPrintable(s[i]) {
+ return nil, StructuralError{"PrintableString contains invalid character"}
}
}
- _, err = out.Write(b)
- return
+ return stringEncoder(s), nil
}
-func marshalIA5String(out *forkableWriter, s string) (err error) {
- b := []byte(s)
- for _, c := range b {
- if c > 127 {
- return StructuralError{"IA5String contains invalid character"}
+func makeIA5String(s string) (e encoder, err error) {
+ for i := 0; i < len(s); i++ {
+ if s[i] > 127 {
+ return nil, StructuralError{"IA5String contains invalid character"}
}
}
- _, err = out.Write(b)
- return
+ return stringEncoder(s), nil
}
-func marshalUTF8String(out *forkableWriter, s string) (err error) {
- _, err = out.Write([]byte(s))
- return
+func makeUTF8String(s string) encoder {
+ return stringEncoder(s)
}
-func marshalTwoDigits(out *forkableWriter, v int) (err error) {
- err = out.WriteByte(byte('0' + (v/10)%10))
- if err != nil {
- return
- }
- return out.WriteByte(byte('0' + v%10))
+func appendTwoDigits(dst []byte, v int) []byte {
+ return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
}
-func marshalFourDigits(out *forkableWriter, v int) (err error) {
+func appendFourDigits(dst []byte, v int) []byte {
var bytes [4]byte
for i := range bytes {
bytes[3-i] = '0' + byte(v%10)
v /= 10
}
- _, err = out.Write(bytes[:])
- return
+ return append(dst, bytes[:]...)
}
func outsideUTCRange(t time.Time) bool {
@@ -310,80 +304,75 @@ func outsideUTCRange(t time.Time) bool {
return year < 1950 || year >= 2050
}
-func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
+func makeUTCTime(t time.Time) (e encoder, err error) {
+ dst := make([]byte, 0, 18)
+
+ dst, err = appendUTCTime(dst, t)
+ if err != nil {
+ return nil, err
+ }
+
+ return bytesEncoder(dst), nil
+}
+
+func makeGeneralizedTime(t time.Time) (e encoder, err error) {
+ dst := make([]byte, 0, 20)
+
+ dst, err = appendGeneralizedTime(dst, t)
+ if err != nil {
+ return nil, err
+ }
+
+ return bytesEncoder(dst), nil
+}
+
+func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
switch {
case 1950 <= year && year < 2000:
- err = marshalTwoDigits(out, year-1900)
+ dst = appendTwoDigits(dst, year-1900)
case 2000 <= year && year < 2050:
- err = marshalTwoDigits(out, year-2000)
+ dst = appendTwoDigits(dst, year-2000)
default:
- return StructuralError{"cannot represent time as UTCTime"}
- }
- if err != nil {
- return
+ return nil, StructuralError{"cannot represent time as UTCTime"}
}
- return marshalTimeCommon(out, t)
+ return appendTimeCommon(dst, t), nil
}
-func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
+func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
year := t.Year()
if year < 0 || year > 9999 {
- return StructuralError{"cannot represent time as GeneralizedTime"}
- }
- if err = marshalFourDigits(out, year); err != nil {
- return
+ return nil, StructuralError{"cannot represent time as GeneralizedTime"}
}
- return marshalTimeCommon(out, t)
+ dst = appendFourDigits(dst, year)
+
+ return appendTimeCommon(dst, t), nil
}
-func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
+func appendTimeCommon(dst []byte, t time.Time) []byte {
_, month, day := t.Date()
- err = marshalTwoDigits(out, int(month))
- if err != nil {
- return
- }
-
- err = marshalTwoDigits(out, day)
- if err != nil {
- return
- }
+ dst = appendTwoDigits(dst, int(month))
+ dst = appendTwoDigits(dst, day)
hour, min, sec := t.Clock()
- err = marshalTwoDigits(out, hour)
- if err != nil {
- return
- }
-
- err = marshalTwoDigits(out, min)
- if err != nil {
- return
- }
-
- err = marshalTwoDigits(out, sec)
- if err != nil {
- return
- }
+ dst = appendTwoDigits(dst, hour)
+ dst = appendTwoDigits(dst, min)
+ dst = appendTwoDigits(dst, sec)
_, offset := t.Zone()
switch {
case offset/60 == 0:
- err = out.WriteByte('Z')
- return
+ return append(dst, 'Z')
case offset > 0:
- err = out.WriteByte('+')
+ dst = append(dst, '+')
case offset < 0:
- err = out.WriteByte('-')
- }
-
- if err != nil {
- return
+ dst = append(dst, '-')
}
offsetMinutes := offset / 60
@@ -391,13 +380,10 @@ func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
offsetMinutes = -offsetMinutes
}
- err = marshalTwoDigits(out, offsetMinutes/60)
- if err != nil {
- return
- }
+ dst = appendTwoDigits(dst, offsetMinutes/60)
+ dst = appendTwoDigits(dst, offsetMinutes%60)
- err = marshalTwoDigits(out, offsetMinutes%60)
- return
+ return dst
}
func stripTagAndLength(in []byte) []byte {
@@ -408,114 +394,124 @@ func stripTagAndLength(in []byte) []byte {
return in[offset:]
}
-func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
+func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
switch value.Type() {
case flagType:
- return nil
+ return bytesEncoder(nil), nil
case timeType:
t := value.Interface().(time.Time)
if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
- return marshalGeneralizedTime(out, t)
- } else {
- return marshalUTCTime(out, t)
+ return makeGeneralizedTime(t)
}
+ return makeUTCTime(t)
case bitStringType:
- return marshalBitString(out, value.Interface().(BitString))
+ return bitStringEncoder(value.Interface().(BitString)), nil
case objectIdentifierType:
- return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
+ return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
case bigIntType:
- return marshalBigInt(out, value.Interface().(*big.Int))
+ return makeBigInt(value.Interface().(*big.Int)), nil
}
switch v := value; v.Kind() {
case reflect.Bool:
if v.Bool() {
- return out.WriteByte(255)
- } else {
- return out.WriteByte(0)
+ return byteFFEncoder, nil
}
+ return byte00Encoder, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return marshalInt64(out, v.Int())
+ return int64Encoder(v.Int()), nil
case reflect.Struct:
t := v.Type()
startingField := 0
+ n := t.NumField()
+ if n == 0 {
+ return bytesEncoder(nil), nil
+ }
+
// If the first element of the structure is a non-empty
// RawContents, then we don't bother serializing the rest.
- if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
+ if t.Field(0).Type == rawContentsType {
s := v.Field(0)
if s.Len() > 0 {
- bytes := make([]byte, s.Len())
- for i := 0; i < s.Len(); i++ {
- bytes[i] = uint8(s.Index(i).Uint())
- }
+ bytes := s.Bytes()
/* The RawContents will contain the tag and
* length fields but we'll also be writing
* those ourselves, so we strip them out of
* bytes */
- _, err = out.Write(stripTagAndLength(bytes))
- return
- } else {
- startingField = 1
+ return bytesEncoder(stripTagAndLength(bytes)), nil
}
+
+ startingField = 1
}
- for i := startingField; i < t.NumField(); i++ {
- var pre *forkableWriter
- pre, out = out.fork()
- err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
- if err != nil {
- return
+ switch n1 := n - startingField; n1 {
+ case 0:
+ return bytesEncoder(nil), nil
+ case 1:
+ return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
+ default:
+ m := make([]encoder, n1)
+ for i := 0; i < n1; i++ {
+ m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
+ if err != nil {
+ return nil, err
+ }
}
+
+ return multiEncoder(m), nil
}
- return
case reflect.Slice:
sliceType := v.Type()
if sliceType.Elem().Kind() == reflect.Uint8 {
- bytes := make([]byte, v.Len())
- for i := 0; i < v.Len(); i++ {
- bytes[i] = uint8(v.Index(i).Uint())
- }
- _, err = out.Write(bytes)
- return
+ return bytesEncoder(v.Bytes()), nil
}
var fp fieldParameters
- for i := 0; i < v.Len(); i++ {
- var pre *forkableWriter
- pre, out = out.fork()
- err = marshalField(pre, v.Index(i), fp)
- if err != nil {
- return
+
+ switch l := v.Len(); l {
+ case 0:
+ return bytesEncoder(nil), nil
+ case 1:
+ return makeField(v.Index(0), fp)
+ default:
+ m := make([]encoder, l)
+
+ for i := 0; i < l; i++ {
+ m[i], err = makeField(v.Index(i), fp)
+ if err != nil {
+ return nil, err
+ }
}
+
+ return multiEncoder(m), nil
}
- return
case reflect.String:
switch params.stringType {
case TagIA5String:
- return marshalIA5String(out, v.String())
+ return makeIA5String(v.String())
case TagPrintableString:
- return marshalPrintableString(out, v.String())
+ return makePrintableString(v.String())
default:
- return marshalUTF8String(out, v.String())
+ return makeUTF8String(v.String()), nil
}
}
- return StructuralError{"unknown Go type"}
+ return nil, StructuralError{"unknown Go type"}
}
-func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
+func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
if !v.IsValid() {
- return fmt.Errorf("asn1: cannot marshal nil value")
+ return nil, fmt.Errorf("asn1: cannot marshal nil value")
}
// If the field is an interface{} then recurse into it.
if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
- return marshalField(out, v.Elem(), params)
+ return makeField(v.Elem(), params)
}
if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
- return
+ return bytesEncoder(nil), nil
}
if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
@@ -523,7 +519,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
defaultValue.SetInt(*params.defaultValue)
if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
- return
+ return bytesEncoder(nil), nil
}
}
@@ -532,37 +528,36 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
// behaviour, but it's what Go has traditionally done.
if params.optional && params.defaultValue == nil {
if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
- return
+ return bytesEncoder(nil), nil
}
}
if v.Type() == rawValueType {
rv := v.Interface().(RawValue)
if len(rv.FullBytes) != 0 {
- _, err = out.Write(rv.FullBytes)
- } else {
- err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
- if err != nil {
- return
- }
- _, err = out.Write(rv.Bytes)
+ return bytesEncoder(rv.FullBytes), nil
}
- return
+
+ t := new(taggedEncoder)
+
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
+ t.body = bytesEncoder(rv.Bytes)
+
+ return t, nil
}
tag, isCompound, ok := getUniversalType(v.Type())
if !ok {
- err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
- return
+ return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
}
class := ClassUniversal
if params.timeType != 0 && tag != TagUTCTime {
- return StructuralError{"explicit time type given to non-time member"}
+ return nil, StructuralError{"explicit time type given to non-time member"}
}
if params.stringType != 0 && tag != TagPrintableString {
- return StructuralError{"explicit string type given to non-string member"}
+ return nil, StructuralError{"explicit string type given to non-string member"}
}
switch tag {
@@ -574,7 +569,7 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
for _, r := range v.String() {
if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
if !utf8.ValidString(v.String()) {
- return errors.New("asn1: string not valid UTF-8")
+ return nil, errors.New("asn1: string not valid UTF-8")
}
tag = TagUTF8String
break
@@ -591,46 +586,46 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
if params.set {
if tag != TagSequence {
- return StructuralError{"non sequence tagged as set"}
+ return nil, StructuralError{"non sequence tagged as set"}
}
tag = TagSet
}
- tags, body := out.fork()
+ t := new(taggedEncoder)
- err = marshalBody(body, v, params)
+ t.body, err = makeBody(v, params)
if err != nil {
- return
+ return nil, err
}
- bodyLen := body.Len()
+ bodyLen := t.body.Len()
- var explicitTag *forkableWriter
if params.explicit {
- explicitTag, tags = tags.fork()
- }
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
- if !params.explicit && params.tag != nil {
- // implicit tag.
- tag = *params.tag
- class = ClassContextSpecific
- }
+ tt := new(taggedEncoder)
- err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
- if err != nil {
- return
- }
+ tt.body = t
- if params.explicit {
- err = marshalTagAndLength(explicitTag, tagAndLength{
+ tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
class: ClassContextSpecific,
tag: *params.tag,
- length: bodyLen + tags.Len(),
+ length: bodyLen + t.tag.Len(),
isCompound: true,
- })
+ }))
+
+ return tt, nil
+ }
+
+ if params.tag != nil {
+ // implicit tag.
+ tag = *params.tag
+ class = ClassContextSpecific
}
- return err
+ t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
+
+ return t, nil
}
// Marshal returns the ASN.1 encoding of val.
@@ -643,13 +638,11 @@ func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters)
// printable: causes strings to be marshaled as ASN.1, PrintableString strings.
// utf8: causes strings to be marshaled as ASN.1, UTF8 strings
func Marshal(val interface{}) ([]byte, error) {
- var out bytes.Buffer
- v := reflect.ValueOf(val)
- f := newForkableWriter()
- err := marshalField(f, v, fieldParameters{})
+ e, err := makeField(reflect.ValueOf(val), fieldParameters{})
if err != nil {
return nil, err
}
- _, err = f.writeTo(&out)
- return out.Bytes(), err
+ b := make([]byte, e.Len())
+ e.Encode(b)
+ return b, nil
}
diff --git a/src/encoding/asn1/marshal_test.go b/src/encoding/asn1/marshal_test.go
index cdca8aa336..6af770fcc3 100644
--- a/src/encoding/asn1/marshal_test.go
+++ b/src/encoding/asn1/marshal_test.go
@@ -173,3 +173,13 @@ func TestInvalidUTF8(t *testing.T) {
t.Errorf("invalid UTF8 string was accepted")
}
}
+
+func BenchmarkMarshal(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ for _, test := range marshalTests {
+ Marshal(test.in)
+ }
+ }
+}