Skip to content

Commit a930581

Browse files
facutuescaalex
andauthored
asn1: Add support for CHOICE fields (pyca#14201)
* asn1: Add support for CHOICE fields Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * only support all Variant, or all non-Variant types Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * Update src/cryptography/hazmat/asn1/asn1.py Co-authored-by: Alex Gaynor <alex.gaynor@gmail.com> * simplify tests Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * add comment about implicit CHOICEs Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * check that tags are Literal types Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * refactor tag check Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * add tests for missing coverage Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * add bound to Tag type Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * fix LiteralString usage in Python < 3.11 Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> * remove redundant assert_roundtrips call Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> --------- Signed-off-by: Facundo Tuesca <facundo.tuesca@trailofbits.com> Co-authored-by: Alex Gaynor <alex.gaynor@gmail.com>
1 parent c5503fa commit a930581

File tree

9 files changed

+824
-23
lines changed

9 files changed

+824
-23
lines changed

src/cryptography/hazmat/asn1/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
PrintableString,
1313
Size,
1414
UtcTime,
15+
Variant,
1516
decode_der,
1617
encode_der,
1718
sequence,
@@ -27,6 +28,7 @@
2728
"PrintableString",
2829
"Size",
2930
"UtcTime",
31+
"Variant",
3032
"decode_der",
3133
"encode_der",
3234
"sequence",

src/cryptography/hazmat/asn1/asn1.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
if sys.version_info < (3, 11):
1414
import typing_extensions
1515

16+
LiteralString = typing_extensions.LiteralString
17+
1618
# We use the `include_extras` parameter of `get_type_hints`, which was
1719
# added in Python 3.9. This can be replaced by the `typing` version
1820
# once the min version is >= 3.9
@@ -31,6 +33,7 @@
3133
get_type_args = typing.get_args
3234
get_type_origin = typing.get_origin
3335
Annotated = typing.Annotated
36+
LiteralString = typing.LiteralString
3437

3538
if sys.version_info < (3, 10):
3639
NoneType = type(None)
@@ -41,6 +44,30 @@
4144

4245
T = typing.TypeVar("T", covariant=True)
4346
U = typing.TypeVar("U")
47+
Tag = typing.TypeVar("Tag", bound=LiteralString)
48+
49+
50+
@dataclasses.dataclass(frozen=True)
51+
class Variant(typing.Generic[U, Tag]):
52+
"""
53+
A tagged variant for CHOICE fields with the same underlying type.
54+
55+
Use this when you have multiple CHOICE alternatives with the same type
56+
and need to distinguish between them:
57+
58+
foo: (
59+
Annotated[Variant[int, typing.Literal["IntA"]], Implicit(0)]
60+
| Annotated[Variant[int, typing.Literal["IntB"]], Implicit(1)]
61+
)
62+
63+
Usage:
64+
example = Example(foo=Variant(5, "IntA"))
65+
decoded.foo.value # The int value
66+
decoded.foo.tag # "IntA" or "IntB"
67+
"""
68+
69+
value: U
70+
tag: str
4471

4572

4673
decode_der = declarative_asn1.decode_der
@@ -101,7 +128,7 @@ def _normalize_field_type(
101128
# from it if it exists.
102129
if get_type_origin(field_type) is Annotated:
103130
annotation = _extract_annotation(field_type.__metadata__, field_name)
104-
field_type, _ = get_type_args(field_type)
131+
field_type, *_ = get_type_args(field_type)
105132
else:
106133
annotation = declarative_asn1.Annotation()
107134

@@ -150,10 +177,53 @@ def _normalize_field_type(
150177
)
151178

152179
rust_field_type = declarative_asn1.Type.Option(annotated_type)
180+
153181
else:
154-
raise TypeError(
155-
"union types other than `X | None` are currently not supported"
182+
# Otherwise, the Union is a CHOICE
183+
if isinstance(annotation.encoding, Implicit):
184+
# CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9.
185+
raise TypeError(
186+
"CHOICE (`X | Y | ...`) types should not have an IMPLICIT "
187+
"annotation"
188+
)
189+
variants = [
190+
_type_to_variant(arg, field_name)
191+
for arg in union_args
192+
if arg is not type(None)
193+
]
194+
195+
# Union types should either be all Variants
196+
# (`Variant[..] | Variant[..] | etc`) or all non Variants
197+
are_union_types_tagged = variants[0].tag_name is not None
198+
if any(
199+
(v.tag_name is not None) != are_union_types_tagged
200+
for v in variants
201+
):
202+
raise TypeError(
203+
"When using `asn1.Variant` in a union, all the other "
204+
"types in the union must also be `asn1.Variant`"
205+
)
206+
207+
if are_union_types_tagged:
208+
tags = {v.tag_name for v in variants}
209+
if len(variants) != len(tags):
210+
raise TypeError(
211+
"When using `asn1.Variant` in a union, the tags used "
212+
"must be unique"
213+
)
214+
215+
rust_choice_type = declarative_asn1.Type.Choice(variants)
216+
# If None is part of the union types, this is an OPTIONAL CHOICE
217+
rust_field_type = (
218+
declarative_asn1.Type.Option(
219+
declarative_asn1.AnnotatedType(
220+
rust_choice_type, declarative_asn1.Annotation()
221+
)
222+
)
223+
if NoneType in union_args
224+
else rust_choice_type
156225
)
226+
157227
elif get_type_origin(field_type) is builtins.list:
158228
inner_type = _normalize_field_type(
159229
get_type_args(field_type)[0], field_name
@@ -165,6 +235,51 @@ def _normalize_field_type(
165235
return declarative_asn1.AnnotatedType(rust_field_type, annotation)
166236

167237

238+
# Convert a type to a Variant. Used with types inside Union
239+
# annotations (T1, T2, etc in `Union[T1, T2, ...]`).
240+
def _type_to_variant(
241+
t: typing.Any, field_name: str
242+
) -> declarative_asn1.Variant:
243+
is_annotated = get_type_origin(t) is Annotated
244+
inner_type = get_type_args(t)[0] if is_annotated else t
245+
246+
# Check if this is a Variant[T, Tag] type
247+
if get_type_origin(inner_type) is Variant:
248+
value_type, tag_literal = get_type_args(inner_type)
249+
if get_type_origin(tag_literal) is not typing.Literal:
250+
raise TypeError(
251+
"When using `asn1.Variant` in a type annotation, the second "
252+
"type parameter must be a `typing.Literal` type. E.g: "
253+
'`Variant[int, typing.Literal["MyInt"]]`.'
254+
)
255+
tag_name = get_type_args(tag_literal)[0]
256+
257+
if hasattr(value_type, "__asn1_root__"):
258+
rust_type = value_type.__asn1_root__.inner
259+
else:
260+
rust_type = declarative_asn1.non_root_python_to_rust(value_type)
261+
262+
if is_annotated:
263+
ann_type = declarative_asn1.AnnotatedType(
264+
rust_type,
265+
_extract_annotation(t.__metadata__, field_name),
266+
)
267+
else:
268+
ann_type = declarative_asn1.AnnotatedType(
269+
rust_type,
270+
declarative_asn1.Annotation(),
271+
)
272+
273+
return declarative_asn1.Variant(Variant, ann_type, tag_name)
274+
else:
275+
# Plain type (not a tagged Variant)
276+
return declarative_asn1.Variant(
277+
inner_type,
278+
_normalize_field_type(t, field_name),
279+
None,
280+
)
281+
282+
168283
def _annotate_fields(
169284
raw_fields: dict[str, type],
170285
) -> dict[str, declarative_asn1.AnnotatedType]:

src/cryptography/hazmat/bindings/_rust/declarative_asn1.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Type:
1515
Sequence: typing.ClassVar[type]
1616
SequenceOf: typing.ClassVar[type]
1717
Option: typing.ClassVar[type]
18+
Choice: typing.ClassVar[type]
1819
PyBool: typing.ClassVar[type]
1920
PyInt: typing.ClassVar[type]
2021
PyBytes: typing.ClassVar[type]
@@ -60,6 +61,18 @@ class AnnotatedTypeObject:
6061
cls, annotated_type: AnnotatedType, value: typing.Any
6162
) -> AnnotatedTypeObject: ...
6263

64+
class Variant:
65+
python_class: type
66+
ann_type: AnnotatedType
67+
tag_name: str | None
68+
69+
def __new__(
70+
cls,
71+
python_class: type,
72+
ann_type: AnnotatedType,
73+
tag_name: str | None,
74+
) -> Variant: ...
75+
6376
class PrintableString:
6477
def __new__(cls, inner: str) -> PrintableString: ...
6578
def __repr__(self) -> str: ...

src/rust/src/declarative_asn1/decode.rs

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use pyo3::types::{PyAnyMethods, PyListMethods};
77

88
use crate::asn1::big_byte_slice_to_py_int;
99
use crate::declarative_asn1::types::{
10-
check_size_constraint, is_tag_valid_for_type, AnnotatedType, Annotation, BitString, Encoding,
11-
GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
10+
check_size_constraint, is_tag_valid_for_type, is_tag_valid_for_variant, AnnotatedType,
11+
Annotation, BitString, Encoding, GeneralizedTime, IA5String, PrintableString, Type, UtcTime,
12+
Variant,
1213
};
1314
use crate::error::CryptographyError;
1415

@@ -160,6 +161,47 @@ fn decode_bitstring<'a>(
160161
)?)
161162
}
162163

164+
// Utility function to handle explicit encoding when parsing
165+
// CHOICE fields.
166+
fn decode_choice_with_encoding<'a>(
167+
py: pyo3::Python<'a>,
168+
parser: &mut Parser<'a>,
169+
ann_type: &AnnotatedType,
170+
encoding: &Encoding,
171+
) -> ParseResult<pyo3::Bound<'a, pyo3::PyAny>> {
172+
match encoding {
173+
Encoding::Implicit(_) => Err(CryptographyError::Py(
174+
// CHOICEs cannot be IMPLICIT. See X.680 section 31.2.9.
175+
pyo3::exceptions::PyValueError::new_err(
176+
"invalid type definition: CHOICE fields cannot be implicitly encoded".to_string(),
177+
),
178+
))?,
179+
Encoding::Explicit(n) => {
180+
// Since we don't know which of the variants is present for this
181+
// CHOICE field, we'll parse this as a generic TLV encoded with
182+
// EXPLICIT, so `read_explicit_element` will consume the EXPLICIT
183+
// wrapper tag, and the TLV data will contain the variant.
184+
let tlv = parser.read_explicit_element::<asn1::Tlv<'_>>(*n)?;
185+
let type_without_explicit = AnnotatedType {
186+
inner: ann_type.inner.clone_ref(py),
187+
annotation: pyo3::Py::new(
188+
py,
189+
Annotation {
190+
default: None,
191+
encoding: None,
192+
size: None,
193+
},
194+
)?,
195+
};
196+
// Parse the TLV data (which contains the field without the EXPLICIT
197+
// wrapper)
198+
asn1::parse(tlv.full_data(), |d| {
199+
decode_annotated_type(py, d, &type_without_explicit)
200+
})
201+
}
202+
}
203+
}
204+
163205
pub(crate) fn decode_annotated_type<'a>(
164206
py: pyo3::Python<'a>,
165207
parser: &mut Parser<'a>,
@@ -173,7 +215,7 @@ pub(crate) fn decode_annotated_type<'a>(
173215
// returning the default value)
174216
if let Some(default) = &ann_type.annotation.get().default {
175217
match parser.peek_tag() {
176-
Some(next_tag) if is_tag_valid_for_type(next_tag, inner, encoding) => (),
218+
Some(next_tag) if is_tag_valid_for_type(py, next_tag, inner, encoding) => (),
177219
_ => return Ok(default.clone_ref(py).into_bound(py)),
178220
}
179221
}
@@ -210,7 +252,7 @@ pub(crate) fn decode_annotated_type<'a>(
210252
}
211253
Type::Option(cls) => {
212254
match parser.peek_tag() {
213-
Some(t) if is_tag_valid_for_type(t, cls.get().inner.get(), encoding) => {
255+
Some(t) if is_tag_valid_for_type(py, t, cls.get().inner.get(), encoding) => {
214256
// For optional types, annotations will always be associated to the `Optional` type
215257
// i.e: `Annotated[Optional[T], annotation]`, as opposed to the inner `T` type.
216258
// Therefore, when decoding the inner type `T` we must pass the annotation of the `Optional`
@@ -223,6 +265,33 @@ pub(crate) fn decode_annotated_type<'a>(
223265
_ => pyo3::types::PyNone::get(py).to_owned().into_any(),
224266
}
225267
}
268+
Type::Choice(ts) => match encoding {
269+
Some(e) => decode_choice_with_encoding(py, parser, ann_type, e.get())?,
270+
None => {
271+
for t in ts.bind(py) {
272+
let variant = t.cast::<Variant>()?.get();
273+
match parser.peek_tag() {
274+
Some(tag) if is_tag_valid_for_variant(py, tag, variant, encoding) => {
275+
let decoded_value =
276+
decode_annotated_type(py, parser, variant.ann_type.get())?;
277+
return match &variant.tag_name {
278+
Some(tag_name) => Ok(variant
279+
.python_class
280+
.call1(py, (decoded_value, tag_name))?
281+
.into_bound(py)),
282+
None => Ok(decoded_value),
283+
};
284+
}
285+
_ => continue,
286+
}
287+
}
288+
Err(CryptographyError::Py(
289+
pyo3::exceptions::PyValueError::new_err(
290+
"could not find matching variant when parsing CHOICE field".to_string(),
291+
),
292+
))?
293+
}
294+
},
226295
Type::PyBool() => decode_pybool(py, parser, encoding)?.into_any(),
227296
Type::PyInt() => decode_pyint(py, parser, encoding)?.into_any(),
228297
Type::PyBytes() => decode_pybytes(py, parser, annotation)?.into_any(),
@@ -244,3 +313,33 @@ pub(crate) fn decode_annotated_type<'a>(
244313
_ => Ok(decoded),
245314
}
246315
}
316+
317+
#[cfg(test)]
318+
mod tests {
319+
use crate::declarative_asn1::types::{AnnotatedType, Annotation, Encoding, Type, Variant};
320+
#[test]
321+
fn test_decode_implicit_choice() {
322+
pyo3::Python::initialize();
323+
pyo3::Python::attach(|py| {
324+
let result = asn1::parse(&[], |parser| {
325+
let variants: Vec<Variant> = vec![];
326+
let choice = Type::Choice(pyo3::types::PyList::new(py, variants)?.unbind());
327+
let annotation = Annotation {
328+
default: None,
329+
encoding: None,
330+
size: None,
331+
};
332+
let ann_type = AnnotatedType {
333+
inner: pyo3::Py::new(py, choice)?,
334+
annotation: pyo3::Py::new(py, annotation)?,
335+
};
336+
let encoding = Encoding::Implicit(0);
337+
super::decode_choice_with_encoding(py, parser, &ann_type, &encoding)
338+
});
339+
assert!(result.is_err());
340+
let error = result.unwrap_err();
341+
assert!(format!("{error}")
342+
.contains("invalid type definition: CHOICE fields cannot be implicitly encoded"));
343+
});
344+
}
345+
}

0 commit comments

Comments
 (0)