@@ -637,6 +637,74 @@ where
637
637
}
638
638
}
639
639
640
+ /// Represents a sequence of objects of a given type T.
641
+ ///
642
+ /// Type T can be either a primitive type (e.g. STRING) or a structure. First, the length N + 1 is given as an
643
+ /// UNSIGNED_VARINT. Then N instances of type T follow. A null array is represented with a length of 0. In protocol
644
+ /// documentation an array of T instances is referred to as `[T]`.
645
+ #[ derive( Debug , PartialEq , Eq , PartialOrd , Ord , Hash ) ]
646
+ #[ cfg_attr( test, derive( proptest_derive:: Arbitrary ) ) ]
647
+ pub struct CompactArray < T > ( pub Option < Vec < T > > ) ;
648
+
649
+ impl < R , T > ReadType < R > for CompactArray < T >
650
+ where
651
+ R : Read ,
652
+ T : ReadType < R > ,
653
+ {
654
+ fn read ( reader : & mut R ) -> Result < Self , ReadError > {
655
+ let len = UnsignedVarint :: read ( reader) ?. 0 ;
656
+ match len {
657
+ 0 => Ok ( Self ( None ) ) ,
658
+ n => {
659
+ let len = usize:: try_from ( n - 1 ) . map_err ( ReadError :: Overflow ) ?;
660
+ let mut builder = VecBuilder :: new ( len) ;
661
+ for _ in 0 ..len {
662
+ builder. push ( T :: read ( reader) ?) ;
663
+ }
664
+ Ok ( Self ( Some ( builder. into ( ) ) ) )
665
+ }
666
+ }
667
+ }
668
+ }
669
+
670
+ impl < W , T > WriteType < W > for CompactArray < T >
671
+ where
672
+ W : Write ,
673
+ T : WriteType < W > ,
674
+ {
675
+ fn write ( & self , writer : & mut W ) -> Result < ( ) , WriteError > {
676
+ CompactArrayRef ( self . 0 . as_deref ( ) ) . write ( writer)
677
+ }
678
+ }
679
+
680
+ /// Same as [`CompactArray`] but contains referenced data.
681
+ ///
682
+ /// This only supports writing.
683
+ #[ derive( Debug , PartialEq , Eq , PartialOrd , Ord , Hash ) ]
684
+ pub struct CompactArrayRef < ' a , T > ( pub Option < & ' a [ T ] > ) ;
685
+
686
+ impl < ' a , W , T > WriteType < W > for CompactArrayRef < ' a , T >
687
+ where
688
+ W : Write ,
689
+ T : WriteType < W > ,
690
+ {
691
+ fn write ( & self , writer : & mut W ) -> Result < ( ) , WriteError > {
692
+ match self . 0 {
693
+ None => UnsignedVarint ( 0 ) . write ( writer) ,
694
+ Some ( inner) => {
695
+ let len = u64:: try_from ( inner. len ( ) + 1 ) . map_err ( WriteError :: from) ?;
696
+ UnsignedVarint ( len) . write ( writer) ?;
697
+
698
+ for element in inner {
699
+ element. write ( writer) ?;
700
+ }
701
+
702
+ Ok ( ( ) )
703
+ }
704
+ }
705
+ }
706
+ }
707
+
640
708
/// Represents a sequence of Kafka records as NULLABLE_BYTES.
641
709
///
642
710
/// This primitive actually depends on the message version and evolved twice in [KIP-32] and [KIP-98]. We only support
@@ -933,23 +1001,19 @@ mod tests {
933
1001
Int32 ( i32:: MAX ) . write ( & mut buf) . unwrap ( ) ;
934
1002
buf. set_position ( 0 ) ;
935
1003
936
- // Use a rather large struct here to trigger OOM
937
- #[ derive( Debug ) ]
938
- struct Large {
939
- _inner : [ u8 ; 1024 ] ,
940
- }
1004
+ let err = Array :: < Large > :: read ( & mut buf) . unwrap_err ( ) ;
1005
+ assert_matches ! ( err, ReadError :: IO ( _) ) ;
1006
+ }
941
1007
942
- impl < R > ReadType < R > for Large
943
- where
944
- R : Read ,
945
- {
946
- fn read ( reader : & mut R ) -> Result < Self , ReadError > {
947
- Int32 :: read ( reader) ?;
948
- unreachable ! ( )
949
- }
950
- }
1008
+ test_roundtrip ! ( CompactArray <Int32 >, test_compact_array_roundtrip) ;
951
1009
952
- let err = Array :: < Large > :: read ( & mut buf) . unwrap_err ( ) ;
1010
+ #[ test]
1011
+ fn test_compact_array_blowup_memory ( ) {
1012
+ let mut buf = Cursor :: new ( Vec :: < u8 > :: new ( ) ) ;
1013
+ UnsignedVarint ( u64:: MAX ) . write ( & mut buf) . unwrap ( ) ;
1014
+ buf. set_position ( 0 ) ;
1015
+
1016
+ let err = CompactArray :: < Large > :: read ( & mut buf) . unwrap_err ( ) ;
953
1017
assert_matches ! ( err, ReadError :: IO ( _) ) ;
954
1018
}
955
1019
@@ -989,4 +1053,20 @@ mod tests {
989
1053
timestamp_type : RecordBatchTimestampType :: CreateTime ,
990
1054
}
991
1055
}
1056
+
1057
+ /// A rather large struct here to trigger OOM.
1058
+ #[ derive( Debug ) ]
1059
+ struct Large {
1060
+ _inner : [ u8 ; 1024 ] ,
1061
+ }
1062
+
1063
+ impl < R > ReadType < R > for Large
1064
+ where
1065
+ R : Read ,
1066
+ {
1067
+ fn read ( reader : & mut R ) -> Result < Self , ReadError > {
1068
+ Int32 :: read ( reader) ?;
1069
+ unreachable ! ( )
1070
+ }
1071
+ }
992
1072
}
0 commit comments