@@ -39,61 +39,91 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
3939  return  output ;
4040}
4141
42- static  VALUE  decompress_buffered (ZSTD_DCtx *  dctx , const  char *  input_data , size_t  input_size )
43- {
44-   ZSTD_inBuffer  input  =  { input_data , input_size , 0  };
45-   VALUE  result  =  rb_str_new (0 , 0 );
42+ static  VALUE  decode_one_frame (ZSTD_DCtx *  dctx , const  unsigned char  *  src , size_t  size , VALUE  kwargs ) {
43+   VALUE  out  =  rb_str_buf_new (0 );
44+   size_t  cap  =  ZSTD_DStreamOutSize ();
45+   char  * buf  =  ALLOC_N (char , cap );
46+   ZSTD_inBuffer  in  =  (ZSTD_inBuffer ){ src , size , 0  };
4647
47-   while  (input .pos  <  input .size ) {
48-     ZSTD_outBuffer  output  =  { NULL , 0 , 0  };
49-     output .size  +=  ZSTD_DStreamOutSize ();
50-     VALUE  output_string  =  rb_str_new (NULL , output .size );
51-     output .dst  =  RSTRING_PTR (output_string );
48+   ZSTD_DCtx_reset (dctx , ZSTD_reset_session_only );
49+   set_decompress_params (dctx , kwargs );
5250
53-     size_t  ret  =  zstd_stream_decompress (dctx , & output , & input , false);
51+   for  (;;) {
52+     ZSTD_outBuffer  o  =  (ZSTD_outBuffer ){ buf , cap , 0  };
53+     size_t  ret  =  ZSTD_decompressStream (dctx , & o , & in );
5454    if  (ZSTD_isError (ret )) {
55-       ZSTD_freeDCtx (dctx );
56-       rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_decompressStream failed" , ZSTD_getErrorName (ret ));
55+       xfree (buf );
56+       rb_raise (rb_eRuntimeError , "ZSTD_decompressStream failed: %s" , ZSTD_getErrorName (ret ));
57+     }
58+     if  (o .pos ) {
59+       rb_str_cat (out , buf , o .pos );
60+     }
61+     if  (ret  ==  0 ) {
62+       break ;
5763    }
58-     rb_str_cat (result , output .dst , output .pos );
5964  }
60-   ZSTD_freeDCtx (dctx );
61-   return  result ;
65+   xfree (buf );
66+   return  out ;
67+ }
68+ 
69+ static  VALUE  decompress_buffered (ZSTD_DCtx *  dctx , const  char *  data , size_t  len ) {
70+   return  decode_one_frame (dctx , (const  unsigned char  * )data , len , Qnil );
6271}
6372
6473static  VALUE  rb_decompress (int  argc , VALUE  * argv , VALUE  self )
6574{
66-   VALUE  input_value ;
67-   VALUE  kwargs ;
75+   VALUE  input_value , kwargs ;
6876  rb_scan_args (argc , argv , "10:" , & input_value , & kwargs );
6977  StringValue (input_value );
70-   char *  input_data  =  RSTRING_PTR (input_value );
71-   size_t  input_size  =  RSTRING_LEN (input_value );
72-   ZSTD_DCtx *  const  dctx  =  ZSTD_createDCtx ();
73-   if  (dctx  ==  NULL ) {
74-     rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
75-   }
76-   set_decompress_params (dctx , kwargs );
7778
78-   unsigned long long const   uncompressed_size  =  ZSTD_getFrameContentSize (input_data , input_size );
79-   if  (uncompressed_size  ==  ZSTD_CONTENTSIZE_ERROR ) {
80-     rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
81-   }
82-   // ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness. 
83-   // Therefore, we will not standardize on ZSTD_decompressStream 
84-   if  (uncompressed_size  ==  ZSTD_CONTENTSIZE_UNKNOWN ) {
85-     return  decompress_buffered (dctx , input_data , input_size );
86-   }
79+   size_t  in_size  =  RSTRING_LEN (input_value );
80+   const  unsigned char   * in_r  =  (const  unsigned char   * )RSTRING_PTR (input_value );
81+   unsigned char   * in  =  ALLOC_N (unsigned  char , in_size );
82+   memcpy (in , in_r , in_size );
83+ 
84+   size_t  off  =  0 ;
85+   const  uint32_t  ZSTD_MAGIC  =  0xFD2FB528U ;
86+   const  uint32_t  SKIP_LO     =  0x184D2A50U ; /* ...5F */ 
87+ 
88+   while  (off  +  4  <= in_size ) {
89+     uint32_t  magic  =  (uint32_t )in [off ]
90+                    | ((uint32_t )in [off + 1 ] << 8 )
91+                    | ((uint32_t )in [off + 2 ] << 16 )
92+                    | ((uint32_t )in [off + 3 ] << 24 );
93+ 
94+     if  ((magic  &  0xFFFFFFF0U ) ==  (SKIP_LO  &  0xFFFFFFF0U )) {
95+       if  (off  +  8  >  in_size ) break ;
96+       uint32_t  skipLen  =  (uint32_t )in [off + 4 ]
97+                        | ((uint32_t )in [off + 5 ] << 8 )
98+                        | ((uint32_t )in [off + 6 ] << 16 )
99+                        | ((uint32_t )in [off + 7 ] << 24 );
100+       size_t  adv  =  (size_t )8  +  (size_t )skipLen ;
101+       if  (off  +  adv  >  in_size ) break ;
102+       off  +=  adv ;
103+       continue ;
104+     }
87105
88-   VALUE  output  =  rb_str_new (NULL , uncompressed_size );
89-   char *  output_data  =  RSTRING_PTR (output );
106+     if  (magic  ==  ZSTD_MAGIC ) {
107+       ZSTD_DCtx  * dctx  =  ZSTD_createDCtx ();
108+       if  (!dctx ) {
109+         xfree (in );
110+         rb_raise (rb_eRuntimeError , "ZSTD_createDCtx failed" );
111+       }
112+ 
113+       VALUE  out  =  decode_one_frame (dctx , in  +  off , in_size  -  off , kwargs );
90114
91-   size_t  const  decompress_size  =  zstd_decompress (dctx , output_data , uncompressed_size , input_data , input_size , false);
92-   if  (ZSTD_isError (decompress_size )) {
93-     rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
115+       ZSTD_freeDCtx (dctx );
116+       xfree (in );
117+       RB_GC_GUARD (input_value );
118+       return  out ;
119+     }
120+ 
121+     off  +=  1 ;
94122  }
95-   ZSTD_freeDCtx (dctx );
96-   return  output ;
123+ 
124+   xfree (in );
125+   RB_GC_GUARD (input_value );
126+   rb_raise (rb_eRuntimeError , "not a zstd frame (magic not found)" );
97127}
98128
99129static  void  free_cdict (void  * dict )
0 commit comments