From c0f5de1ad275cf9a2d7f572be76e7b10fcd3e48f Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 16 Aug 2016 16:03:00 -0700 Subject: [release-branch.go1.7] compress/flate: make huffmanBitWriter errors persistent For persistent error handling, the methods of huffmanBitWriter have to be consistent about how they check errors. It must either consistently check error *before* every operation OR immediately *after* every operation. Since most of the current logic uses the previous approach, we apply the same style of error checking to writeBits and all calls to Write such that they only operate if w.err is already nil going into them. The error handling approach is brittle and easily broken by future commits to the code. In the near future, we should switch the logic to use panic at the lowest levels and a recover at the edge of the public API to ensure that errors are always persistent. Fixes #16749 Change-Id: Ie1d83e4ed8842f6911a31e23311cd3cbf38abe8c Reviewed-on: https://go-review.googlesource.com/27200 Reviewed-by: Matthew Dempsky Reviewed-by: Brad Fitzpatrick Reviewed-on: https://go-review.googlesource.com/28634 Reviewed-by: Joe Tsai Run-TryBot: Brad Fitzpatrick --- src/compress/flate/deflate.go | 2 +- src/compress/flate/deflate_test.go | 50 ++++++++++++++++++++++++++++++++ src/compress/flate/huffman_bit_writer.go | 40 ++++++++++++++++--------- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/src/compress/flate/deflate.go b/src/compress/flate/deflate.go index 3e4dc7b57e..9f53d51a6e 100644 --- a/src/compress/flate/deflate.go +++ b/src/compress/flate/deflate.go @@ -724,7 +724,7 @@ func (w *Writer) Close() error { // the result of NewWriter or NewWriterDict called with dst // and w's level and dictionary. func (w *Writer) Reset(dst io.Writer) { - if dw, ok := w.d.w.w.(*dictWriter); ok { + if dw, ok := w.d.w.writer.(*dictWriter); ok { // w was created with NewWriterDict dw.w = dst w.d.reset(dw) diff --git a/src/compress/flate/deflate_test.go b/src/compress/flate/deflate_test.go index 27a3b3823a..3322c40845 100644 --- a/src/compress/flate/deflate_test.go +++ b/src/compress/flate/deflate_test.go @@ -6,6 +6,7 @@ package flate import ( "bytes" + "errors" "fmt" "internal/testenv" "io" @@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) { } } } + +var errIO = errors.New("IO error") + +// failWriter fails with errIO exactly at the nth call to Write. +type failWriter struct{ n int } + +func (w *failWriter) Write(b []byte) (int, error) { + w.n-- + if w.n == -1 { + return 0, errIO + } + return len(b), nil +} + +func TestWriterPersistentError(t *testing.T) { + d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + d = d[:10000] // Keep this test short + + zw, err := NewWriter(nil, DefaultCompression) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + + // Sweep over the threshold at which an error is returned. + // The variable i makes it such that the ith call to failWriter.Write will + // return errIO. Since failWriter errors are not persistent, we must ensure + // that flate.Writer errors are persistent. + for i := 0; i < 1000; i++ { + fw := &failWriter{i} + zw.Reset(fw) + + _, werr := zw.Write(d) + cerr := zw.Close() + if werr != errIO && werr != nil { + t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO) + } + if cerr != errIO && fw.n < 0 { + t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO) + } + if fw.n >= 0 { + // At this point, the failure threshold was sufficiently high enough + // that we wrote the whole stream without any errors. + return + } + } +} diff --git a/src/compress/flate/huffman_bit_writer.go b/src/compress/flate/huffman_bit_writer.go index c4adef9ff5..d8b5a3ebd7 100644 --- a/src/compress/flate/huffman_bit_writer.go +++ b/src/compress/flate/huffman_bit_writer.go @@ -77,7 +77,11 @@ var offsetBase = []uint32{ var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} type huffmanBitWriter struct { - w io.Writer + // writer is the underlying writer. + // Do not use it directly; use the write method, which ensures + // that Write errors are sticky. + writer io.Writer + // Data waiting to be written is bytes[0:nbytes] // and then the low nbits of bits. bits uint64 @@ -96,7 +100,7 @@ type huffmanBitWriter struct { func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { return &huffmanBitWriter{ - w: w, + writer: w, literalFreq: make([]int32, maxNumLit), offsetFreq: make([]int32, offsetCodeCount), codegen: make([]uint8, maxNumLit+offsetCodeCount+1), @@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { } func (w *huffmanBitWriter) reset(writer io.Writer) { - w.w = writer + w.writer = writer w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil w.bytes = [bufferSize]byte{} } @@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() { n++ } w.bits = 0 - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) w.nbytes = 0 } +func (w *huffmanBitWriter) write(b []byte) { + if w.err != nil { + return + } + _, w.err = w.writer.Write(b) +} + func (w *huffmanBitWriter) writeBits(b int32, nb uint) { + if w.err != nil { + return + } w.bits |= uint64(b) << w.nbits w.nbits += nb if w.nbits >= 48 { @@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) { bytes[5] = byte(bits >> 40) n += 6 if n >= bufferFlushSize { - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) n = 0 } w.nbytes = n @@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) { n++ } if n != 0 { - _, w.err = w.w.Write(w.bytes[:n]) - if w.err != nil { - return - } + w.write(w.bytes[:n]) } w.nbytes = 0 - _, w.err = w.w.Write(bytes) + w.write(bytes) } // RFC 1951 3.2.7 specifies a special run-length encoding for specifying @@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) { bytes[5] = byte(bits >> 40) n += 6 if n >= bufferFlushSize { - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) n = 0 } w.nbytes = n @@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets // writeTokens writes a slice of tokens to the output. // codes for literal and offset encoding must be supplied. func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { + if w.err != nil { + return + } for _, t := range tokens { if t < matchType { w.writeCode(leCodes[t.literal()]) @@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) { if n < bufferFlushSize { continue } - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) if w.err != nil { - return + return // Return early in the event of write failures } n = 0 } -- cgit v1.2.3-54-g00ecf