diff --git a/avro/src/de.rs b/avro/src/de.rs index 13021260..55bc945a 100644 --- a/avro/src/de.rs +++ b/avro/src/de.rs @@ -16,7 +16,7 @@ // under the License. //! Logic for serde-compatible deserialization. -use crate::{Error, bytes::DE_BYTES_BORROWED, error::Details, types::Value}; +use crate::{AvroResult, Error, bytes::DE_BYTES_BORROWED, error::Details, types::Value}; use serde::{ Deserialize, de::{self, DeserializeSeed, Deserializer as _, Visitor}, @@ -756,7 +756,7 @@ impl<'de> de::Deserializer<'de> for StringDeserializer { /// /// This conversion can fail if the structure of the `Value` does not match the /// structure expected by `D`. -pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> Result { +pub fn from_value<'de, D: Deserialize<'de>>(value: &'de Value) -> AvroResult { let de = Deserializer::new(value); D::deserialize(&de) } diff --git a/avro/src/de_schema.rs b/avro/src/de_schema.rs new file mode 100644 index 00000000..4d47de62 --- /dev/null +++ b/avro/src/de_schema.rs @@ -0,0 +1,982 @@ +use crate::error::Details; +use crate::{ + util::{zag_i32, zag_i64}, Error, + Schema, +}; +use serde::de::Visitor; +use serde::{de, forward_to_deserialize_any}; +use std::io::Read; +use std::slice::Iter; + +pub struct SchemaAwareReadDeserializer<'a, R: Read> { + reader: &'a mut R, + root_schema: &'a Schema, +} + +impl<'a, R: Read> SchemaAwareReadDeserializer<'a, R> { + #[allow(dead_code)] // TODO: remove! It is actually used in reader.rs + pub(crate) fn new(reader: &'a mut R, root_schema: &'a Schema) -> Self { + Self { + reader, + root_schema, + } + } +} + +impl<'de: 'a, 'a, R: Read> serde::de::Deserializer<'de> for SchemaAwareReadDeserializer<'a, R> { + type Error = Error; + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + // Implement the deserialization logic here + unimplemented!() + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let schema = self.root_schema; + let mut this = self; + this.deserialize_bool_with_schema(visitor, schema) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i32(visitor) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i32(visitor) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let schema = self.root_schema; + let mut this = self; + this.deserialize_i32_with_schema(visitor, schema) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let schema = self.root_schema; + let mut this = self; + this.deserialize_i64_with_schema(visitor, schema) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i32(visitor) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i32(visitor) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i64(visitor) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_i64(visitor) + } + + fn deserialize_f32(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_f64(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_char(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_str(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!("Implement deserialization for str") + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.root_schema { + Schema::String => { + match zag_i64(self.reader) + .map_err(Error::into_details) { + Ok(len) => { + dbg!(len); + let mut buf = vec![0; usize::try_from(len) + .map_err(|e| Details::ConvertI64ToUsize(e, len))?]; + dbg!(&buf); + self.reader.read_exact(&mut buf).map_err(|e| { + Details::ReadBytes(e) + })?; + let string = String::from_utf8(buf) + .map_err(|e| Details::ConvertToUtf8(e))?; + visitor.visit_string(string) + } + Err(details) => { + Err(de::Error::custom(format!( + "Cannot read the length of the string schema {details:?}", + ))) + } + } + + } + not_implemented => { + Err(de::Error::custom(format!( + "Expected a String schema, but got {:?}", + not_implemented + ))) + } + } + } + + fn deserialize_bytes(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_byte_buf(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + match self.root_schema { + Schema::Null => visitor.visit_none(), + Schema::Union(union_schema) => { + match zag_i64(self.reader).map_err(Error::into_details) { + Ok(index) => { + let variants = union_schema.variants(); + let variant = variants + .get(usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?) + .ok_or(Details::GetUnionVariant { + index, + num_variants: variants.len(), + })?; + dbg!(&variant); + match variant { + Schema::Null => visitor.visit_none(), + _ => visitor.visit_some( + SchemaAwareReadDeserializer::new(self.reader, variant), + ), + } + } + Err(details) => Err(de::Error::custom(format!( + "Cannot read the index of the union schema variant {details:?}", + ))), + } + } + _ => Err(de::Error::custom(format!( + "Expected a Union, but got {:?}", + self.root_schema + ))), + } + } + + fn deserialize_unit(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_seq(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_tuple(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_map(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + let schema = self.root_schema; + let mut this = self; + this.deserialize_struct_with_schema(name, fields, visitor, schema) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_identifier(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } + + fn deserialize_ignored_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + todo!() + } +} + +impl<'de: 'a, 'a, R: Read> SchemaAwareReadDeserializer<'a, R> { + fn deserialize_bool_with_schema( + &mut self, + visitor: V, + schema: &Schema, + ) -> Result + where + V: Visitor<'de>, + { + let create_error = |cause: &str| { + Details::DeserializeValueWithSchema { + value_type: "bool", + value: format!("Cause: {cause}"), + schema: schema.clone(), + } + .into() + }; + + match schema { + Schema::Boolean => { + let mut buf = [0; 1]; + self.reader + .read_exact(&mut buf) // Read a single byte + .map_err(|e| create_error(&format!("Failed to read: {e}")))?; + let value = buf[0] != 0; + visitor.visit_bool(value) + } + Schema::Union(union_schema) => { + for variant_schema in union_schema.schemas.iter() { + match variant_schema { + Schema::Boolean => { + return self.deserialize_bool_with_schema(visitor, variant_schema); + } + _ => { /* skip */ } + } + } + Err(create_error(&format!( + "The union schema must have a Boolean variant: {schema:?}" + ))) + } + unexpected => Err(create_error(&format!( + "Expected a boolean schema, found: {unexpected:?}" + ))), + } + } + + fn deserialize_i32_with_schema( + &mut self, + visitor: V, + schema: &Schema, + ) -> Result + where + V: Visitor<'de>, + { + let create_error = |cause: &str| { + Error::new(Details::DeserializeValueWithSchema { + value_type: "i32", + value: format!("Cause: {cause}"), + schema: schema.clone(), + }) + }; + + match schema { + Schema::Int | Schema::TimeMillis | Schema::Date => { + let int = zag_i32(self.reader)?; + visitor.visit_i32(int) + } + Schema::Union(union_schema) => { + for variant_schema in union_schema.schemas.iter() { + match variant_schema { + Schema::Int | Schema::TimeMillis | Schema::Date => { + return self.deserialize_i32_with_schema(visitor, variant_schema); + } + _ => { /* skip */ } + } + } + Err(create_error(&format!( + "The union schema must have an Int[-like] variant: {schema:?}" + ))) + } + unexpected => Err(create_error(&format!( + "Expected an Int[-like] schema, found: {unexpected:?}" + ))), + } + } + + fn deserialize_i64_with_schema( + &mut self, + visitor: V, + schema: &Schema, + ) -> Result + where + V: Visitor<'de>, + { + let create_error = |cause: &str| { + Details::DeserializeValueWithSchema { + value_type: "i64", + value: format!("Cause: {cause}"), + schema: schema.clone(), + } + .into() + }; + + match schema { + Schema::Int | Schema::TimeMillis | Schema::Date => { + let long = zag_i64(self.reader)?; + let int = i32::try_from(long) + .map_err(|cause| create_error(cause.to_string().as_str()))?; + visitor.visit_i32(int) + } + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + let long = zag_i64(self.reader)?; + visitor.visit_i64(long) + } + Schema::Union(union_schema) => { + for variant_schema in union_schema.schemas.iter() { + match variant_schema { + Schema::Int + | Schema::TimeMillis + | Schema::Date + | Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + return self.deserialize_i64_with_schema(visitor, variant_schema); + } + _ => { /* skip */ } + } + } + Err(create_error(&format!( + "The union schema must have a Long[-like] variant: {schema:?}" + ))) + } + unexpected => Err(create_error(&format!( + "Expected a Long[-like] schema, found: {unexpected:?}" + ))), + } + } + + fn deserialize_struct_with_schema( + &'a mut self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + schema: &'a Schema, + ) -> Result + where + V: Visitor<'de>, + { + let create_error = |cause: &str| { + Details::DeserializeValueWithSchema { + value_type: "struct", + value: format!("Cause: {cause}"), + schema: schema.clone(), + } + .into() + }; + dbg!(name, fields); + match schema { + Schema::Record(record_schema) => { + visitor.visit_map(RecordSchemaAwareReadDeserializerStruct::new( + self, + name, + fields.iter(), + record_schema, + )) + } + Schema::Union(union_schema) => { + for variant_schema in union_schema.schemas.iter() { + match variant_schema { + Schema::Int + | Schema::TimeMillis + | Schema::Date + | Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + return self.deserialize_i64_with_schema(visitor, variant_schema); + } + _ => { /* skip */ } + } + } + Err(create_error(&format!( + "The union schema must have a Long[-like] variant: {schema:?}" + ))) + } + unexpected => Err(create_error(&format!( + "Expected a Long[-like] schema, found: {unexpected:?}" + ))), + } + } +} + +struct RecordSchemaAwareReadDeserializerStruct<'a, R: Read> { + deser: &'a mut SchemaAwareReadDeserializer<'a, R>, + _schema_name: &'static str, + fields: Iter<'a, &'static str>, + current_field: Option<&'static str>, + record_schema: &'a crate::schema::RecordSchema, +} + +impl<'a, R: Read> RecordSchemaAwareReadDeserializerStruct<'a, R> { + fn new( + deser: &'a mut SchemaAwareReadDeserializer<'a, R>, + _schema_name: &'static str, + fields: Iter<'a, &'static str>, + record_schema: &'a crate::schema::RecordSchema, + ) -> Self { + Self { + deser, + _schema_name, + fields, + current_field: None, + record_schema, + } + } +} + +impl<'de: 'a, 'a, R: Read> de::MapAccess<'de> + for RecordSchemaAwareReadDeserializerStruct<'a, R> +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: de::DeserializeSeed<'de>, + { + match self.fields.next() { + Some(&field_name) => { + self.current_field = Some(field_name); + seed + .deserialize(StringDeserializer { input: field_name }) + .map(Some) + }, + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: de::DeserializeSeed<'de>, + { + match self.current_field.take() { + Some(field_name) => { + let schema = self.record_schema; + let record_field = schema.lookup.get(field_name) + .and_then(|idx| schema.fields.get(*idx)) + .ok_or_else(|| { + Error::new(Details::DeserializeValueWithSchema { + value_type: "struct", + value: format!("Field '{field_name}' not found in record schema"), + schema: Schema::Record(schema.clone()), + }) + })?; + let field_schema = &record_field.schema; + seed.deserialize(SchemaAwareReadDeserializer::new( + self.deser.reader, + field_schema, + )) + } + None => Err(de::Error::custom("should not happen - too many values")), + } + } +} + +// struct RecordSchemaAwareReadDeserializer<'s, R: Read> { +// deser: &'s mut SchemaAwareReadDeserializer<'s, R>, +// schema_name: &'static str, +// fields: Iter<'s, &'static str>, +// record_schema: &'s crate::schema::RecordSchema, +// } +// +// impl<'s, R: Read> RecordSchemaAwareReadDeserializer<'s, R> { +// fn new( +// deser: &'s mut SchemaAwareReadDeserializer<'s, R>, +// schema_name: &'static str, +// fields: Iter<'s, &'static str>, +// record_schema: &'s crate::schema::RecordSchema, +// ) -> Self { +// Self { +// deser, +// schema_name, +// fields, +// record_schema, +// } +// } +// } + +// impl<'de, R: Read> de::MapAccess<'de> for RecordSchemaAwareReadDeserializer<'de, R> { +// type Error = Error; +// +// fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> +// where +// K: de::DeserializeSeed<'de>, +// { +// match self.fields.next() { +// Some(&field_name) => seed +// .deserialize(StringDeserializer { input: field_name }) +// .map(Some), +// None => Ok(None), +// } +// } +// +// fn next_value_seed(&mut self, seed: V) -> Result +// where +// V: de::DeserializeSeed<'de>, +// { +// match self.fields.next() { +// Some(&field_name) => { +// let field_idx = self.record_schema.lookup.get(field_name).ok_or_else(|| { +// return Error::new(Details::DeserializeValueWithSchema { +// value_type: "field", +// value: format!("Field '{field_name}' not found in record schema"), +// schema: Schema::Record(self.record_schema.clone()), +// }); +// })?; +// let record_field = self.record_schema.fields.get(*field_idx).ok_or_else(|| { +// return Error::new(Details::DeserializeValueWithSchema { +// value_type: "field", +// value: format!("Field index {field_idx} out of bounds"), +// schema: Schema::Record(self.record_schema.clone()), +// }); +// })?; +// let field_schema = &record_field.schema; +// seed.deserialize(SchemaAwareReadDeserializer::new( +// self.deser.reader, +// field_schema, +// )) +// } +// None => Err(de::Error::custom("should not happen - too many values")), +// } +// } +// } + +#[derive(Clone)] +struct StringDeserializer<'de> { + input: &'de str, +} + +impl<'de> de::Deserializer<'de> for StringDeserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_str(self.input) + } + + forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string unit option + seq bytes byte_buf map unit_struct newtype_struct + tuple_struct struct tuple enum identifier ignored_any + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Details; + use crate::{ + reader::read_avro_datum_ref, + schema::{Schema, UnionSchema}, + util::{zig_i32, zig_i64}, + }; + use apache_avro_test_helper::TestResult; + use std::io::Cursor; + + #[test] + fn avro_rs_226_deserialize_bool_boolean_schema() -> TestResult { + let schema = Schema::Boolean; + + for (byte, expected) in [(0, false), (1, true)] { + let mut reader: &[u8] = &[byte]; + let read: bool = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, expected); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_union_boolean_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Boolean])?); + + for (byte, expected) in [(0, false), (1, true)] { + let mut reader: &[u8] = &[byte]; + let read: bool = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, expected); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_invalid_schema() -> TestResult { + let schema = Schema::Long; // Using a non-boolean schema + + let mut reader: &[u8] = &[0, 1, 2]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) + { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "bool"); + assert!(value.contains("Cause: Expected a boolean schema")); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_bool_union_invalid_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Long])?); + + let mut reader: &[u8] = &[1, 2, 3]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) + { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "bool"); + assert!(value.contains("The union schema must have a Boolean variant")); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid union schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_int32_int_schema() -> TestResult { + let schema = Schema::Int; + + for value in [123_i32, -1024_i32] { + let mut writer = vec![]; + zig_i32(value, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: i32 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_int32_union_int_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Int])?); + + for value in [123_i32, -1024_i32] { + let mut writer = vec![]; + zig_i32(value, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: i32 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_i32_invalid_schema() -> TestResult { + let schema = Schema::Long; // Using a non-Int schema + + let mut reader: &[u8] = &[0, 1, 2]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "i32"); + assert!( + value.contains("Cause: Expected an Int[-like] schema"), + "Got: {value}", + ); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_i32_union_invalid_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::Long])?); + + let mut reader: &[u8] = &[1, 2, 3]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "i32"); + assert!( + value.contains("The union schema must have an Int[-like] variant"), + "Got: {value}", + ); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid union schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_int64_int_schema() -> TestResult { + let schema = Schema::TimeMillis; + + for value in [i32::MAX, -i32::MAX] { + let mut writer = vec![]; + zig_i64(value as i64, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: i64 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value as i64); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_int64_long_schema() -> TestResult { + let schema = Schema::TimestampMicros; + + for value in [i64::MAX, -i64::MAX] { + let mut writer = vec![]; + zig_i64(value, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: i64 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_int64_union_int_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::TimeMicros])?); + + for value in [123_i64, -1024_i64] { + let mut writer = vec![]; + zig_i64(value, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: i64 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_i64_invalid_schema() -> TestResult { + let schema = Schema::Uuid; // Using a non-Long schema + + let mut reader: &[u8] = &[0, 1, 2]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "i64"); + assert!( + value.contains("Cause: Expected a Long[-like] schema"), + "Got: {value}", + ); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_i64_union_invalid_schema() -> TestResult { + let schema = Schema::Union(UnionSchema::new(vec![Schema::Null, Schema::String])?); + + let mut reader: &[u8] = &[1, 2, 3]; + match read_avro_datum_ref::(&schema, &mut reader).map_err(Error::into_details) { + Err(Details::DeserializeValueWithSchema { + value_type, + value, + schema, + }) => { + assert_eq!(value_type, "i64"); + assert!( + value.contains("The union schema must have a Long[-like] variant"), + "Got: {value}", + ); + assert_eq!(schema.to_string(), schema.to_string()); + } + _ => panic!("Expected an error for invalid union schema"), + } + + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_u8_int_schema() -> TestResult { + let schema = Schema::TimeMillis; + + for value in [u8::MAX, 0] { + let mut writer = vec![]; + zig_i32(value as i32, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: u8 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_u16_int_schema() -> TestResult { + let schema = Schema::TimeMillis; + + for value in [u16::MAX, 0] { + let mut writer = vec![]; + zig_i32(value as i32, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: u16 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } + + #[test] + fn avro_rs_226_deserialize_u32_long_schema() -> TestResult { + let schema = Schema::TimeMicros; + + for value in [u32::MAX, 0] { + let mut writer = vec![]; + zig_i64(value as i64, &mut writer)?; + let mut reader = Cursor::new(&writer); + let read: u32 = read_avro_datum_ref(&schema, &mut reader)?; + assert_eq!(read, value); + } + Ok(()) + } +} diff --git a/avro/src/error.rs b/avro/src/error.rs index 3e4553c4..945d5ab1 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -508,6 +508,13 @@ pub enum Details { schema: Schema, }, + #[error("Failed to deserialize value of type {value_type} using schema {schema:?}: {value}")] + DeserializeValueWithSchema { + value_type: &'static str, + value: String, + schema: Schema, + }, + #[error("Failed to serialize field '{field_name}' for record {record_schema:?}: {error}")] SerializeRecordFieldWithSchema { field_name: &'static str, diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 225cbc13..f5fa905c 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -870,6 +870,7 @@ mod ser_schema; mod util; mod writer; +mod de_schema; pub mod error; pub mod headers; pub mod rabin; @@ -899,7 +900,7 @@ pub use duration::{Days, Duration, Millis, Months}; pub use error::Error; pub use reader::{ GenericSingleObjectReader, Reader, SpecificSingleObjectReader, from_avro_datum, - from_avro_datum_reader_schemata, from_avro_datum_schemata, read_marker, + from_avro_datum_reader_schemata, from_avro_datum_schemata, read_avro_datum_ref, read_marker, }; pub use schema::{AvroSchema, Schema}; pub use ser::to_value; diff --git a/avro/src/reader.rs b/avro/src/reader.rs index e2f7570f..d22fca31 100644 --- a/avro/src/reader.rs +++ b/avro/src/reader.rs @@ -16,10 +16,11 @@ // under the License. //! Logic handling reading from Avro format at user level. +use crate::error::Details; use crate::{ AvroResult, Codec, Error, + de_schema::SchemaAwareReadDeserializer, decode::{decode, decode_internal}, - error::Details, from_value, headers::{HeaderBuilder, RabinFingerprintHeader}, schema::{ @@ -597,6 +598,16 @@ pub fn read_marker(bytes: &[u8]) -> [u8; 16] { marker } +#[allow(dead_code)] // TODO: remove! It is used in de_schema.rs tests +pub fn read_avro_datum_ref( + schema: &Schema, + reader: &mut R, +) -> AvroResult { + // let names: NamesRef = NamesRef::default(); + let deserializer = SchemaAwareReadDeserializer::new(reader, schema); + D::deserialize(deserializer) +} + #[cfg(test)] mod tests { use super::*; diff --git a/avro/src/writer.rs b/avro/src/writer.rs index 9c879918..f9d4a0ac 100644 --- a/avro/src/writer.rs +++ b/avro/src/writer.rs @@ -728,8 +728,7 @@ pub fn write_avro_datum_ref( ) -> AvroResult { let names: HashMap = HashMap::new(); let mut serializer = SchemaAwareWriteSerializer::new(writer, schema, &names, None); - let bytes_written = data.serialize(&mut serializer)?; - Ok(bytes_written) + data.serialize(&mut serializer) } /// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also diff --git a/avro/tests/avro-rs-226.rs b/avro/tests/avro-rs-226.rs index 10dc80db..f306708a 100644 --- a/avro/tests/avro-rs-226.rs +++ b/avro/tests/avro-rs-226.rs @@ -15,36 +15,52 @@ // specific language governing permissions and limitations // under the License. -use apache_avro::{AvroSchema, Schema, Writer, from_value}; +use apache_avro::{AvroSchema, Schema, read_avro_datum_ref, write_avro_datum_ref}; use apache_avro_test_helper::TestResult; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::fmt::Debug; +use std::io::Cursor; fn ser_deser(schema: &Schema, record: T) -> TestResult where T: Serialize + DeserializeOwned + Debug + PartialEq + Clone, { let record2 = record.clone(); - let mut writer = Writer::new(schema, vec![]); - writer.append_ser(record)?; - let bytes_written = writer.into_inner()?; + // let mut writer = Writer::new(schema, vec![]); + // writer.append_ser(record)?; + // let bytes_written = writer.into_inner()?; - let reader = apache_avro::Reader::new(&bytes_written[..])?; - for value in reader { - let value = value?; - let deserialized = from_value::(&value)?; - assert_eq!(deserialized, record2); - } + let mut writer = vec![]; + let _written = write_avro_datum_ref(schema, &record, &mut writer)?; + + // let mut bytes_written = Cursor::new(bytes_written); + // let value = from_avro_datum(schema, &mut bytes_written, None)?; + // dbg!(&value); + // let deserialized = from_value::(&value)?; + // assert_eq!(deserialized, record2); + + // let reader = apache_avro::Reader::with_schema(schema, &bytes_written[..])?; + // for value in reader { + // let value = value?; + // dbg!(&value); + // let deserialized = from_value::(&value)?; + // assert_eq!(deserialized, record2); + // } + + // let mut reader = Cursor::new(&bytes_written); + let mut reader = Cursor::new(&writer); + let deserialized: T = read_avro_datum_ref(schema, &mut reader)?; + assert_eq!(deserialized, record2); Ok(()) } #[test] -fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field() -> TestResult { +fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() -> TestResult { #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] struct T { - x: Option, #[serde(skip_serializing_if = "Option::is_none")] + x: Option, y: Option, z: Option, } @@ -53,18 +69,18 @@ fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field &T::get_schema(), T { x: None, - y: None, - z: Some(1), + y: Some("test".to_string()), + z: Some(23), }, ) } #[test] -fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_first_field() -> TestResult { +fn avro_rs_226_index_out_of_bounds_with_serde_skip_serializing_skip_middle_field() -> TestResult { #[derive(AvroSchema, Clone, Debug, Deserialize, PartialEq, Serialize)] struct T { - #[serde(skip_serializing_if = "Option::is_none")] x: Option, + #[serde(skip_serializing_if = "Option::is_none")] y: Option, z: Option, }