summaryrefslogtreecommitdiff
path: root/src/common/compress_zstd.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/compress_zstd.c')
-rw-r--r--src/common/compress_zstd.c45
1 files changed, 37 insertions, 8 deletions
diff --git a/src/common/compress_zstd.c b/src/common/compress_zstd.c
index a136db48bf..0808bcd9ab 100644
--- a/src/common/compress_zstd.c
+++ b/src/common/compress_zstd.c
@@ -98,6 +98,8 @@ struct tor_zstd_compress_state_t {
#endif // HAVE_ZSTD.
int compress; /**< True if we are compressing; false if we are inflating */
+ int have_called_end; /**< True if we are compressing and we've called
+ * ZSTD_endStream */
/** Number of bytes read so far. Used to detect compression bombs. */
size_t input_so_far;
@@ -194,31 +196,41 @@ tor_zstd_compress_new(int compress,
result->u.compress_stream = ZSTD_createCStream();
if (result->u.compress_stream == NULL) {
- log_warn(LD_GENERAL, "Error while creating Zstandard stream");
+ // LCOV_EXCL_START
+ log_warn(LD_GENERAL, "Error while creating Zstandard compression "
+ "stream");
goto err;
+ // LCOV_EXCL_STOP
}
retval = ZSTD_initCStream(result->u.compress_stream, preset);
if (ZSTD_isError(retval)) {
+ // LCOV_EXCL_START
log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
ZSTD_getErrorName(retval));
goto err;
+ // LCOV_EXCL_STOP
}
} else {
result->u.decompress_stream = ZSTD_createDStream();
if (result->u.decompress_stream == NULL) {
- log_warn(LD_GENERAL, "Error while creating Zstandard stream");
+ // LCOV_EXCL_START
+ log_warn(LD_GENERAL, "Error while creating Zstandard decompression "
+ "stream");
goto err;
+ // LCOV_EXCL_STOP
}
retval = ZSTD_initDStream(result->u.decompress_stream);
if (ZSTD_isError(retval)) {
+ // LCOV_EXCL_START
log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
ZSTD_getErrorName(retval));
goto err;
+ // LCOV_EXCL_STOP
}
}
@@ -226,6 +238,7 @@ tor_zstd_compress_new(int compress,
return result;
err:
+ // LCOV_EXCL_START
if (compress) {
ZSTD_freeCStream(result->u.compress_stream);
} else {
@@ -234,6 +247,7 @@ tor_zstd_compress_new(int compress,
tor_free(result);
return NULL;
+ // LCOV_EXCL_STOP
#else // HAVE_ZSTD.
(void)compress;
(void)method;
@@ -270,9 +284,16 @@ tor_zstd_compress_process(tor_zstd_compress_state_t *state,
ZSTD_inBuffer input = { *in, *in_len, 0 };
ZSTD_outBuffer output = { *out, *out_len, 0 };
+ if (BUG(finish == 0 && state->have_called_end)) {
+ finish = 1;
+ }
+
if (state->compress) {
- retval = ZSTD_compressStream(state->u.compress_stream,
- &output, &input);
+ if (! state->have_called_end)
+ retval = ZSTD_compressStream(state->u.compress_stream,
+ &output, &input);
+ else
+ retval = 0;
} else {
retval = ZSTD_decompressStream(state->u.decompress_stream,
&output, &input);
@@ -300,7 +321,7 @@ tor_zstd_compress_process(tor_zstd_compress_state_t *state,
return TOR_COMPRESS_ERROR;
}
- if (state->compress && !finish) {
+ if (state->compress && !state->have_called_end) {
retval = ZSTD_flushStream(state->u.compress_stream, &output);
*out = (char *)output.dst + output.pos;
@@ -314,16 +335,24 @@ tor_zstd_compress_process(tor_zstd_compress_state_t *state,
// ZSTD_flushStream returns 0 if the frame is done, or >0 if it
// is incomplete.
- if (retval > 0)
+ if (retval > 0) {
return TOR_COMPRESS_BUFFER_FULL;
+ }
}
if (!finish) {
- // We're not done with the input, so no need to flush.
+ // The caller says we're not done with the input, so no need to write an
+ // epilogue.
return TOR_COMPRESS_OK;
} else if (state->compress) {
- retval = ZSTD_endStream(state->u.compress_stream, &output);
+ if (*in_len) {
+ // We say that we're not done with the input, so we can't write an
+ // epilogue.
+ return TOR_COMPRESS_OK;
+ }
+ retval = ZSTD_endStream(state->u.compress_stream, &output);
+ state->have_called_end = 1;
*out = (char *)output.dst + output.pos;
*out_len = output.size - output.pos;