Skip to content

Commit 9f764cb

Browse files
congx4cursoragent
andcommitted
Fix CardinalityCollector deserialization: use deserialize_byte_buf
The Serialize impl uses serializer.serialize_bytes() which hints the serializer to use its bytes-optimized path. The Deserialize impl was using Vec<u8>::deserialize() which goes through deserialize_seq(), a different deserialization path. While the wire formats are theoretically identical in postcard for u8 sequences, using deserialize_byte_buf explicitly pairs with serialize_bytes and eliminates any potential mismatch in binary serializers. The visitor implements visit_bytes, visit_borrowed_bytes, visit_byte_buf, and visit_seq to handle all possible deserialization paths gracefully. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 18fedd9 commit 9f764cb

File tree

1 file changed

+249
-1
lines changed

1 file changed

+249
-1
lines changed

src/aggregation/metric/cardinality.rs

Lines changed: 249 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,43 @@ impl Serialize for CardinalityCollector {
341341

342342
impl<'de> Deserialize<'de> for CardinalityCollector {
343343
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
344-
let bytes: Vec<u8> = Deserialize::deserialize(deserializer)?;
344+
struct HllBytesVisitor;
345+
346+
impl<'de> serde::de::Visitor<'de> for HllBytesVisitor {
347+
type Value = Vec<u8>;
348+
349+
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
350+
f.write_str("HLL sketch bytes")
351+
}
352+
353+
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Vec<u8>, E> {
354+
Ok(v.to_vec())
355+
}
356+
357+
fn visit_borrowed_bytes<E: serde::de::Error>(
358+
self,
359+
v: &'de [u8],
360+
) -> Result<Vec<u8>, E> {
361+
Ok(v.to_vec())
362+
}
363+
364+
fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<Vec<u8>, E> {
365+
Ok(v)
366+
}
367+
368+
fn visit_seq<A: serde::de::SeqAccess<'de>>(
369+
self,
370+
mut seq: A,
371+
) -> Result<Vec<u8>, A::Error> {
372+
let mut bytes = Vec::with_capacity(seq.size_hint().unwrap_or(0));
373+
while let Some(byte) = seq.next_element()? {
374+
bytes.push(byte);
375+
}
376+
Ok(bytes)
377+
}
378+
}
379+
380+
let bytes = deserializer.deserialize_byte_buf(HllBytesVisitor)?;
345381
let sketch = HllSketch::deserialize(&bytes).map_err(serde::de::Error::custom)?;
346382
Ok(Self { sketch, salt: 0 })
347383
}
@@ -558,6 +594,218 @@ mod tests {
558594
assert_eq!(original_estimate, 2.0);
559595
}
560596

597+
#[test]
598+
fn cardinality_collector_postcard_roundtrip() {
599+
use super::CardinalityCollector;
600+
601+
let mut collector = CardinalityCollector::default();
602+
collector.insert("hello");
603+
collector.insert("world");
604+
collector.insert("hello");
605+
606+
let original_estimate = collector.clone().finalize().unwrap();
607+
608+
let serialized = postcard::to_allocvec(&collector).expect("postcard serialize failed");
609+
let deserialized: CardinalityCollector =
610+
postcard::from_bytes(&serialized).expect("postcard deserialize failed");
611+
612+
let roundtrip_estimate = deserialized.finalize().unwrap();
613+
assert_eq!(original_estimate, roundtrip_estimate);
614+
assert_eq!(original_estimate, 2.0);
615+
}
616+
617+
#[test]
618+
fn cardinality_collector_postcard_bytes_fidelity() {
619+
use super::CardinalityCollector;
620+
621+
let mut collector = CardinalityCollector::default();
622+
for i in 0..10u64 {
623+
collector.insert(i);
624+
}
625+
626+
let hll_bytes = collector.sketch.serialize();
627+
println!("HLL bytes len: {}, first 16: {:?}", hll_bytes.len(), &hll_bytes[..16.min(hll_bytes.len())]);
628+
629+
let postcard_bytes = postcard::to_allocvec(&collector).unwrap();
630+
println!("Postcard bytes len: {}", postcard_bytes.len());
631+
632+
let deserialized: CardinalityCollector = postcard::from_bytes(&postcard_bytes).unwrap();
633+
let hll_bytes_after = deserialized.sketch.serialize();
634+
println!(
635+
"HLL bytes after roundtrip len: {}, first 16: {:?}",
636+
hll_bytes_after.len(),
637+
&hll_bytes_after[..16.min(hll_bytes_after.len())]
638+
);
639+
640+
assert_eq!(hll_bytes, hll_bytes_after, "HLL bytes should be identical after Postcard roundtrip");
641+
}
642+
643+
#[test]
644+
fn cardinality_collector_postcard_roundtrip_large() {
645+
use super::CardinalityCollector;
646+
647+
let mut collector = CardinalityCollector::default();
648+
for i in 0..1000u64 {
649+
collector.insert(i);
650+
}
651+
652+
let original_estimate = collector.clone().finalize().unwrap();
653+
assert!((original_estimate - 1000.0).abs() < 50.0);
654+
655+
let serialized = postcard::to_allocvec(&collector).expect("postcard serialize failed");
656+
println!(
657+
"Large HLL sketch serialized to {} postcard bytes",
658+
serialized.len()
659+
);
660+
let deserialized: CardinalityCollector =
661+
postcard::from_bytes(&serialized).expect("postcard deserialize failed");
662+
663+
let roundtrip_estimate = deserialized.finalize().unwrap();
664+
assert_eq!(original_estimate, roundtrip_estimate);
665+
}
666+
667+
#[test]
668+
fn cardinality_intermediate_metric_result_postcard_roundtrip() {
669+
use super::CardinalityCollector;
670+
use crate::aggregation::intermediate_agg_result::IntermediateMetricResult;
671+
672+
let mut collector = CardinalityCollector::default();
673+
collector.insert("hello");
674+
collector.insert("world");
675+
676+
let intermediate = IntermediateMetricResult::Cardinality(collector);
677+
let serialized =
678+
postcard::to_allocvec(&intermediate).expect("postcard serialize failed");
679+
let deserialized: IntermediateMetricResult =
680+
postcard::from_bytes(&serialized).expect("postcard deserialize failed");
681+
682+
match deserialized {
683+
IntermediateMetricResult::Cardinality(c) => {
684+
assert_eq!(c.finalize().unwrap(), 2.0);
685+
}
686+
_ => panic!("expected Cardinality variant"),
687+
}
688+
}
689+
690+
#[test]
691+
fn cardinality_postcard_multisegment_roundtrip() {
692+
use crate::aggregation::agg_req::Aggregations;
693+
use crate::aggregation::collector::AggregationCollector;
694+
use crate::aggregation::intermediate_agg_result::IntermediateAggregationResults;
695+
use crate::aggregation::AggContextParams;
696+
use crate::collector::{Collector, SegmentCollector};
697+
use crate::query::AllQuery;
698+
699+
let segment_and_terms = vec![
700+
vec!["terma"],
701+
vec!["termb"],
702+
vec!["termc"],
703+
vec!["terma"],
704+
];
705+
let index = get_test_index_from_terms(false, &segment_and_terms).unwrap();
706+
707+
let agg_req: Aggregations = serde_json::from_value(json!({
708+
"cardinality": {
709+
"cardinality": {
710+
"field": "string_id",
711+
}
712+
},
713+
}))
714+
.unwrap();
715+
716+
let collector = AggregationCollector::from_aggs(
717+
agg_req,
718+
AggContextParams::new(Default::default(), index.tokenizers().clone()),
719+
);
720+
721+
let reader = index.reader().unwrap();
722+
let searcher = reader.searcher();
723+
724+
let segments = searcher.segment_readers();
725+
assert!(
726+
segments.len() > 1,
727+
"Need multiple segments for this test, got {}",
728+
segments.len()
729+
);
730+
731+
// Collect from each segment individually and serialize via Postcard
732+
let serialized_results: Vec<Vec<u8>> = segments
733+
.iter()
734+
.enumerate()
735+
.map(|(ord, segment_reader)| {
736+
let mut segment_collector = collector
737+
.for_segment(ord as u32, segment_reader)
738+
.unwrap();
739+
for doc in segment_reader.doc_ids_alive() {
740+
segment_collector.collect(doc, 0.0);
741+
}
742+
let fruit = segment_collector.harvest().unwrap();
743+
postcard::to_allocvec(&fruit).expect("postcard serialize should work")
744+
})
745+
.collect();
746+
747+
// Deserialize and merge (this is what quickwit does)
748+
let merged: IntermediateAggregationResults = serialized_results
749+
.iter()
750+
.map(|bytes| {
751+
postcard::from_bytes::<IntermediateAggregationResults>(bytes)
752+
.expect("postcard deserialize should work")
753+
})
754+
.fold(None, |acc: Option<IntermediateAggregationResults>, fruits| {
755+
match acc {
756+
Some(mut merged) => {
757+
merged.merge_fruits(fruits).unwrap();
758+
Some(merged)
759+
}
760+
None => Some(fruits),
761+
}
762+
})
763+
.unwrap();
764+
765+
// Verify the merged result can be serialized again
766+
let _final_bytes =
767+
postcard::to_allocvec(&merged).expect("final postcard serialize should work");
768+
}
769+
770+
#[test]
771+
fn cardinality_full_intermediate_agg_results_postcard_roundtrip() {
772+
use super::CardinalityCollector;
773+
use crate::aggregation::intermediate_agg_result::{
774+
IntermediateAggregationResult, IntermediateAggregationResults,
775+
IntermediateMetricResult,
776+
};
777+
778+
let mut collector = CardinalityCollector::default();
779+
collector.insert("hello");
780+
collector.insert("world");
781+
782+
let mut results = IntermediateAggregationResults::default();
783+
results
784+
.push(
785+
"test_card".to_string(),
786+
IntermediateAggregationResult::Metric(
787+
IntermediateMetricResult::Cardinality(collector),
788+
),
789+
)
790+
.unwrap();
791+
792+
let serialized =
793+
postcard::to_allocvec(&results).expect("postcard serialize failed");
794+
let deserialized: IntermediateAggregationResults =
795+
postcard::from_bytes(&serialized).expect("postcard deserialize failed");
796+
797+
let result = deserialized
798+
.aggs_res
799+
.get("test_card")
800+
.expect("missing key");
801+
match result {
802+
IntermediateAggregationResult::Metric(IntermediateMetricResult::Cardinality(c)) => {
803+
assert_eq!(c.clone().finalize().unwrap(), 2.0);
804+
}
805+
_ => panic!("expected Cardinality variant"),
806+
}
807+
}
808+
561809
#[test]
562810
fn cardinality_collector_merge() {
563811
use super::CardinalityCollector;

0 commit comments

Comments
 (0)