Skip to content

Commit 7951b2b

Browse files
danieldkDaniël de Kok
authored andcommitted
Implement WriteChunk for MmapQuantizedArray
1 parent a2044f6 commit 7951b2b

File tree

2 files changed

+175
-108
lines changed

2 files changed

+175
-108
lines changed

src/chunks/storage/quantized.rs

Lines changed: 173 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -100,73 +100,13 @@ impl QuantizedArray {
100100
read_norms,
101101
})
102102
}
103-
}
104-
105-
impl Storage for QuantizedArray {
106-
fn embedding(&self, idx: usize) -> CowArray1<f32> {
107-
let mut reconstructed = self.quantizer.reconstruct_vector(self.quantized.row(idx));
108-
if let Some(ref norms) = self.norms {
109-
reconstructed *= norms[idx];
110-
}
111-
112-
CowArray::Owned(reconstructed)
113-
}
114-
115-
fn shape(&self) -> (usize, usize) {
116-
(self.quantized.rows(), self.quantizer.reconstructed_len())
117-
}
118-
}
119103

120-
impl ReadChunk for QuantizedArray {
121-
fn read_chunk<R>(read: &mut R) -> Result<Self>
122-
where
123-
R: Read + Seek,
124-
{
125-
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::QuantizedArray)?;
126-
127-
// Read and discard chunk length.
128-
read.read_u64::<LittleEndian>().map_err(|e| {
129-
ErrorKind::io_error("Cannot read quantized embedding matrix chunk length", e)
130-
})?;
131-
132-
let PQRead {
133-
n_embeddings,
134-
quantizer,
135-
read_norms,
136-
} = Self::read_product_quantizer(read)?;
137-
138-
let norms = if read_norms {
139-
let mut norms_vec = vec![0f32; n_embeddings];
140-
read.read_f32_into::<LittleEndian>(&mut norms_vec)
141-
.map_err(|e| ErrorKind::io_error("Cannot read norms", e))?;
142-
Some(Array1::from_vec(norms_vec))
143-
} else {
144-
None
145-
};
146-
147-
let mut quantized_embeddings_vec = vec![0u8; n_embeddings * quantizer.quantized_len()];
148-
read.read_exact(&mut quantized_embeddings_vec)
149-
.map_err(|e| ErrorKind::io_error("Cannot read quantized embeddings", e))?;
150-
let quantized = Array2::from_shape_vec(
151-
(n_embeddings, quantizer.quantized_len()),
152-
quantized_embeddings_vec,
153-
)
154-
.map_err(Error::Shape)?;
155-
156-
Ok(QuantizedArray {
157-
quantizer,
158-
quantized,
159-
norms,
160-
})
161-
}
162-
}
163-
164-
impl WriteChunk for QuantizedArray {
165-
fn chunk_identifier(&self) -> ChunkIdentifier {
166-
ChunkIdentifier::QuantizedArray
167-
}
168-
169-
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
104+
fn write_chunk<W>(
105+
write: &mut W,
106+
quantizer: &PQ<f32>,
107+
quantized: ArrayView2<u8>,
108+
norms: Option<ArrayView1<f32>>,
109+
) -> Result<()>
170110
where
171111
W: Write + Seek,
172112
{
@@ -194,16 +134,16 @@ impl WriteChunk for QuantizedArray {
194134
+ size_of::<u64>()
195135
+ 2 * size_of::<u32>()
196136
+ n_padding as usize
197-
+ self.quantizer.projection().is_some() as usize
198-
* self.quantizer.reconstructed_len()
199-
* self.quantizer.reconstructed_len()
137+
+ quantizer.projection().is_some() as usize
138+
* quantizer.reconstructed_len()
139+
* quantizer.reconstructed_len()
200140
* size_of::<f32>()
201-
+ self.quantizer.quantized_len()
202-
* self.quantizer.n_quantizer_centroids()
203-
* (self.quantizer.reconstructed_len() / self.quantizer.quantized_len())
141+
+ quantizer.quantized_len()
142+
* quantizer.n_quantizer_centroids()
143+
* (quantizer.reconstructed_len() / quantizer.quantized_len())
204144
* size_of::<f32>()
205-
+ self.norms.is_some() as usize * self.quantized.rows() * size_of::<f32>()
206-
+ self.quantized.rows() * self.quantizer.quantized_len();
145+
+ norms.is_some() as usize * quantized.rows() * size_of::<f32>()
146+
+ quantized.rows() * quantizer.quantized_len();
207147

208148
write
209149
.write_u64::<LittleEndian>(chunk_size as u64)
@@ -212,24 +152,24 @@ impl WriteChunk for QuantizedArray {
212152
})?;
213153

214154
write
215-
.write_u32::<LittleEndian>(self.quantizer.projection().is_some() as u32)
155+
.write_u32::<LittleEndian>(quantizer.projection().is_some() as u32)
216156
.map_err(|e| {
217157
ErrorKind::io_error("Cannot write quantized embedding matrix projection", e)
218158
})?;
219159
write
220-
.write_u32::<LittleEndian>(self.norms.is_some() as u32)
160+
.write_u32::<LittleEndian>(norms.is_some() as u32)
221161
.map_err(|e| ErrorKind::io_error("Cannot write quantized embedding matrix norms", e))?;
222162
write
223-
.write_u32::<LittleEndian>(self.quantizer.quantized_len() as u32)
163+
.write_u32::<LittleEndian>(quantizer.quantized_len() as u32)
224164
.map_err(|e| ErrorKind::io_error("Cannot write quantized embedding length", e))?;
225165
write
226-
.write_u32::<LittleEndian>(self.quantizer.reconstructed_len() as u32)
166+
.write_u32::<LittleEndian>(quantizer.reconstructed_len() as u32)
227167
.map_err(|e| ErrorKind::io_error("Cannot write reconstructed embedding length", e))?;
228168
write
229-
.write_u32::<LittleEndian>(self.quantizer.n_quantizer_centroids() as u32)
169+
.write_u32::<LittleEndian>(quantizer.n_quantizer_centroids() as u32)
230170
.map_err(|e| ErrorKind::io_error("Cannot write number of subquantizers", e))?;
231171
write
232-
.write_u64::<LittleEndian>(self.quantized.rows() as u64)
172+
.write_u64::<LittleEndian>(quantized.rows() as u64)
233173
.map_err(|e| ErrorKind::io_error("Cannot write number of quantized embeddings", e))?;
234174

235175
// Quantized and reconstruction types.
@@ -250,7 +190,7 @@ impl WriteChunk for QuantizedArray {
250190
.map_err(|e| ErrorKind::io_error("Cannot write padding", e))?;
251191

252192
// Write projection matrix.
253-
if let Some(projection) = self.quantizer.projection() {
193+
if let Some(projection) = quantizer.projection() {
254194
for row in projection.outer_iter() {
255195
for &col in row {
256196
write.write_f32::<LittleEndian>(col).map_err(|e| {
@@ -261,7 +201,7 @@ impl WriteChunk for QuantizedArray {
261201
}
262202

263203
// Write subquantizers.
264-
for subquantizer in self.quantizer.subquantizers() {
204+
for subquantizer in quantizer.subquantizers() {
265205
for row in subquantizer.outer_iter() {
266206
for &col in row {
267207
write.write_f32::<LittleEndian>(col).map_err(|e| {
@@ -272,7 +212,7 @@ impl WriteChunk for QuantizedArray {
272212
}
273213

274214
// Write norms.
275-
if let Some(ref norms) = self.norms {
215+
if let Some(ref norms) = norms {
276216
for row in norms.outer_iter() {
277217
for &col in row {
278218
write.write_f32::<LittleEndian>(col).map_err(|e| {
@@ -283,7 +223,7 @@ impl WriteChunk for QuantizedArray {
283223
}
284224

285225
// Write quantized embedding matrix.
286-
for row in self.quantized.outer_iter() {
226+
for row in quantized.outer_iter() {
287227
for &col in row {
288228
write.write_u8(col).map_err(|e| {
289229
ErrorKind::io_error("Cannot write quantized embedding matrix component", e)
@@ -295,6 +235,83 @@ impl WriteChunk for QuantizedArray {
295235
}
296236
}
297237

238+
impl Storage for QuantizedArray {
239+
fn embedding(&self, idx: usize) -> CowArray1<f32> {
240+
let mut reconstructed = self.quantizer.reconstruct_vector(self.quantized.row(idx));
241+
if let Some(ref norms) = self.norms {
242+
reconstructed *= norms[idx];
243+
}
244+
245+
CowArray::Owned(reconstructed)
246+
}
247+
248+
fn shape(&self) -> (usize, usize) {
249+
(self.quantized.rows(), self.quantizer.reconstructed_len())
250+
}
251+
}
252+
253+
impl ReadChunk for QuantizedArray {
254+
fn read_chunk<R>(read: &mut R) -> Result<Self>
255+
where
256+
R: Read + Seek,
257+
{
258+
ChunkIdentifier::ensure_chunk_type(read, ChunkIdentifier::QuantizedArray)?;
259+
260+
// Read and discard chunk length.
261+
read.read_u64::<LittleEndian>().map_err(|e| {
262+
ErrorKind::io_error("Cannot read quantized embedding matrix chunk length", e)
263+
})?;
264+
265+
let PQRead {
266+
n_embeddings,
267+
quantizer,
268+
read_norms,
269+
} = Self::read_product_quantizer(read)?;
270+
271+
let norms = if read_norms {
272+
let mut norms_vec = vec![0f32; n_embeddings];
273+
read.read_f32_into::<LittleEndian>(&mut norms_vec)
274+
.map_err(|e| ErrorKind::io_error("Cannot read norms", e))?;
275+
Some(Array1::from_vec(norms_vec))
276+
} else {
277+
None
278+
};
279+
280+
let mut quantized_embeddings_vec = vec![0u8; n_embeddings * quantizer.quantized_len()];
281+
read.read_exact(&mut quantized_embeddings_vec)
282+
.map_err(|e| ErrorKind::io_error("Cannot read quantized embeddings", e))?;
283+
let quantized = Array2::from_shape_vec(
284+
(n_embeddings, quantizer.quantized_len()),
285+
quantized_embeddings_vec,
286+
)
287+
.map_err(Error::Shape)?;
288+
289+
Ok(QuantizedArray {
290+
quantizer,
291+
quantized,
292+
norms,
293+
})
294+
}
295+
}
296+
297+
impl WriteChunk for QuantizedArray {
298+
fn chunk_identifier(&self) -> ChunkIdentifier {
299+
ChunkIdentifier::QuantizedArray
300+
}
301+
302+
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
303+
where
304+
W: Write + Seek,
305+
{
306+
Self::write_chunk(
307+
write,
308+
&self.quantizer,
309+
self.quantized.view(),
310+
self.norms.as_ref().map(Array1::view),
311+
)
312+
}
313+
}
314+
298315
/// Quantizable embedding matrix.
299316
pub trait Quantize {
300317
/// Quantize the embedding matrix.
@@ -401,6 +418,27 @@ pub struct MmapQuantizedArray {
401418
norms: Option<Mmap>,
402419
}
403420

421+
impl MmapQuantizedArray {
422+
unsafe fn norms(&self) -> Option<ArrayView1<f32>> {
423+
let n_embeddings = self.shape().0;
424+
425+
#[allow(clippy::cast_ptr_alignment)]
426+
self.norms.as_ref().map(|norms|
427+
// Alignment is ok, padding guarantees that the pointer is at
428+
// a multiple of 4.
429+
ArrayView1::from_shape_ptr((n_embeddings,), norms.as_ptr() as *const f32))
430+
}
431+
432+
unsafe fn quantized(&self) -> ArrayView2<u8> {
433+
let n_embeddings = self.shape().0;
434+
435+
ArrayView2::from_shape_ptr(
436+
(n_embeddings, self.quantizer.quantized_len()),
437+
self.quantized.as_ptr(),
438+
)
439+
}
440+
}
441+
404442
impl MmapQuantizedArray {
405443
fn mmap_norms(read: &mut BufReader<File>, n_embeddings: usize) -> Result<Mmap> {
406444
let offset = read.seek(SeekFrom::Current(0)).map_err(|e| {
@@ -460,23 +498,10 @@ impl MmapQuantizedArray {
460498

461499
impl Storage for MmapQuantizedArray {
462500
fn embedding(&self, idx: usize) -> CowArray1<f32> {
463-
let n_embeddings = self.shape().0;
464-
465-
let quantized = unsafe {
466-
ArrayView2::from_shape_ptr(
467-
(n_embeddings, self.quantizer.quantized_len()),
468-
self.quantized.as_ptr(),
469-
)
470-
};
501+
let quantized = unsafe { self.quantized() };
471502

472503
let mut reconstructed = self.quantizer.reconstruct_vector(quantized.row(idx));
473-
if let Some(ref norms) = self.norms {
474-
// Alignment is ok, padding guarantees that the pointer is at
475-
// a multiple of 4.
476-
#[allow(clippy::cast_ptr_alignment)]
477-
let norms = unsafe {
478-
ArrayView1::from_shape_ptr((n_embeddings,), norms.as_ptr() as *const f32)
479-
};
504+
if let Some(norms) = unsafe { self.norms() } {
480505
reconstructed *= norms[idx];
481506
}
482507

@@ -523,6 +548,24 @@ impl MmapChunk for MmapQuantizedArray {
523548
}
524549
}
525550

551+
impl WriteChunk for MmapQuantizedArray {
552+
fn chunk_identifier(&self) -> ChunkIdentifier {
553+
ChunkIdentifier::QuantizedArray
554+
}
555+
556+
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
557+
where
558+
W: Write + Seek,
559+
{
560+
QuantizedArray::write_chunk(
561+
write,
562+
&self.quantizer,
563+
unsafe { self.quantized() },
564+
unsafe { self.norms() },
565+
)
566+
}
567+
}
568+
526569
#[cfg(test)]
527570
mod tests {
528571
use std::fs::File;
@@ -559,6 +602,17 @@ mod tests {
559602
read.read_u64::<LittleEndian>().unwrap()
560603
}
561604

605+
// Compare storage for which Eq is not implemented.
606+
fn storage_eq(arr: &impl Storage, check_arr: &impl Storage) {
607+
assert_eq!(arr.shape(), check_arr.shape());
608+
for idx in 0..check_arr.shape().0 {
609+
assert_eq!(
610+
arr.embedding(idx).as_view(),
611+
check_arr.embedding(idx).as_view()
612+
);
613+
}
614+
}
615+
562616
#[test]
563617
fn quantized_array_correct_chunk_size() {
564618
let check_arr = test_quantized_array(false);
@@ -609,12 +663,25 @@ mod tests {
609663
let arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
610664

611665
// Check
612-
assert_eq!(arr.shape(), check_arr.shape());
613-
for idx in 0..check_arr.shape().0 {
614-
assert_eq!(
615-
arr.embedding(idx).as_view(),
616-
check_arr.embedding(idx).as_view()
617-
);
618-
}
666+
storage_eq(&arr, &check_arr);
667+
}
668+
669+
#[test]
670+
fn write_mmap_quantized_array() {
671+
// Memory map matrix.
672+
let mut storage_read =
673+
BufReader::new(File::open("testdata/quantized_storage.bin").unwrap());
674+
let check_arr = MmapQuantizedArray::mmap_chunk(&mut storage_read).unwrap();
675+
676+
// Write matrix
677+
let mut cursor = Cursor::new(Vec::new());
678+
check_arr.write_chunk(&mut cursor).unwrap();
679+
680+
// Read using non-mmap'ed reader.
681+
cursor.seek(SeekFrom::Start(0)).unwrap();
682+
let arr = QuantizedArray::read_chunk(&mut cursor).unwrap();
683+
684+
// Check
685+
storage_eq(&arr, &check_arr);
619686
}
620687
}

0 commit comments

Comments
 (0)