@@ -100,73 +100,13 @@ impl QuantizedArray {
100
100
read_norms,
101
101
} )
102
102
}
103
- }
104
-
105
- impl Storage for QuantizedArray {
106
- fn embedding ( & self , idx : usize ) -> CowArray1 < f32 > {
107
- let mut reconstructed = self . quantizer . reconstruct_vector ( self . quantized . row ( idx) ) ;
108
- if let Some ( ref norms) = self . norms {
109
- reconstructed *= norms[ idx] ;
110
- }
111
-
112
- CowArray :: Owned ( reconstructed)
113
- }
114
-
115
- fn shape ( & self ) -> ( usize , usize ) {
116
- ( self . quantized . rows ( ) , self . quantizer . reconstructed_len ( ) )
117
- }
118
- }
119
103
120
- impl ReadChunk for QuantizedArray {
121
- fn read_chunk < R > ( read : & mut R ) -> Result < Self >
122
- where
123
- R : Read + Seek ,
124
- {
125
- ChunkIdentifier :: ensure_chunk_type ( read, ChunkIdentifier :: QuantizedArray ) ?;
126
-
127
- // Read and discard chunk length.
128
- read. read_u64 :: < LittleEndian > ( ) . map_err ( |e| {
129
- ErrorKind :: io_error ( "Cannot read quantized embedding matrix chunk length" , e)
130
- } ) ?;
131
-
132
- let PQRead {
133
- n_embeddings,
134
- quantizer,
135
- read_norms,
136
- } = Self :: read_product_quantizer ( read) ?;
137
-
138
- let norms = if read_norms {
139
- let mut norms_vec = vec ! [ 0f32 ; n_embeddings] ;
140
- read. read_f32_into :: < LittleEndian > ( & mut norms_vec)
141
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read norms" , e) ) ?;
142
- Some ( Array1 :: from_vec ( norms_vec) )
143
- } else {
144
- None
145
- } ;
146
-
147
- let mut quantized_embeddings_vec = vec ! [ 0u8 ; n_embeddings * quantizer. quantized_len( ) ] ;
148
- read. read_exact ( & mut quantized_embeddings_vec)
149
- . map_err ( |e| ErrorKind :: io_error ( "Cannot read quantized embeddings" , e) ) ?;
150
- let quantized = Array2 :: from_shape_vec (
151
- ( n_embeddings, quantizer. quantized_len ( ) ) ,
152
- quantized_embeddings_vec,
153
- )
154
- . map_err ( Error :: Shape ) ?;
155
-
156
- Ok ( QuantizedArray {
157
- quantizer,
158
- quantized,
159
- norms,
160
- } )
161
- }
162
- }
163
-
164
- impl WriteChunk for QuantizedArray {
165
- fn chunk_identifier ( & self ) -> ChunkIdentifier {
166
- ChunkIdentifier :: QuantizedArray
167
- }
168
-
169
- fn write_chunk < W > ( & self , write : & mut W ) -> Result < ( ) >
104
+ fn write_chunk < W > (
105
+ write : & mut W ,
106
+ quantizer : & PQ < f32 > ,
107
+ quantized : ArrayView2 < u8 > ,
108
+ norms : Option < ArrayView1 < f32 > > ,
109
+ ) -> Result < ( ) >
170
110
where
171
111
W : Write + Seek ,
172
112
{
@@ -194,16 +134,16 @@ impl WriteChunk for QuantizedArray {
194
134
+ size_of :: < u64 > ( )
195
135
+ 2 * size_of :: < u32 > ( )
196
136
+ n_padding as usize
197
- + self . quantizer . projection ( ) . is_some ( ) as usize
198
- * self . quantizer . reconstructed_len ( )
199
- * self . quantizer . reconstructed_len ( )
137
+ + quantizer. projection ( ) . is_some ( ) as usize
138
+ * quantizer. reconstructed_len ( )
139
+ * quantizer. reconstructed_len ( )
200
140
* size_of :: < f32 > ( )
201
- + self . quantizer . quantized_len ( )
202
- * self . quantizer . n_quantizer_centroids ( )
203
- * ( self . quantizer . reconstructed_len ( ) / self . quantizer . quantized_len ( ) )
141
+ + quantizer. quantized_len ( )
142
+ * quantizer. n_quantizer_centroids ( )
143
+ * ( quantizer. reconstructed_len ( ) / quantizer. quantized_len ( ) )
204
144
* size_of :: < f32 > ( )
205
- + self . norms . is_some ( ) as usize * self . quantized . rows ( ) * size_of :: < f32 > ( )
206
- + self . quantized . rows ( ) * self . quantizer . quantized_len ( ) ;
145
+ + norms. is_some ( ) as usize * quantized. rows ( ) * size_of :: < f32 > ( )
146
+ + quantized. rows ( ) * quantizer. quantized_len ( ) ;
207
147
208
148
write
209
149
. write_u64 :: < LittleEndian > ( chunk_size as u64 )
@@ -212,24 +152,24 @@ impl WriteChunk for QuantizedArray {
212
152
} ) ?;
213
153
214
154
write
215
- . write_u32 :: < LittleEndian > ( self . quantizer . projection ( ) . is_some ( ) as u32 )
155
+ . write_u32 :: < LittleEndian > ( quantizer. projection ( ) . is_some ( ) as u32 )
216
156
. map_err ( |e| {
217
157
ErrorKind :: io_error ( "Cannot write quantized embedding matrix projection" , e)
218
158
} ) ?;
219
159
write
220
- . write_u32 :: < LittleEndian > ( self . norms . is_some ( ) as u32 )
160
+ . write_u32 :: < LittleEndian > ( norms. is_some ( ) as u32 )
221
161
. map_err ( |e| ErrorKind :: io_error ( "Cannot write quantized embedding matrix norms" , e) ) ?;
222
162
write
223
- . write_u32 :: < LittleEndian > ( self . quantizer . quantized_len ( ) as u32 )
163
+ . write_u32 :: < LittleEndian > ( quantizer. quantized_len ( ) as u32 )
224
164
. map_err ( |e| ErrorKind :: io_error ( "Cannot write quantized embedding length" , e) ) ?;
225
165
write
226
- . write_u32 :: < LittleEndian > ( self . quantizer . reconstructed_len ( ) as u32 )
166
+ . write_u32 :: < LittleEndian > ( quantizer. reconstructed_len ( ) as u32 )
227
167
. map_err ( |e| ErrorKind :: io_error ( "Cannot write reconstructed embedding length" , e) ) ?;
228
168
write
229
- . write_u32 :: < LittleEndian > ( self . quantizer . n_quantizer_centroids ( ) as u32 )
169
+ . write_u32 :: < LittleEndian > ( quantizer. n_quantizer_centroids ( ) as u32 )
230
170
. map_err ( |e| ErrorKind :: io_error ( "Cannot write number of subquantizers" , e) ) ?;
231
171
write
232
- . write_u64 :: < LittleEndian > ( self . quantized . rows ( ) as u64 )
172
+ . write_u64 :: < LittleEndian > ( quantized. rows ( ) as u64 )
233
173
. map_err ( |e| ErrorKind :: io_error ( "Cannot write number of quantized embeddings" , e) ) ?;
234
174
235
175
// Quantized and reconstruction types.
@@ -250,7 +190,7 @@ impl WriteChunk for QuantizedArray {
250
190
. map_err ( |e| ErrorKind :: io_error ( "Cannot write padding" , e) ) ?;
251
191
252
192
// Write projection matrix.
253
- if let Some ( projection) = self . quantizer . projection ( ) {
193
+ if let Some ( projection) = quantizer. projection ( ) {
254
194
for row in projection. outer_iter ( ) {
255
195
for & col in row {
256
196
write. write_f32 :: < LittleEndian > ( col) . map_err ( |e| {
@@ -261,7 +201,7 @@ impl WriteChunk for QuantizedArray {
261
201
}
262
202
263
203
// Write subquantizers.
264
- for subquantizer in self . quantizer . subquantizers ( ) {
204
+ for subquantizer in quantizer. subquantizers ( ) {
265
205
for row in subquantizer. outer_iter ( ) {
266
206
for & col in row {
267
207
write. write_f32 :: < LittleEndian > ( col) . map_err ( |e| {
@@ -272,7 +212,7 @@ impl WriteChunk for QuantizedArray {
272
212
}
273
213
274
214
// Write norms.
275
- if let Some ( ref norms) = self . norms {
215
+ if let Some ( ref norms) = norms {
276
216
for row in norms. outer_iter ( ) {
277
217
for & col in row {
278
218
write. write_f32 :: < LittleEndian > ( col) . map_err ( |e| {
@@ -283,7 +223,7 @@ impl WriteChunk for QuantizedArray {
283
223
}
284
224
285
225
// Write quantized embedding matrix.
286
- for row in self . quantized . outer_iter ( ) {
226
+ for row in quantized. outer_iter ( ) {
287
227
for & col in row {
288
228
write. write_u8 ( col) . map_err ( |e| {
289
229
ErrorKind :: io_error ( "Cannot write quantized embedding matrix component" , e)
@@ -295,6 +235,83 @@ impl WriteChunk for QuantizedArray {
295
235
}
296
236
}
297
237
238
+ impl Storage for QuantizedArray {
239
+ fn embedding ( & self , idx : usize ) -> CowArray1 < f32 > {
240
+ let mut reconstructed = self . quantizer . reconstruct_vector ( self . quantized . row ( idx) ) ;
241
+ if let Some ( ref norms) = self . norms {
242
+ reconstructed *= norms[ idx] ;
243
+ }
244
+
245
+ CowArray :: Owned ( reconstructed)
246
+ }
247
+
248
+ fn shape ( & self ) -> ( usize , usize ) {
249
+ ( self . quantized . rows ( ) , self . quantizer . reconstructed_len ( ) )
250
+ }
251
+ }
252
+
253
+ impl ReadChunk for QuantizedArray {
254
+ fn read_chunk < R > ( read : & mut R ) -> Result < Self >
255
+ where
256
+ R : Read + Seek ,
257
+ {
258
+ ChunkIdentifier :: ensure_chunk_type ( read, ChunkIdentifier :: QuantizedArray ) ?;
259
+
260
+ // Read and discard chunk length.
261
+ read. read_u64 :: < LittleEndian > ( ) . map_err ( |e| {
262
+ ErrorKind :: io_error ( "Cannot read quantized embedding matrix chunk length" , e)
263
+ } ) ?;
264
+
265
+ let PQRead {
266
+ n_embeddings,
267
+ quantizer,
268
+ read_norms,
269
+ } = Self :: read_product_quantizer ( read) ?;
270
+
271
+ let norms = if read_norms {
272
+ let mut norms_vec = vec ! [ 0f32 ; n_embeddings] ;
273
+ read. read_f32_into :: < LittleEndian > ( & mut norms_vec)
274
+ . map_err ( |e| ErrorKind :: io_error ( "Cannot read norms" , e) ) ?;
275
+ Some ( Array1 :: from_vec ( norms_vec) )
276
+ } else {
277
+ None
278
+ } ;
279
+
280
+ let mut quantized_embeddings_vec = vec ! [ 0u8 ; n_embeddings * quantizer. quantized_len( ) ] ;
281
+ read. read_exact ( & mut quantized_embeddings_vec)
282
+ . map_err ( |e| ErrorKind :: io_error ( "Cannot read quantized embeddings" , e) ) ?;
283
+ let quantized = Array2 :: from_shape_vec (
284
+ ( n_embeddings, quantizer. quantized_len ( ) ) ,
285
+ quantized_embeddings_vec,
286
+ )
287
+ . map_err ( Error :: Shape ) ?;
288
+
289
+ Ok ( QuantizedArray {
290
+ quantizer,
291
+ quantized,
292
+ norms,
293
+ } )
294
+ }
295
+ }
296
+
297
+ impl WriteChunk for QuantizedArray {
298
+ fn chunk_identifier ( & self ) -> ChunkIdentifier {
299
+ ChunkIdentifier :: QuantizedArray
300
+ }
301
+
302
+ fn write_chunk < W > ( & self , write : & mut W ) -> Result < ( ) >
303
+ where
304
+ W : Write + Seek ,
305
+ {
306
+ Self :: write_chunk (
307
+ write,
308
+ & self . quantizer ,
309
+ self . quantized . view ( ) ,
310
+ self . norms . as_ref ( ) . map ( Array1 :: view) ,
311
+ )
312
+ }
313
+ }
314
+
298
315
/// Quantizable embedding matrix.
299
316
pub trait Quantize {
300
317
/// Quantize the embedding matrix.
@@ -401,6 +418,27 @@ pub struct MmapQuantizedArray {
401
418
norms : Option < Mmap > ,
402
419
}
403
420
421
+ impl MmapQuantizedArray {
422
+ unsafe fn norms ( & self ) -> Option < ArrayView1 < f32 > > {
423
+ let n_embeddings = self . shape ( ) . 0 ;
424
+
425
+ #[ allow( clippy:: cast_ptr_alignment) ]
426
+ self . norms . as_ref ( ) . map ( |norms|
427
+ // Alignment is ok, padding guarantees that the pointer is at
428
+ // a multiple of 4.
429
+ ArrayView1 :: from_shape_ptr ( ( n_embeddings, ) , norms. as_ptr ( ) as * const f32 ) )
430
+ }
431
+
432
+ unsafe fn quantized ( & self ) -> ArrayView2 < u8 > {
433
+ let n_embeddings = self . shape ( ) . 0 ;
434
+
435
+ ArrayView2 :: from_shape_ptr (
436
+ ( n_embeddings, self . quantizer . quantized_len ( ) ) ,
437
+ self . quantized . as_ptr ( ) ,
438
+ )
439
+ }
440
+ }
441
+
404
442
impl MmapQuantizedArray {
405
443
fn mmap_norms ( read : & mut BufReader < File > , n_embeddings : usize ) -> Result < Mmap > {
406
444
let offset = read. seek ( SeekFrom :: Current ( 0 ) ) . map_err ( |e| {
@@ -460,23 +498,10 @@ impl MmapQuantizedArray {
460
498
461
499
impl Storage for MmapQuantizedArray {
462
500
fn embedding ( & self , idx : usize ) -> CowArray1 < f32 > {
463
- let n_embeddings = self . shape ( ) . 0 ;
464
-
465
- let quantized = unsafe {
466
- ArrayView2 :: from_shape_ptr (
467
- ( n_embeddings, self . quantizer . quantized_len ( ) ) ,
468
- self . quantized . as_ptr ( ) ,
469
- )
470
- } ;
501
+ let quantized = unsafe { self . quantized ( ) } ;
471
502
472
503
let mut reconstructed = self . quantizer . reconstruct_vector ( quantized. row ( idx) ) ;
473
- if let Some ( ref norms) = self . norms {
474
- // Alignment is ok, padding guarantees that the pointer is at
475
- // a multiple of 4.
476
- #[ allow( clippy:: cast_ptr_alignment) ]
477
- let norms = unsafe {
478
- ArrayView1 :: from_shape_ptr ( ( n_embeddings, ) , norms. as_ptr ( ) as * const f32 )
479
- } ;
504
+ if let Some ( norms) = unsafe { self . norms ( ) } {
480
505
reconstructed *= norms[ idx] ;
481
506
}
482
507
@@ -523,6 +548,24 @@ impl MmapChunk for MmapQuantizedArray {
523
548
}
524
549
}
525
550
551
+ impl WriteChunk for MmapQuantizedArray {
552
+ fn chunk_identifier ( & self ) -> ChunkIdentifier {
553
+ ChunkIdentifier :: QuantizedArray
554
+ }
555
+
556
+ fn write_chunk < W > ( & self , write : & mut W ) -> Result < ( ) >
557
+ where
558
+ W : Write + Seek ,
559
+ {
560
+ QuantizedArray :: write_chunk (
561
+ write,
562
+ & self . quantizer ,
563
+ unsafe { self . quantized ( ) } ,
564
+ unsafe { self . norms ( ) } ,
565
+ )
566
+ }
567
+ }
568
+
526
569
#[ cfg( test) ]
527
570
mod tests {
528
571
use std:: fs:: File ;
@@ -559,6 +602,17 @@ mod tests {
559
602
read. read_u64 :: < LittleEndian > ( ) . unwrap ( )
560
603
}
561
604
605
+ // Compare storage for which Eq is not implemented.
606
+ fn storage_eq ( arr : & impl Storage , check_arr : & impl Storage ) {
607
+ assert_eq ! ( arr. shape( ) , check_arr. shape( ) ) ;
608
+ for idx in 0 ..check_arr. shape ( ) . 0 {
609
+ assert_eq ! (
610
+ arr. embedding( idx) . as_view( ) ,
611
+ check_arr. embedding( idx) . as_view( )
612
+ ) ;
613
+ }
614
+ }
615
+
562
616
#[ test]
563
617
fn quantized_array_correct_chunk_size ( ) {
564
618
let check_arr = test_quantized_array ( false ) ;
@@ -609,12 +663,25 @@ mod tests {
609
663
let arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
610
664
611
665
// Check
612
- assert_eq ! ( arr. shape( ) , check_arr. shape( ) ) ;
613
- for idx in 0 ..check_arr. shape ( ) . 0 {
614
- assert_eq ! (
615
- arr. embedding( idx) . as_view( ) ,
616
- check_arr. embedding( idx) . as_view( )
617
- ) ;
618
- }
666
+ storage_eq ( & arr, & check_arr) ;
667
+ }
668
+
669
+ #[ test]
670
+ fn write_mmap_quantized_array ( ) {
671
+ // Memory map matrix.
672
+ let mut storage_read =
673
+ BufReader :: new ( File :: open ( "testdata/quantized_storage.bin" ) . unwrap ( ) ) ;
674
+ let check_arr = MmapQuantizedArray :: mmap_chunk ( & mut storage_read) . unwrap ( ) ;
675
+
676
+ // Write matrix
677
+ let mut cursor = Cursor :: new ( Vec :: new ( ) ) ;
678
+ check_arr. write_chunk ( & mut cursor) . unwrap ( ) ;
679
+
680
+ // Read using non-mmap'ed reader.
681
+ cursor. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
682
+ let arr = QuantizedArray :: read_chunk ( & mut cursor) . unwrap ( ) ;
683
+
684
+ // Check
685
+ storage_eq ( & arr, & check_arr) ;
619
686
}
620
687
}
0 commit comments