Skip to content

Commit 4e514aa

Browse files
committed
asn1: refactor expected tag logic
Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com>
1 parent 5424a89 commit 4e514aa

File tree

3 files changed

+65
-34
lines changed

3 files changed

+65
-34
lines changed

src/rust/src/declarative_asn1/decode.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use pyo3::types::{PyAnyMethods, PyListMethods};
77

88
use crate::asn1::big_byte_slice_to_py_int;
99
use crate::declarative_asn1::types::{
10-
check_size_constraint, type_to_tag, AnnotatedType, Annotation, BitString, Encoding,
10+
check_size_constraint, is_tag_valid_for_type, AnnotatedType, Annotation, BitString, Encoding,
1111
GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
1212
};
1313
use crate::error::CryptographyError;
@@ -172,10 +172,9 @@ pub(crate) fn decode_annotated_type<'a>(
172172
// Handle DEFAULT annotation if field is not present (by
173173
// returning the default value)
174174
if let Some(default) = &ann_type.annotation.get().default {
175-
let expected_tag = type_to_tag(inner, encoding);
176-
let next_tag = parser.peek_tag();
177-
if next_tag != Some(expected_tag) {
178-
return Ok(default.clone_ref(py).into_bound(py));
175+
match parser.peek_tag() {
176+
Some(next_tag) if is_tag_valid_for_type(next_tag, inner, encoding) => (),
177+
_ => return Ok(default.clone_ref(py).into_bound(py)),
179178
}
180179
}
181180

@@ -210,9 +209,8 @@ pub(crate) fn decode_annotated_type<'a>(
210209
})?
211210
}
212211
Type::Option(cls) => {
213-
let inner_tag = type_to_tag(cls.get().inner.get(), encoding);
214212
match parser.peek_tag() {
215-
Some(t) if t == inner_tag => {
213+
Some(t) if is_tag_valid_for_type(t, cls.get().inner.get(), encoding) => {
216214
// For optional types, annotations will always be associated to the `Optional` type
217215
// i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type.
218216
// Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional`

src/rust/src/declarative_asn1/types.rs

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -419,29 +419,52 @@ pub(crate) fn python_class_to_annotated<'p>(
419419
}
420420
}
421421

422-
pub(crate) fn type_to_tag(t: &Type, encoding: &Option<pyo3::Py<Encoding>>) -> asn1::Tag {
423-
let inner_tag = match t {
424-
Type::Sequence(_, _) => asn1::Sequence::TAG,
425-
Type::SequenceOf(_) => asn1::Sequence::TAG,
426-
Type::Option(t) => type_to_tag(t.get().inner.get(), encoding),
427-
Type::PyBool() => bool::TAG,
428-
Type::PyInt() => asn1::BigInt::TAG,
429-
Type::PyBytes() => <&[u8] as SimpleAsn1Readable>::TAG,
430-
Type::PyStr() => asn1::Utf8String::TAG,
431-
Type::PrintableString() => asn1::PrintableString::TAG,
432-
Type::IA5String() => asn1::IA5String::TAG,
433-
Type::ObjectIdentifier() => asn1::ObjectIdentifier::TAG,
434-
Type::UtcTime() => asn1::UtcTime::TAG,
435-
Type::GeneralizedTime() => asn1::GeneralizedTime::TAG,
436-
Type::BitString() => asn1::BitString::TAG,
437-
};
438-
439-
match encoding {
422+
// Checks if encoding `tag_without_encoding` using `encoding` results
423+
// in `tag`
424+
fn check_tag_with_encoding(
425+
tag_without_encoding: asn1::Tag,
426+
encoding: &Option<pyo3::Py<Encoding>>,
427+
tag: asn1::Tag,
428+
) -> bool {
429+
let tag_with_encoding = match encoding {
440430
Some(e) => match e.get() {
441-
Encoding::Implicit(n) => asn1::implicit_tag(*n, inner_tag),
431+
Encoding::Implicit(n) => asn1::implicit_tag(*n, tag_without_encoding),
442432
Encoding::Explicit(n) => asn1::explicit_tag(*n),
443433
},
444-
None => inner_tag,
434+
None => tag_without_encoding,
435+
};
436+
tag_with_encoding == tag
437+
}
438+
439+
// Given `tag` and `encoding`, returns whether that tag with that encoding
440+
// matches what one would expect to see when decoding `type_`
441+
pub(crate) fn is_tag_valid_for_type(
442+
tag: asn1::Tag,
443+
type_: &Type,
444+
encoding: &Option<pyo3::Py<Encoding>>,
445+
) -> bool {
446+
match type_ {
447+
Type::Sequence(_, _) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
448+
Type::SequenceOf(_) => check_tag_with_encoding(asn1::Sequence::TAG, encoding, tag),
449+
Type::Option(t) => is_tag_valid_for_type(tag, t.get().inner.get(), encoding),
450+
Type::PyBool() => check_tag_with_encoding(bool::TAG, encoding, tag),
451+
Type::PyInt() => check_tag_with_encoding(asn1::BigInt::TAG, encoding, tag),
452+
Type::PyBytes() => {
453+
check_tag_with_encoding(<&[u8] as SimpleAsn1Readable>::TAG, encoding, tag)
454+
}
455+
Type::PyStr() => check_tag_with_encoding(asn1::Utf8String::TAG, encoding, tag),
456+
Type::PrintableString() => {
457+
check_tag_with_encoding(asn1::PrintableString::TAG, encoding, tag)
458+
}
459+
Type::IA5String() => check_tag_with_encoding(asn1::IA5String::TAG, encoding, tag),
460+
Type::ObjectIdentifier() => {
461+
check_tag_with_encoding(asn1::ObjectIdentifier::TAG, encoding, tag)
462+
}
463+
Type::UtcTime() => check_tag_with_encoding(asn1::UtcTime::TAG, encoding, tag),
464+
Type::GeneralizedTime() => {
465+
check_tag_with_encoding(asn1::GeneralizedTime::TAG, encoding, tag)
466+
}
467+
Type::BitString() => check_tag_with_encoding(asn1::BitString::TAG, encoding, tag),
445468
}
446469
}
447470

@@ -468,14 +491,15 @@ pub(crate) fn check_size_constraint(
468491
#[cfg(test)]
469492
mod tests {
470493

494+
use asn1::SimpleAsn1Readable;
471495
use pyo3::IntoPyObject;
472496

473-
use super::{type_to_tag, AnnotatedType, Annotation, Type};
497+
use super::{is_tag_valid_for_type, AnnotatedType, Annotation, Type};
474498

475499
#[test]
476-
// Needed for coverage of `type_to_tag(Type::Option(..))`, since
477-
// `type_to_tag` is never called with an optional value.
478-
fn test_option_type_to_tag() {
500+
// Needed for coverage of `is_tag_valid_for_type(Type::Option(..))`, since
501+
// `is_tag_valid_for_type` is never called with an optional value.
502+
fn test_option_is_tag_valid_for_type() {
479503
pyo3::Python::initialize();
480504

481505
pyo3::Python::attach(|py| {
@@ -509,8 +533,11 @@ mod tests {
509533
},
510534
)
511535
.unwrap();
512-
let expected_tag = type_to_tag(&Type::Option(optional_type), &None);
513-
assert_eq!(expected_tag, type_to_tag(&Type::PyInt(), &None))
536+
assert!(is_tag_valid_for_type(
537+
asn1::BigInt::TAG,
538+
&Type::Option(optional_type),
539+
&None
540+
));
514541
})
515542
}
516543
}

tests/hazmat/asn1/test_serialization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ class Example:
467467
h: typing.Union[asn1.BitString, None]
468468
i: typing.Union[asn1.IA5String, None]
469469
j: typing.Union[x509.ObjectIdentifier, None]
470+
k: Annotated[typing.Union[str, None], asn1.Implicit(0)]
471+
only_field_present: Annotated[
472+
typing.Union[str, None], asn1.Implicit(1)
473+
]
470474

471475
assert_roundtrips(
472476
[
@@ -482,8 +486,10 @@ class Example:
482486
h=None,
483487
i=None,
484488
j=None,
489+
k=None,
490+
only_field_present="a",
485491
),
486-
b"\x30\x00",
492+
b"\x30\x03\x81\x01a",
487493
)
488494
]
489495
)

0 commit comments

Comments
 (0)