aboutsummaryrefslogtreecommitdiff
path: root/src/compress
diff options
context:
space:
mode:
authorAgniva De Sarker <agnivade@yahoo.co.in>2020-11-27 11:25:00 +0530
committerAgniva De Sarker <agniva.quicksilver@gmail.com>2021-03-17 07:01:00 +0000
commit72b501cb0326f62a20621636942e5a95fc3c1466 (patch)
tree4b0332323b2c1085f970185bf5bd99fc59c16096 /src/compress
parent119d76d98e4a367ad9b35dfa436692e2976c99aa (diff)
downloadgo-72b501cb0326f62a20621636942e5a95fc3c1466.tar.gz
go-72b501cb0326f62a20621636942e5a95fc3c1466.zip
compress/lzw: add Reset method to Reader and Writer
We add a Reset method which clears any internal state of an encoder or a decoder to let it be reused again as a new Writer or Reader respectively. We also export the encoder and decoder structs, renaming them to be Reader and Writer, and we guarantee that the underlying types from the constructors will always be Reader and Writer respectively. Benchmark results by reusing the encoder: on cpu: Intel(R) Core(TM) i5-8265U CPU @ 1.60GHz name time/op Decoder/1e4-8 93.6µs ± 1% Decoder/1e-Reuse4-8 87.7µs ± 1% Decoder/1e5-8 877µs ± 1% Decoder/1e-Reuse5-8 860µs ± 3% Decoder/1e6-8 8.79ms ± 1% Decoder/1e-Reuse6-8 8.82ms ± 4% Encoder/1e4-8 168µs ± 2% Encoder/1e-Reuse4-8 160µs ± 1% Encoder/1e5-8 1.64ms ± 1% Encoder/1e-Reuse5-8 1.61ms ± 2% Encoder/1e6-8 16.2ms ± 6% Encoder/1e-Reuse6-8 15.8ms ± 2% name speed Decoder/1e4-8 107MB/s ± 1% Decoder/1e-Reuse4-8 114MB/s ± 1% Decoder/1e5-8 114MB/s ± 1% Decoder/1e-Reuse5-8 116MB/s ± 3% Decoder/1e6-8 114MB/s ± 1% Decoder/1e-Reuse6-8 113MB/s ± 5% Encoder/1e4-8 59.7MB/s ± 2% Encoder/1e-Reuse4-8 62.4MB/s ± 1% Encoder/1e5-8 61.1MB/s ± 1% Encoder/1e-Reuse5-8 62.0MB/s ± 2% Encoder/1e6-8 61.7MB/s ± 5% Encoder/1e-Reuse6-8 63.4MB/s ± 2% name alloc/op Decoder/1e4-8 21.8kB ± 0% Decoder/1e-Reuse4-8 50.0B ± 0% Decoder/1e5-8 21.8kB ± 0% Decoder/1e-Reuse5-8 70.4B ± 2% Decoder/1e6-8 21.9kB ± 0% Decoder/1e-Reuse6-8 271B ± 3% Encoder/1e4-8 77.9kB ± 0% Encoder/1e-Reuse4-8 4.17kB ± 0% Encoder/1e5-8 77.9kB ± 0% Encoder/1e-Reuse5-8 4.27kB ± 0% Encoder/1e6-8 77.9kB ± 0% Encoder/1e-Reuse6-8 5.22kB ± 0% name allocs/op Decoder/1e4-8 2.00 ± 0% Decoder/1e-Reuse4-8 1.00 ± 0% Decoder/1e5-8 2.00 ± 0% Decoder/1e-Reuse5-8 1.00 ± 0% Decoder/1e6-8 2.00 ± 0% Decoder/1e-Reuse6-8 1.00 ± 0% Encoder/1e4-8 3.00 ± 0% Encoder/1e-Reuse4-8 2.00 ± 0% Encoder/1e5-8 3.00 ± 0% Encoder/1e-Reuse5-8 2.00 ± 0% Encoder/1e6-8 3.00 ± 0% Encoder/1e-Reuse6-8 2.00 ± 0% Fixes #26535 Change-Id: Icde613fea6234a5bdce95f1e49910f5687e30b22 Reviewed-on: https://go-review.googlesource.com/c/go/+/273667 Trust: Agniva De Sarker <agniva.quicksilver@gmail.com> Trust: Joe Tsai <thebrokentoaster@gmail.com> Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Diffstat (limited to 'src/compress')
-rw-r--r--src/compress/lzw/reader.go207
-rw-r--r--src/compress/lzw/reader_test.go88
-rw-r--r--src/compress/lzw/writer.go206
-rw-r--r--src/compress/lzw/writer_test.go54
4 files changed, 349 insertions, 206 deletions
diff --git a/src/compress/lzw/reader.go b/src/compress/lzw/reader.go
index f08021190c..952870a56a 100644
--- a/src/compress/lzw/reader.go
+++ b/src/compress/lzw/reader.go
@@ -42,15 +42,15 @@ const (
flushBuffer = 1 << maxWidth
)
-// decoder is the state from which the readXxx method converts a byte
-// stream into a code stream.
-type decoder struct {
+// Reader is an io.Reader which can be used to read compressed data in the
+// LZW format.
+type Reader struct {
r io.ByteReader
bits uint32
nBits uint
width uint
- read func(*decoder) (uint16, error) // readLSB or readMSB
- litWidth int // width in bits of literal codes
+ read func(*Reader) (uint16, error) // readLSB or readMSB
+ litWidth int // width in bits of literal codes
err error
// The first 1<<litWidth codes are literal codes.
@@ -87,148 +87,158 @@ type decoder struct {
}
// readLSB returns the next code for "Least Significant Bits first" data.
-func (d *decoder) readLSB() (uint16, error) {
- for d.nBits < d.width {
- x, err := d.r.ReadByte()
+func (r *Reader) readLSB() (uint16, error) {
+ for r.nBits < r.width {
+ x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
- d.bits |= uint32(x) << d.nBits
- d.nBits += 8
+ r.bits |= uint32(x) << r.nBits
+ r.nBits += 8
}
- code := uint16(d.bits & (1<<d.width - 1))
- d.bits >>= d.width
- d.nBits -= d.width
+ code := uint16(r.bits & (1<<r.width - 1))
+ r.bits >>= r.width
+ r.nBits -= r.width
return code, nil
}
// readMSB returns the next code for "Most Significant Bits first" data.
-func (d *decoder) readMSB() (uint16, error) {
- for d.nBits < d.width {
- x, err := d.r.ReadByte()
+func (r *Reader) readMSB() (uint16, error) {
+ for r.nBits < r.width {
+ x, err := r.r.ReadByte()
if err != nil {
return 0, err
}
- d.bits |= uint32(x) << (24 - d.nBits)
- d.nBits += 8
+ r.bits |= uint32(x) << (24 - r.nBits)
+ r.nBits += 8
}
- code := uint16(d.bits >> (32 - d.width))
- d.bits <<= d.width
- d.nBits -= d.width
+ code := uint16(r.bits >> (32 - r.width))
+ r.bits <<= r.width
+ r.nBits -= r.width
return code, nil
}
-func (d *decoder) Read(b []byte) (int, error) {
+// Read implements io.Reader, reading uncompressed bytes from its underlying Reader.
+func (r *Reader) Read(b []byte) (int, error) {
for {
- if len(d.toRead) > 0 {
- n := copy(b, d.toRead)
- d.toRead = d.toRead[n:]
+ if len(r.toRead) > 0 {
+ n := copy(b, r.toRead)
+ r.toRead = r.toRead[n:]
return n, nil
}
- if d.err != nil {
- return 0, d.err
+ if r.err != nil {
+ return 0, r.err
}
- d.decode()
+ r.decode()
}
}
// decode decompresses bytes from r and leaves them in d.toRead.
// read specifies how to decode bytes into codes.
// litWidth is the width in bits of literal codes.
-func (d *decoder) decode() {
+func (r *Reader) decode() {
// Loop over the code stream, converting codes into decompressed bytes.
loop:
for {
- code, err := d.read(d)
+ code, err := r.read(r)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
- d.err = err
+ r.err = err
break
}
switch {
- case code < d.clear:
+ case code < r.clear:
// We have a literal code.
- d.output[d.o] = uint8(code)
- d.o++
- if d.last != decoderInvalidCode {
+ r.output[r.o] = uint8(code)
+ r.o++
+ if r.last != decoderInvalidCode {
// Save what the hi code expands to.
- d.suffix[d.hi] = uint8(code)
- d.prefix[d.hi] = d.last
+ r.suffix[r.hi] = uint8(code)
+ r.prefix[r.hi] = r.last
}
- case code == d.clear:
- d.width = 1 + uint(d.litWidth)
- d.hi = d.eof
- d.overflow = 1 << d.width
- d.last = decoderInvalidCode
+ case code == r.clear:
+ r.width = 1 + uint(r.litWidth)
+ r.hi = r.eof
+ r.overflow = 1 << r.width
+ r.last = decoderInvalidCode
continue
- case code == d.eof:
- d.err = io.EOF
+ case code == r.eof:
+ r.err = io.EOF
break loop
- case code <= d.hi:
- c, i := code, len(d.output)-1
- if code == d.hi && d.last != decoderInvalidCode {
+ case code <= r.hi:
+ c, i := code, len(r.output)-1
+ if code == r.hi && r.last != decoderInvalidCode {
// code == hi is a special case which expands to the last expansion
// followed by the head of the last expansion. To find the head, we walk
// the prefix chain until we find a literal code.
- c = d.last
- for c >= d.clear {
- c = d.prefix[c]
+ c = r.last
+ for c >= r.clear {
+ c = r.prefix[c]
}
- d.output[i] = uint8(c)
+ r.output[i] = uint8(c)
i--
- c = d.last
+ c = r.last
}
// Copy the suffix chain into output and then write that to w.
- for c >= d.clear {
- d.output[i] = d.suffix[c]
+ for c >= r.clear {
+ r.output[i] = r.suffix[c]
i--
- c = d.prefix[c]
+ c = r.prefix[c]
}
- d.output[i] = uint8(c)
- d.o += copy(d.output[d.o:], d.output[i:])
- if d.last != decoderInvalidCode {
+ r.output[i] = uint8(c)
+ r.o += copy(r.output[r.o:], r.output[i:])
+ if r.last != decoderInvalidCode {
// Save what the hi code expands to.
- d.suffix[d.hi] = uint8(c)
- d.prefix[d.hi] = d.last
+ r.suffix[r.hi] = uint8(c)
+ r.prefix[r.hi] = r.last
}
default:
- d.err = errors.New("lzw: invalid code")
+ r.err = errors.New("lzw: invalid code")
break loop
}
- d.last, d.hi = code, d.hi+1
- if d.hi >= d.overflow {
- if d.hi > d.overflow {
+ r.last, r.hi = code, r.hi+1
+ if r.hi >= r.overflow {
+ if r.hi > r.overflow {
panic("unreachable")
}
- if d.width == maxWidth {
- d.last = decoderInvalidCode
+ if r.width == maxWidth {
+ r.last = decoderInvalidCode
// Undo the d.hi++ a few lines above, so that (1) we maintain
// the invariant that d.hi < d.overflow, and (2) d.hi does not
// eventually overflow a uint16.
- d.hi--
+ r.hi--
} else {
- d.width++
- d.overflow = 1 << d.width
+ r.width++
+ r.overflow = 1 << r.width
}
}
- if d.o >= flushBuffer {
+ if r.o >= flushBuffer {
break
}
}
// Flush pending output.
- d.toRead = d.output[:d.o]
- d.o = 0
+ r.toRead = r.output[:r.o]
+ r.o = 0
}
var errClosed = errors.New("lzw: reader/writer is closed")
-func (d *decoder) Close() error {
- d.err = errClosed // in case any Reads come along
+// Close closes the Reader and returns an error for any future read operation.
+// It does not close the underlying io.Reader.
+func (r *Reader) Close() error {
+ r.err = errClosed // in case any Reads come along
return nil
}
+// Reset clears the Reader's state and allows it to be reused again
+// as a new Reader.
+func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
+ *r = Reader{}
+ r.init(src, order, litWidth)
+}
+
// NewReader creates a new io.ReadCloser.
// Reads from the returned io.ReadCloser read and decompress data from r.
// If r does not also implement io.ByteReader,
@@ -238,32 +248,43 @@ func (d *decoder) Close() error {
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. It must equal the litWidth
// used during compression.
+//
+// It is guaranteed that the underlying type of the returned io.ReadCloser
+// is a *Reader.
func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
- d := new(decoder)
+ return newReader(r, order, litWidth)
+}
+
+func newReader(src io.Reader, order Order, litWidth int) *Reader {
+ r := new(Reader)
+ r.init(src, order, litWidth)
+ return r
+}
+
+func (r *Reader) init(src io.Reader, order Order, litWidth int) {
switch order {
case LSB:
- d.read = (*decoder).readLSB
+ r.read = (*Reader).readLSB
case MSB:
- d.read = (*decoder).readMSB
+ r.read = (*Reader).readMSB
default:
- d.err = errors.New("lzw: unknown order")
- return d
+ r.err = errors.New("lzw: unknown order")
+ return
}
if litWidth < 2 || 8 < litWidth {
- d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
- return d
+ r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+ return
}
- if br, ok := r.(io.ByteReader); ok {
- d.r = br
- } else {
- d.r = bufio.NewReader(r)
- }
- d.litWidth = litWidth
- d.width = 1 + uint(litWidth)
- d.clear = uint16(1) << uint(litWidth)
- d.eof, d.hi = d.clear+1, d.clear+1
- d.overflow = uint16(1) << d.width
- d.last = decoderInvalidCode
- return d
+ br, ok := src.(io.ByteReader)
+ if !ok && src != nil {
+ br = bufio.NewReader(src)
+ }
+ r.r = br
+ r.litWidth = litWidth
+ r.width = 1 + uint(litWidth)
+ r.clear = uint16(1) << uint(litWidth)
+ r.eof, r.hi = r.clear+1, r.clear+1
+ r.overflow = uint16(1) << r.width
+ r.last = decoderInvalidCode
}
diff --git a/src/compress/lzw/reader_test.go b/src/compress/lzw/reader_test.go
index d1eb76d042..9a2a477302 100644
--- a/src/compress/lzw/reader_test.go
+++ b/src/compress/lzw/reader_test.go
@@ -120,6 +120,53 @@ func TestReader(t *testing.T) {
}
}
+func TestReaderReset(t *testing.T) {
+ var b bytes.Buffer
+ for _, tt := range lzwTests {
+ d := strings.Split(tt.desc, ";")
+ var order Order
+ switch d[1] {
+ case "LSB":
+ order = LSB
+ case "MSB":
+ order = MSB
+ default:
+ t.Errorf("%s: bad order %q", tt.desc, d[1])
+ }
+ litWidth, _ := strconv.Atoi(d[2])
+ rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
+ defer rc.Close()
+ b.Reset()
+ n, err := io.Copy(&b, rc)
+ b1 := b.Bytes()
+ if err != nil {
+ if err != tt.err {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
+ }
+ if err == io.ErrUnexpectedEOF {
+ // Even if the input is truncated, we should still return the
+ // partial decoded result.
+ if n == 0 || !strings.HasPrefix(tt.raw, b.String()) {
+ t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, b.String(), tt.raw)
+ }
+ }
+ continue
+ }
+
+ b.Reset()
+ rc.(*Reader).Reset(strings.NewReader(tt.compressed), order, litWidth)
+ n, err = io.Copy(&b, rc)
+ b2 := b.Bytes()
+ if err != nil {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, nil)
+ continue
+ }
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("bytes read were not the same")
+ }
+ }
+}
+
type devZero struct{}
func (devZero) Read(p []byte) (int, error) {
@@ -131,7 +178,7 @@ func (devZero) Read(p []byte) (int, error) {
func TestHiCodeDoesNotOverflow(t *testing.T) {
r := NewReader(devZero{}, LSB, 8)
- d := r.(*decoder)
+ d := r.(*Reader)
buf := make([]byte, 1024)
oldHi := uint16(0)
for i := 0; i < 100; i++ {
@@ -226,28 +273,43 @@ func BenchmarkDecoder(b *testing.B) {
b.Fatalf("test file has no data")
}
+ getInputBuf := func(buf []byte, n int) []byte {
+ compressed := new(bytes.Buffer)
+ w := NewWriter(compressed, LSB, 8)
+ for i := 0; i < n; i += len(buf) {
+ if len(buf) > n-i {
+ buf = buf[:n-i]
+ }
+ w.Write(buf)
+ }
+ w.Close()
+ return compressed.Bytes()
+ }
+
for e := 4; e <= 6; e++ {
n := int(math.Pow10(e))
b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
b.StopTimer()
b.SetBytes(int64(n))
- buf0 := buf
- compressed := new(bytes.Buffer)
- w := NewWriter(compressed, LSB, 8)
- for i := 0; i < n; i += len(buf0) {
- if len(buf0) > n-i {
- buf0 = buf0[:n-i]
- }
- w.Write(buf0)
- }
- w.Close()
- buf1 := compressed.Bytes()
- buf0, compressed, w = nil, nil, nil
+ buf1 := getInputBuf(buf, n)
runtime.GC()
b.StartTimer()
for i := 0; i < b.N; i++ {
io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8))
}
})
+ b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
+ b.StopTimer()
+ b.SetBytes(int64(n))
+ buf1 := getInputBuf(buf, n)
+ runtime.GC()
+ b.StartTimer()
+ r := NewReader(bytes.NewReader(buf1), LSB, 8)
+ for i := 0; i < b.N; i++ {
+ io.Copy(io.Discard, r)
+ r.Close()
+ r.(*Reader).Reset(bytes.NewReader(buf1), LSB, 8)
+ }
+ })
}
}
diff --git a/src/compress/lzw/writer.go b/src/compress/lzw/writer.go
index 6ddb335f31..552bdc2ce1 100644
--- a/src/compress/lzw/writer.go
+++ b/src/compress/lzw/writer.go
@@ -17,19 +17,6 @@ type writer interface {
Flush() error
}
-// An errWriteCloser is an io.WriteCloser that always returns a given error.
-type errWriteCloser struct {
- err error
-}
-
-func (e *errWriteCloser) Write([]byte) (int, error) {
- return 0, e.err
-}
-
-func (e *errWriteCloser) Close() error {
- return e.err
-}
-
const (
// A code is a 12 bit value, stored as a uint32 when encoding to avoid
// type conversions when shifting bits.
@@ -44,14 +31,15 @@ const (
invalidEntry = 0
)
-// encoder is LZW compressor.
-type encoder struct {
+// Writer is an LZW compressor. It writes the compressed form of the data
+// to an underlying writer (see NewWriter).
+type Writer struct {
// w is the writer that compressed bytes are written to.
w writer
// order, write, bits, nBits and width are the state for
// converting a code stream into a byte stream.
order Order
- write func(*encoder, uint32) error
+ write func(*Writer, uint32) error
bits uint32
nBits uint
width uint
@@ -63,7 +51,7 @@ type encoder struct {
// savedCode is the accumulated code at the end of the most recent Write
// call. It is equal to invalidCode if there was no such call.
savedCode uint32
- // err is the first error encountered during writing. Closing the encoder
+ // err is the first error encountered during writing. Closing the writer
// will make any future Write calls return errClosed
err error
// table is the hash table from 20-bit keys to 12-bit values. Each table
@@ -74,80 +62,80 @@ type encoder struct {
}
// writeLSB writes the code c for "Least Significant Bits first" data.
-func (e *encoder) writeLSB(c uint32) error {
- e.bits |= c << e.nBits
- e.nBits += e.width
- for e.nBits >= 8 {
- if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+func (w *Writer) writeLSB(c uint32) error {
+ w.bits |= c << w.nBits
+ w.nBits += w.width
+ for w.nBits >= 8 {
+ if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
- e.bits >>= 8
- e.nBits -= 8
+ w.bits >>= 8
+ w.nBits -= 8
}
return nil
}
// writeMSB writes the code c for "Most Significant Bits first" data.
-func (e *encoder) writeMSB(c uint32) error {
- e.bits |= c << (32 - e.width - e.nBits)
- e.nBits += e.width
- for e.nBits >= 8 {
- if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
+func (w *Writer) writeMSB(c uint32) error {
+ w.bits |= c << (32 - w.width - w.nBits)
+ w.nBits += w.width
+ for w.nBits >= 8 {
+ if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
return err
}
- e.bits <<= 8
- e.nBits -= 8
+ w.bits <<= 8
+ w.nBits -= 8
}
return nil
}
-// errOutOfCodes is an internal error that means that the encoder has run out
+// errOutOfCodes is an internal error that means that the writer has run out
// of unused codes and a clear code needs to be sent next.
var errOutOfCodes = errors.New("lzw: out of codes")
// incHi increments e.hi and checks for both overflow and running out of
// unused codes. In the latter case, incHi sends a clear code, resets the
-// encoder state and returns errOutOfCodes.
-func (e *encoder) incHi() error {
- e.hi++
- if e.hi == e.overflow {
- e.width++
- e.overflow <<= 1
+// writer state and returns errOutOfCodes.
+func (w *Writer) incHi() error {
+ w.hi++
+ if w.hi == w.overflow {
+ w.width++
+ w.overflow <<= 1
}
- if e.hi == maxCode {
- clear := uint32(1) << e.litWidth
- if err := e.write(e, clear); err != nil {
+ if w.hi == maxCode {
+ clear := uint32(1) << w.litWidth
+ if err := w.write(w, clear); err != nil {
return err
}
- e.width = e.litWidth + 1
- e.hi = clear + 1
- e.overflow = clear << 1
- for i := range e.table {
- e.table[i] = invalidEntry
+ w.width = w.litWidth + 1
+ w.hi = clear + 1
+ w.overflow = clear << 1
+ for i := range w.table {
+ w.table[i] = invalidEntry
}
return errOutOfCodes
}
return nil
}
-// Write writes a compressed representation of p to e's underlying writer.
-func (e *encoder) Write(p []byte) (n int, err error) {
- if e.err != nil {
- return 0, e.err
+// Write writes a compressed representation of p to w's underlying writer.
+func (w *Writer) Write(p []byte) (n int, err error) {
+ if w.err != nil {
+ return 0, w.err
}
if len(p) == 0 {
return 0, nil
}
- if maxLit := uint8(1<<e.litWidth - 1); maxLit != 0xff {
+ if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
for _, x := range p {
if x > maxLit {
- e.err = errors.New("lzw: input byte too large for the litWidth")
- return 0, e.err
+ w.err = errors.New("lzw: input byte too large for the litWidth")
+ return 0, w.err
}
}
}
n = len(p)
- code := e.savedCode
+ code := w.savedCode
if code == invalidCode {
// The first code sent is always a literal code.
code, p = uint32(p[0]), p[1:]
@@ -159,77 +147,84 @@ loop:
// If there is a hash table hit for this key then we continue the loop
// and do not emit a code yet.
hash := (key>>12 ^ key) & tableMask
- for h, t := hash, e.table[hash]; t != invalidEntry; {
+ for h, t := hash, w.table[hash]; t != invalidEntry; {
if key == t>>12 {
code = t & maxCode
continue loop
}
h = (h + 1) & tableMask
- t = e.table[h]
+ t = w.table[h]
}
// Otherwise, write the current code, and literal becomes the start of
// the next emitted code.
- if e.err = e.write(e, code); e.err != nil {
- return 0, e.err
+ if w.err = w.write(w, code); w.err != nil {
+ return 0, w.err
}
code = literal
// Increment e.hi, the next implied code. If we run out of codes, reset
- // the encoder state (including clearing the hash table) and continue.
- if err1 := e.incHi(); err1 != nil {
+ // the writer state (including clearing the hash table) and continue.
+ if err1 := w.incHi(); err1 != nil {
if err1 == errOutOfCodes {
continue
}
- e.err = err1
- return 0, e.err
+ w.err = err1
+ return 0, w.err
}
// Otherwise, insert key -> e.hi into the map that e.table represents.
for {
- if e.table[hash] == invalidEntry {
- e.table[hash] = (key << 12) | e.hi
+ if w.table[hash] == invalidEntry {
+ w.table[hash] = (key << 12) | w.hi
break
}
hash = (hash + 1) & tableMask
}
}
- e.savedCode = code
+ w.savedCode = code
return n, nil
}
-// Close closes the encoder, flushing any pending output. It does not close or
-// flush e's underlying writer.
-func (e *encoder) Close() error {
- if e.err != nil {
- if e.err == errClosed {
+// Close closes the Writer, flushing any pending output. It does not close
+// w's underlying writer.
+func (w *Writer) Close() error {
+ if w.err != nil {
+ if w.err == errClosed {
return nil
}
- return e.err
+ return w.err
}
// Make any future calls to Write return errClosed.
- e.err = errClosed
+ w.err = errClosed
// Write the savedCode if valid.
- if e.savedCode != invalidCode {
- if err := e.write(e, e.savedCode); err != nil {
+ if w.savedCode != invalidCode {
+ if err := w.write(w, w.savedCode); err != nil {
return err
}
- if err := e.incHi(); err != nil && err != errOutOfCodes {
+ if err := w.incHi(); err != nil && err != errOutOfCodes {
return err
}
}
// Write the eof code.
- eof := uint32(1)<<e.litWidth + 1
- if err := e.write(e, eof); err != nil {
+ eof := uint32(1)<<w.litWidth + 1
+ if err := w.write(w, eof); err != nil {
return err
}
// Write the final bits.
- if e.nBits > 0 {
- if e.order == MSB {
- e.bits >>= 24
+ if w.nBits > 0 {
+ if w.order == MSB {
+ w.bits >>= 24
}
- if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+ if err := w.w.WriteByte(uint8(w.bits)); err != nil {
return err
}
}
- return e.w.Flush()
+ return w.w.Flush()
+}
+
+// Reset clears the Writer's state and allows it to be reused again
+// as a new Writer.
+func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
+ *w = Writer{}
+ w.init(dst, order, litWidth)
}
// NewWriter creates a new io.WriteCloser.
@@ -238,32 +233,43 @@ func (e *encoder) Close() error {
// finished writing.
// The number of bits to use for literal codes, litWidth, must be in the
// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
+//
+// It is guaranteed that the underlying type of the returned io.WriteCloser
+// is a *Writer.
func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
- var write func(*encoder, uint32) error
+ return newWriter(w, order, litWidth)
+}
+
+func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
+ w := new(Writer)
+ w.init(dst, order, litWidth)
+ return w
+}
+
+func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
switch order {
case LSB:
- write = (*encoder).writeLSB
+ w.write = (*Writer).writeLSB
case MSB:
- write = (*encoder).writeMSB
+ w.write = (*Writer).writeMSB
default:
- return &errWriteCloser{errors.New("lzw: unknown order")}
+ w.err = errors.New("lzw: unknown order")
+ return
}
if litWidth < 2 || 8 < litWidth {
- return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
+ w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+ return
}
- bw, ok := w.(writer)
- if !ok {
- bw = bufio.NewWriter(w)
+ bw, ok := dst.(writer)
+ if !ok && dst != nil {
+ bw = bufio.NewWriter(dst)
}
+ w.w = bw
lw := uint(litWidth)
- return &encoder{
- w: bw,
- order: order,
- write: write,
- width: 1 + lw,
- litWidth: lw,
- hi: 1<<lw + 1,
- overflow: 1 << (lw + 1),
- savedCode: invalidCode,
- }
+ w.order = order
+ w.width = 1 + lw
+ w.litWidth = lw
+ w.hi = 1<<lw + 1
+ w.overflow = 1 << (lw + 1)
+ w.savedCode = invalidCode
}
diff --git a/src/compress/lzw/writer_test.go b/src/compress/lzw/writer_test.go
index 1a5dbcae93..9f59c8bb18 100644
--- a/src/compress/lzw/writer_test.go
+++ b/src/compress/lzw/writer_test.go
@@ -5,6 +5,7 @@
package lzw
import (
+ "bytes"
"fmt"
"internal/testenv"
"io"
@@ -105,6 +106,50 @@ func TestWriter(t *testing.T) {
}
}
+func TestWriterReset(t *testing.T) {
+ for _, order := range [...]Order{LSB, MSB} {
+ t.Run(fmt.Sprintf("Order %d", order), func(t *testing.T) {
+ for litWidth := 6; litWidth <= 8; litWidth++ {
+ t.Run(fmt.Sprintf("LitWidth %d", litWidth), func(t *testing.T) {
+ var data []byte
+ if litWidth == 6 {
+ data = []byte{1, 2, 3}
+ } else {
+ data = []byte(`lorem ipsum dolor sit amet`)
+ }
+ var buf bytes.Buffer
+ w := NewWriter(&buf, order, litWidth)
+ if _, err := w.Write(data); err != nil {
+ t.Errorf("write: %v: %v", string(data), err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Errorf("close: %v", err)
+ }
+
+ b1 := buf.Bytes()
+ buf.Reset()
+
+ w.(*Writer).Reset(&buf, order, litWidth)
+
+ if _, err := w.Write(data); err != nil {
+ t.Errorf("write: %v: %v", string(data), err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Errorf("close: %v", err)
+ }
+ b2 := buf.Bytes()
+
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("bytes written were not same")
+ }
+ })
+ }
+ })
+ }
+}
+
func TestWriterReturnValues(t *testing.T) {
w := NewWriter(io.Discard, LSB, 8)
n, err := w.Write([]byte("asdf"))
@@ -152,5 +197,14 @@ func BenchmarkEncoder(b *testing.B) {
w.Close()
}
})
+ b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
+ b.SetBytes(int64(n))
+ w := NewWriter(io.Discard, LSB, 8)
+ for i := 0; i < b.N; i++ {
+ w.Write(buf1)
+ w.Close()
+ w.(*Writer).Reset(io.Discard, LSB, 8)
+ }
+ })
}
}