aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compress/flate/deflate.go2
-rw-r--r--src/compress/flate/deflate_test.go50
-rw-r--r--src/compress/flate/huffman_bit_writer.go40
3 files changed, 78 insertions, 14 deletions
diff --git a/src/compress/flate/deflate.go b/src/compress/flate/deflate.go
index 3e4dc7b57e..9f53d51a6e 100644
--- a/src/compress/flate/deflate.go
+++ b/src/compress/flate/deflate.go
@@ -724,7 +724,7 @@ func (w *Writer) Close() error {
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
- if dw, ok := w.d.w.w.(*dictWriter); ok {
+ if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
diff --git a/src/compress/flate/deflate_test.go b/src/compress/flate/deflate_test.go
index 27a3b3823a..3322c40845 100644
--- a/src/compress/flate/deflate_test.go
+++ b/src/compress/flate/deflate_test.go
@@ -6,6 +6,7 @@ package flate
import (
"bytes"
+ "errors"
"fmt"
"internal/testenv"
"io"
@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
}
}
}
+
+var errIO = errors.New("IO error")
+
+// failWriter fails with errIO exactly at the nth call to Write.
+type failWriter struct{ n int }
+
+func (w *failWriter) Write(b []byte) (int, error) {
+ w.n--
+ if w.n == -1 {
+ return 0, errIO
+ }
+ return len(b), nil
+}
+
+func TestWriterPersistentError(t *testing.T) {
+ d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
+ if err != nil {
+ t.Fatalf("ReadFile: %v", err)
+ }
+ d = d[:10000] // Keep this test short
+
+ zw, err := NewWriter(nil, DefaultCompression)
+ if err != nil {
+ t.Fatalf("NewWriter: %v", err)
+ }
+
+ // Sweep over the threshold at which an error is returned.
+ // The variable i makes it such that the ith call to failWriter.Write will
+ // return errIO. Since failWriter errors are not persistent, we must ensure
+ // that flate.Writer errors are persistent.
+ for i := 0; i < 1000; i++ {
+ fw := &failWriter{i}
+ zw.Reset(fw)
+
+ _, werr := zw.Write(d)
+ cerr := zw.Close()
+ if werr != errIO && werr != nil {
+ t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
+ }
+ if cerr != errIO && fw.n < 0 {
+ t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
+ }
+ if fw.n >= 0 {
+ // At this point, the failure threshold was sufficiently high enough
+ // that we wrote the whole stream without any errors.
+ return
+ }
+ }
+}
diff --git a/src/compress/flate/huffman_bit_writer.go b/src/compress/flate/huffman_bit_writer.go
index c4adef9ff5..d8b5a3ebd7 100644
--- a/src/compress/flate/huffman_bit_writer.go
+++ b/src/compress/flate/huffman_bit_writer.go
@@ -77,7 +77,11 @@ var offsetBase = []uint32{
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
type huffmanBitWriter struct {
- w io.Writer
+ // writer is the underlying writer.
+ // Do not use it directly; use the write method, which ensures
+ // that Write errors are sticky.
+ writer io.Writer
+
// Data waiting to be written is bytes[0:nbytes]
// and then the low nbits of bits.
bits uint64
@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{
- w: w,
+ writer: w,
literalFreq: make([]int32, maxNumLit),
offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
}
func (w *huffmanBitWriter) reset(writer io.Writer) {
- w.w = writer
+ w.writer = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [bufferSize]byte{}
}
@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
n++
}
w.bits = 0
- _, w.err = w.w.Write(w.bytes[:n])
+ w.write(w.bytes[:n])
w.nbytes = 0
}
+func (w *huffmanBitWriter) write(b []byte) {
+ if w.err != nil {
+ return
+ }
+ _, w.err = w.writer.Write(b)
+}
+
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
+ if w.err != nil {
+ return
+ }
w.bits |= uint64(b) << w.nbits
w.nbits += nb
if w.nbits >= 48 {
@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
- _, w.err = w.w.Write(w.bytes[:n])
+ w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
n++
}
if n != 0 {
- _, w.err = w.w.Write(w.bytes[:n])
- if w.err != nil {
- return
- }
+ w.write(w.bytes[:n])
}
w.nbytes = 0
- _, w.err = w.w.Write(bytes)
+ w.write(bytes)
}
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
- _, w.err = w.w.Write(w.bytes[:n])
+ w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
// writeTokens writes a slice of tokens to the output.
// codes for literal and offset encoding must be supplied.
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
+ if w.err != nil {
+ return
+ }
for _, t := range tokens {
if t < matchType {
w.writeCode(leCodes[t.literal()])
@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
if n < bufferFlushSize {
continue
}
- _, w.err = w.w.Write(w.bytes[:n])
+ w.write(w.bytes[:n])
if w.err != nil {
- return
+ return // Return early in the event of write failures
}
n = 0
}