diff options
-rw-r--r-- | src/compress/flate/deflate.go | 2 | ||||
-rw-r--r-- | src/compress/flate/deflate_test.go | 50 | ||||
-rw-r--r-- | src/compress/flate/huffman_bit_writer.go | 40 |
3 files changed, 78 insertions, 14 deletions
diff --git a/src/compress/flate/deflate.go b/src/compress/flate/deflate.go index 3e4dc7b57e..9f53d51a6e 100644 --- a/src/compress/flate/deflate.go +++ b/src/compress/flate/deflate.go @@ -724,7 +724,7 @@ func (w *Writer) Close() error { // the result of NewWriter or NewWriterDict called with dst // and w's level and dictionary. func (w *Writer) Reset(dst io.Writer) { - if dw, ok := w.d.w.w.(*dictWriter); ok { + if dw, ok := w.d.w.writer.(*dictWriter); ok { // w was created with NewWriterDict dw.w = dst w.d.reset(dw) diff --git a/src/compress/flate/deflate_test.go b/src/compress/flate/deflate_test.go index 27a3b3823a..3322c40845 100644 --- a/src/compress/flate/deflate_test.go +++ b/src/compress/flate/deflate_test.go @@ -6,6 +6,7 @@ package flate import ( "bytes" + "errors" "fmt" "internal/testenv" "io" @@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) { } } } + +var errIO = errors.New("IO error") + +// failWriter fails with errIO exactly at the nth call to Write. +type failWriter struct{ n int } + +func (w *failWriter) Write(b []byte) (int, error) { + w.n-- + if w.n == -1 { + return 0, errIO + } + return len(b), nil +} + +func TestWriterPersistentError(t *testing.T) { + d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + d = d[:10000] // Keep this test short + + zw, err := NewWriter(nil, DefaultCompression) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + + // Sweep over the threshold at which an error is returned. + // The variable i makes it such that the ith call to failWriter.Write will + // return errIO. Since failWriter errors are not persistent, we must ensure + // that flate.Writer errors are persistent. + for i := 0; i < 1000; i++ { + fw := &failWriter{i} + zw.Reset(fw) + + _, werr := zw.Write(d) + cerr := zw.Close() + if werr != errIO && werr != nil { + t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO) + } + if cerr != errIO && fw.n < 0 { + t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO) + } + if fw.n >= 0 { + // At this point, the failure threshold was sufficiently high enough + // that we wrote the whole stream without any errors. + return + } + } +} diff --git a/src/compress/flate/huffman_bit_writer.go b/src/compress/flate/huffman_bit_writer.go index c4adef9ff5..d8b5a3ebd7 100644 --- a/src/compress/flate/huffman_bit_writer.go +++ b/src/compress/flate/huffman_bit_writer.go @@ -77,7 +77,11 @@ var offsetBase = []uint32{ var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} type huffmanBitWriter struct { - w io.Writer + // writer is the underlying writer. + // Do not use it directly; use the write method, which ensures + // that Write errors are sticky. + writer io.Writer + // Data waiting to be written is bytes[0:nbytes] // and then the low nbits of bits. bits uint64 @@ -96,7 +100,7 @@ type huffmanBitWriter struct { func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { return &huffmanBitWriter{ - w: w, + writer: w, literalFreq: make([]int32, maxNumLit), offsetFreq: make([]int32, offsetCodeCount), codegen: make([]uint8, maxNumLit+offsetCodeCount+1), @@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { } func (w *huffmanBitWriter) reset(writer io.Writer) { - w.w = writer + w.writer = writer w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil w.bytes = [bufferSize]byte{} } @@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() { n++ } w.bits = 0 - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) w.nbytes = 0 } +func (w *huffmanBitWriter) write(b []byte) { + if w.err != nil { + return + } + _, w.err = w.writer.Write(b) +} + func (w *huffmanBitWriter) writeBits(b int32, nb uint) { + if w.err != nil { + return + } w.bits |= uint64(b) << w.nbits w.nbits += nb if w.nbits >= 48 { @@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) { bytes[5] = byte(bits >> 40) n += 6 if n >= bufferFlushSize { - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) n = 0 } w.nbytes = n @@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) { n++ } if n != 0 { - _, w.err = w.w.Write(w.bytes[:n]) - if w.err != nil { - return - } + w.write(w.bytes[:n]) } w.nbytes = 0 - _, w.err = w.w.Write(bytes) + w.write(bytes) } // RFC 1951 3.2.7 specifies a special run-length encoding for specifying @@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) { bytes[5] = byte(bits >> 40) n += 6 if n >= bufferFlushSize { - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) n = 0 } w.nbytes = n @@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets // writeTokens writes a slice of tokens to the output. // codes for literal and offset encoding must be supplied. func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { + if w.err != nil { + return + } for _, t := range tokens { if t < matchType { w.writeCode(leCodes[t.literal()]) @@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) { if n < bufferFlushSize { continue } - _, w.err = w.w.Write(w.bytes[:n]) + w.write(w.bytes[:n]) if w.err != nil { - return + return // Return early in the event of write failures } n = 0 } |