Skip to content

Commit 518a166

Browse files
committed
only support all Variant, or all non-Variant types
Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com>
1 parent db9795f commit 518a166

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

src/cryptography/hazmat/asn1/asn1.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,27 @@ def _normalize_field_type(
187187
for arg in union_args
188188
if arg is not type(None)
189189
]
190+
191+
# Union types should either be all Variants
192+
# (`Variant[..] | Variant[..] | etc`) or all non Variants
193+
are_union_types_tagged = variants[0].tag_name is not None
194+
if any(
195+
(v.tag_name is not None) != are_union_types_tagged
196+
for v in variants
197+
):
198+
raise TypeError(
199+
"When using `asn1.Variant` in a union, all the other "
200+
"types in the union must also be `asn1.Variant`"
201+
)
202+
203+
if are_union_types_tagged:
204+
tags = [v.tag_name for v in variants]
205+
if len(tags) != len(set(tags)):
206+
raise TypeError(
207+
"When using `asn1.Variant` in a union, the tags used "
208+
"must be unique"
209+
)
210+
190211
rust_choice_type = declarative_asn1.Type.Choice(variants)
191212
# If None is part of the union types, this is an OPTIONAL CHOICE
192213
rust_field_type = (

tests/hazmat/asn1/test_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,37 @@ class Example2:
312312
Annotated[int, asn1.Default(value=9)], None
313313
]
314314

315+
def test_fail_choice_with_inconsistent_types(self) -> None:
316+
with pytest.raises(
317+
TypeError,
318+
match=re.escape(
319+
"When using `asn1.Variant` in a union, all the other "
320+
"types in the union must also be `asn1.Variant`"
321+
),
322+
):
323+
324+
@asn1.sequence
325+
class Example2:
326+
invalid: typing.Union[
327+
int, asn1.Variant[bool, typing.Literal["myTag"]]
328+
]
329+
330+
def test_fail_choice_with_duplicate_tags(self) -> None:
331+
with pytest.raises(
332+
TypeError,
333+
match=re.escape(
334+
"When using `asn1.Variant` in a union, the tags used "
335+
"must be unique"
336+
),
337+
):
338+
339+
@asn1.sequence
340+
class Example2:
341+
invalid: typing.Union[
342+
asn1.Variant[int, typing.Literal["myTag"]],
343+
asn1.Variant[bool, typing.Literal["myTag"]],
344+
]
345+
315346
def test_fields_of_variant_type(self) -> None:
316347
from cryptography.hazmat.bindings._rust import declarative_asn1
317348

0 commit comments

Comments
 (0)