@@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
3939 return output ;
4040}
4141
42- static VALUE decompress_buffered (ZSTD_DCtx * dctx , const char * input_data , size_t input_size )
43- {
44- ZSTD_inBuffer input = { input_data , input_size , 0 };
45- VALUE result = rb_str_new (0 , 0 );
42+ static VALUE decode_one_frame (ZSTD_DCtx * dctx , const unsigned char * src , size_t size , VALUE kwargs ) {
43+ VALUE out = rb_str_buf_new (0 );
44+ size_t cap = ZSTD_DStreamOutSize ();
45+ char * buf = ALLOC_N (char , cap );
46+ ZSTD_inBuffer in = (ZSTD_inBuffer ){ src , size , 0 };
4647
47- while (input .pos < input .size ) {
48- ZSTD_outBuffer output = { NULL , 0 , 0 };
49- output .size += ZSTD_DStreamOutSize ();
50- VALUE output_string = rb_str_new (NULL , output .size );
51- output .dst = RSTRING_PTR (output_string );
48+ ZSTD_DCtx_reset (dctx , ZSTD_reset_session_only );
49+ set_decompress_params (dctx , kwargs );
5250
53- size_t ret = zstd_stream_decompress (dctx , & output , & input , false);
51+ for (;;) {
52+ ZSTD_outBuffer o = (ZSTD_outBuffer ){ buf , cap , 0 };
53+ size_t ret = ZSTD_decompressStream (dctx , & o , & in );
5454 if (ZSTD_isError (ret )) {
55- ZSTD_freeDCtx (dctx );
56- rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_decompressStream failed" , ZSTD_getErrorName (ret ));
55+ xfree (buf );
56+ rb_raise (rb_eRuntimeError , "ZSTD_decompressStream failed: %s" , ZSTD_getErrorName (ret ));
57+ }
58+ if (o .pos ) {
59+ rb_str_cat (out , buf , o .pos );
60+ }
61+ if (ret == 0 ) {
62+ break ;
5763 }
58- rb_str_cat (result , output .dst , output .pos );
5964 }
60- ZSTD_freeDCtx (dctx );
61- return result ;
65+ xfree (buf );
66+ return out ;
67+ }
68+
69+ static VALUE decompress_buffered (ZSTD_DCtx * dctx , const char * data , size_t len ) {
70+ return decode_one_frame (dctx , (const unsigned char * )data , len , Qnil );
6271}
6372
6473static VALUE rb_decompress (int argc , VALUE * argv , VALUE self )
6574{
66- VALUE input_value ;
67- VALUE kwargs ;
75+ VALUE input_value , kwargs ;
6876 rb_scan_args (argc , argv , "10:" , & input_value , & kwargs );
6977 StringValue (input_value );
70- char * input_data = RSTRING_PTR (input_value );
71- size_t input_size = RSTRING_LEN (input_value );
72- ZSTD_DCtx * const dctx = ZSTD_createDCtx ();
73- if (dctx == NULL ) {
74- rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
75- }
76- set_decompress_params (dctx , kwargs );
7778
78- unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
79- if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
80- rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
81- }
82- // ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
83- // Therefore, we will not standardize on ZSTD_decompressStream
84- if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
85- return decompress_buffered (dctx , input_data , input_size );
86- }
79+ size_t in_size = RSTRING_LEN (input_value );
80+ const unsigned char * in_r = (const unsigned char * )RSTRING_PTR (input_value );
81+ unsigned char * in = ALLOC_N (unsigned char , in_size );
82+ memcpy (in , in_r , in_size );
83+
84+ size_t off = 0 ;
85+ const uint32_t ZSTD_MAGIC = 0xFD2FB528U ;
86+ const uint32_t SKIP_LO = 0x184D2A50U ; /* ...5F */
87+
88+ while (off + 4 <= in_size ) {
89+ uint32_t magic = (uint32_t )in [off ]
90+ | ((uint32_t )in [off + 1 ] << 8 )
91+ | ((uint32_t )in [off + 2 ] << 16 )
92+ | ((uint32_t )in [off + 3 ] << 24 );
93+
94+ if ((magic & 0xFFFFFFF0U ) == (SKIP_LO & 0xFFFFFFF0U )) {
95+ if (off + 8 > in_size ) break ;
96+ uint32_t skipLen = (uint32_t )in [off + 4 ]
97+ | ((uint32_t )in [off + 5 ] << 8 )
98+ | ((uint32_t )in [off + 6 ] << 16 )
99+ | ((uint32_t )in [off + 7 ] << 24 );
100+ size_t adv = (size_t )8 + (size_t )skipLen ;
101+ if (off + adv > in_size ) break ;
102+ off += adv ;
103+ continue ;
104+ }
87105
88- VALUE output = rb_str_new (NULL , uncompressed_size );
89- char * output_data = RSTRING_PTR (output );
106+ if (magic == ZSTD_MAGIC ) {
107+ ZSTD_DCtx * dctx = ZSTD_createDCtx ();
108+ if (!dctx ) {
109+ xfree (in );
110+ rb_raise (rb_eRuntimeError , "ZSTD_createDCtx failed" );
111+ }
112+
113+ VALUE out = decode_one_frame (dctx , in + off , in_size - off , kwargs );
90114
91- size_t const decompress_size = zstd_decompress (dctx , output_data , uncompressed_size , input_data , input_size , false);
92- if (ZSTD_isError (decompress_size )) {
93- rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
115+ ZSTD_freeDCtx (dctx );
116+ xfree (in );
117+ RB_GC_GUARD (input_value );
118+ return out ;
119+ }
120+
121+ off += 1 ;
94122 }
95- ZSTD_freeDCtx (dctx );
96- return output ;
123+
124+ xfree (in );
125+ RB_GC_GUARD (input_value );
126+ rb_raise (rb_eRuntimeError , "not a zstd frame (magic not found)" );
97127}
98128
99129static void free_cdict (void * dict )
0 commit comments