Skip to content

Commit b6ed9ec

Browse files
committed
Make from_ method more generalizable for custom classes
Signed-off-by: HarshvMahawar <[email protected]>
1 parent 3c3d806 commit b6ed9ec

File tree

11 files changed

+233
-354
lines changed

11 files changed

+233
-354
lines changed

src/base.py

Lines changed: 74 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,97 @@
1-
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, Type, TypeVar, get_type_hints
31
import json
2+
from abc import ABC
3+
from typing import Any, ClassVar, Dict, Tuple, Type, TypeVar, Union, get_args
44

55
T = TypeVar("T", bound="BaseJCSerializable")
66

7-
class BaseJCSerializable(ABC):
8-
jc_map: Dict[str, int]
97

10-
def to_dict(self) -> Dict[str, Any]:
11-
raise NotImplementedError("to_dict must be implemented in subclasses.")
8+
def to_data(value: Any, keys_as_int=False) -> Any:
9+
if hasattr(value, "to_data"):
10+
return value.to_data(keys_as_int)
11+
if hasattr(value, "items"): # dict-like
12+
return {
13+
to_data(k, keys_as_int): to_data(v, keys_as_int) for k, v in value.items()
14+
}
15+
if hasattr(value, "__iter__") and not isinstance(value, str): # list-like
16+
return [to_data(v, keys_as_int) for v in value]
1217

13-
def to_json(self) -> str:
14-
return json.dumps(self.to_dict())
18+
if hasattr(
19+
value, "value"
20+
): # custom classes that have value attr but don't have 'to_data'
21+
return value.value # type: ignore[attr-defined]
22+
# scalar and no to_data(), so assume serializable as-is
23+
return value
1524

16-
def to_cbor(self) -> Dict[int, Any]:
17-
cbor_data = {}
18-
for attr, cbor_key in self.jc_map.items():
19-
if "." in attr: # skiped nested keys, cuz, they will be processed when we go inside submods(not a nested key as in jc_map)
20-
continue
21-
value = getattr(self, attr, None)
22-
if isinstance(value, BaseJCSerializable): # trust_vector and status will be processed here
23-
cbor_data[cbor_key] = value.to_cbor()
24-
elif isinstance(value, dict): # submods will be processed here
25-
nested = {}
26-
for k, v in value.items():
27-
nested[k] = self._serialize_nested_dict(attr, v)
28-
cbor_data[cbor_key] = nested
29-
elif hasattr(value, "to_dict"): # for trust_claim
30-
cbor_data[cbor_key] = value.to_dict()
31-
else:
32-
cbor_data[cbor_key] = value
33-
return cbor_data
34-
35-
def _serialize_nested_dict(self, prefix: str, d: dict) -> dict:
36-
out = {}
37-
for subkey, val in d.items():
38-
if hasattr(val, "to_cbor"):
39-
out[self.jc_map[f"{prefix}.{subkey}"]] = val.to_cbor()
40-
elif hasattr(val, "value"): # status with trust_tier
41-
out[self.jc_map[f"{prefix}.{subkey}"]] = val.value
42-
else:
43-
out[self.jc_map.get(f"{prefix}.{subkey}", subkey)] = val
44-
return out
4525

46-
@classmethod
47-
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
48-
raise NotImplementedError("from_dict must be implemented in subclasses.")
26+
class BaseJCSerializable(ABC):
27+
jc_map: ClassVar[Dict[str, Tuple[int, str]]]
28+
29+
def to_data(self, keys_as_int=False) -> Dict[Union[str, int], Any]:
30+
return {
31+
(int_key if keys_as_int else str_key): to_data(
32+
getattr(self, attr), keys_as_int
33+
)
34+
for attr, (int_key, str_key) in self.jc_map.items()
35+
}
4936

5037
@classmethod
51-
def from_cbor(cls: Type[T], data: Dict[int, Any]) -> T:
52-
kwargs = {}
53-
reverse_map = {v: k for k, v in cls.jc_map.items()}
54-
type_hints = get_type_hints(cls)
55-
56-
for key, val in data.items():
57-
attr = reverse_map.get(key)
58-
if attr is None or "." in attr:
38+
def from_data(cls: Type[T], data: dict, keys_as_int=False) -> T:
39+
40+
if keys_as_int:
41+
index = 0
42+
else:
43+
index = 1
44+
init_kwargs = {}
45+
reverse_map = {v[index]: k for k, v in cls.jc_map.items()}
46+
for key, value in data.items():
47+
if key not in reverse_map:
5948
continue
6049

61-
hint = type_hints.get(attr)
50+
attr = reverse_map[key]
51+
field_type = getattr(cls, "__annotations__", {}).get(attr)
52+
if field_type is None:
53+
continue
6254

63-
# Handle BaseJCSerializable directly
64-
if isinstance(val, dict) and hasattr(hint, "from_cbor"):
65-
kwargs[attr] = hint.from_cbor(val)
55+
args = get_args(field_type)
6656

67-
# Handle Dict[str, BaseJCSerializable] or Dict[str, Any] with nested mapping
68-
elif isinstance(val, dict) and isinstance(hint, type) and issubclass(hint, dict):
69-
sub_hint = None
70-
if hasattr(hint, "__args__") and len(hint.__args__) > 1:
71-
sub_hint = hint.__args__[1]
72-
kwargs[attr] = {
73-
k: cls._deserialize_nested_dict(attr, v, sub_hint=sub_hint)
74-
for k, v in val.items()
57+
if hasattr(field_type, "from_data"):
58+
# Direct object
59+
init_kwargs[attr] = field_type.from_data(value, keys_as_int=keys_as_int)
60+
61+
elif hasattr(field_type, "items") and hasattr(args[1], "from_data"):
62+
# Dict[str | int, CustomClass]
63+
init_kwargs[attr] = {
64+
k: args[1].from_data(v, keys_as_int=keys_as_int)
65+
for k, v in value.items()
7566
}
7667

68+
elif args:
69+
# custom classes that dont have 'from_data'
70+
init_kwargs[attr] = args[0](value)
71+
7772
else:
78-
kwargs[attr] = val
73+
init_kwargs[attr] = field_type(value)
74+
75+
return cls(**init_kwargs)
76+
77+
def to_dict(self) -> Dict[str, Any]:
78+
# default str_keys
79+
return self.to_data() # type: ignore[return-value] # pyright: ignore[reportGeneralTypeIssues] # noqa: E501 # pylint: disable=line-too-long
7980

80-
return cls(**kwargs)
81+
def to_int_keys(self) -> Dict[Union[str, int], Any]:
82+
return self.to_data(keys_as_int=True)
8183

8284
@classmethod
83-
def _deserialize_nested_dict(cls, prefix: str, d: dict, sub_hint=None) -> dict:
84-
out = {}
85-
86-
# If sub_hint isn't given, try to get it from type hints
87-
if sub_hint is None:
88-
type_hints = get_type_hints(cls)
89-
hint = type_hints.get(prefix)
90-
if hasattr(hint, '__args__') and len(hint.__args__) > 1:
91-
sub_hint = hint.__args__[1]
92-
93-
for map_key, jc_key in cls.jc_map.items():
94-
if not map_key.startswith(f"{prefix}."):
95-
continue
85+
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
86+
return cls.from_data(data)
9687

97-
field_name = map_key.split(".")[-1]
98-
if jc_key in d:
99-
val = d[jc_key]
88+
@classmethod
89+
def from_int_keys(cls: Type[T], data: Dict[int, Any]) -> T:
90+
return cls.from_data(data, keys_as_int=True)
10091

101-
# Handle BaseJCSerializable subclasses inside subdict
102-
if hasattr(sub_hint, 'from_cbor') and isinstance(val, dict):
103-
out[field_name] = sub_hint.from_cbor(val)
104-
elif callable(sub_hint):
105-
out[field_name] = sub_hint(val)
106-
else:
107-
out[field_name] = val
92+
@classmethod
93+
def from_json(cls, json_str: str):
94+
return cls.from_dict(json.loads(json_str))
10895

109-
return out
96+
def to_json(self):
97+
return json.dumps(self.to_data())

src/claims.py

Lines changed: 12 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import json
21
from dataclasses import dataclass, field
32
from datetime import datetime, timedelta
4-
from typing import Any, Dict
3+
from typing import Dict
54

65
from jose import jwt # type: ignore # pylint: disable=import-error
76

87
from src.base import BaseJCSerializable
98
from src.errors import EARValidationError
109
from src.jwt_config import DEFAULT_ALGORITHM, DEFAULT_EXPIRATION_MINUTES
11-
from src.trust_tier import to_trust_tier
12-
from src.trust_vector import TrustVector
10+
from src.submod import Submod
1311
from src.verifier_id import VerifierID
1412

1513

@@ -19,82 +17,16 @@ class AttestationResult(BaseJCSerializable):
1917
profile: str
2018
issued_at: int
2119
verifier_id: VerifierID
22-
submods: Dict[str, Dict[str, Any]] = field(default_factory=dict)
20+
submods: Dict[str, Submod] = field(default_factory=dict)
2321

2422
# https://www.ietf.org/archive/id/draft-ietf-rats-eat-31.html#section-7.2.4
2523
jc_map = {
26-
"profile": 265,
27-
"issued_at": 6,
28-
"verifier_id": 1004,
29-
"submods": 266,
30-
"submods.trust_vector": 1001,
31-
"submods.status": 1000,
24+
"profile": (265, "profile"),
25+
"issued_at": (6, "issued_at"),
26+
"verifier_id": (1004, "verifier_id"),
27+
"submods": (266, "submods"),
3228
}
3329

34-
def to_dict(self) -> Dict[str, Any]:
35-
return {
36-
"eat_profile": self.profile,
37-
"iat": self.issued_at,
38-
"ear.verifier-id": self.verifier_id.to_dict(),
39-
"submods": {
40-
key: {
41-
"trust_vector": value["trust_vector"].to_dict(),
42-
"status": value["status"].value,
43-
}
44-
for key, value in self.submods.items()
45-
},
46-
}
47-
48-
# def to_cbor(self) -> Dict[int, Any]:
49-
# return {
50-
# self.jc_map["profile"]: self.profile,
51-
# self.jc_map["issued_at"]: self.issued_at,
52-
# self.jc_map["verifier_id"]: self.verifier_id.to_cbor(),
53-
# self.jc_map["submods"]: {
54-
# key: {
55-
# self.jc_map["submod.trust_vector"]: value["trust_vector"].to_cbor(),
56-
# self.jc_map["submod.status"]: value["status"].value,
57-
# }
58-
# for key, value in self.submods.items()
59-
# },
60-
# }
61-
62-
@classmethod
63-
def from_dict(cls, data: Dict[str, Any]):
64-
return cls(
65-
profile=data.get("eat_profile", ""),
66-
issued_at=data.get("iat", 0),
67-
verifier_id=VerifierID.from_dict(data.get("ear.verifier-id", {})),
68-
submods={
69-
key: {
70-
"trust_vector": TrustVector.from_dict(value["trust_vector"]),
71-
"status": to_trust_tier(value["status"]),
72-
}
73-
for key, value in data.get("submods", {}).items()
74-
},
75-
)
76-
77-
@classmethod
78-
def from_json(cls, json_str: str):
79-
return cls.from_dict(json.loads(json_str))
80-
81-
# @classmethod
82-
# def from_cbor(cls, data: Dict[int, Any]):
83-
# return cls(
84-
# profile=data.get(cls.jc_map["profile"], ""),
85-
# issued_at=data.get(cls.jc_map["issued_at"], 0),
86-
# verifier_id=VerifierID.from_cbor(data.get(cls.jc_map["verifier_id"], {})),
87-
# submods={
88-
# key: {
89-
# "trust_vector": TrustVector.from_cbor(
90-
# value.get(cls.jc_map["submod.trust_vector"], {})
91-
# ),
92-
# "status": to_trust_tier(value.get(cls.jc_map["submod.status"], 0)),
93-
# }
94-
# for key, value in data.get(cls.jc_map["submods"], {}).items()
95-
# },
96-
# )
97-
9830
def validate(self):
9931
# Validates an AttestationResult object
10032
if not isinstance(self.profile, str) or not self.profile:
@@ -109,16 +41,12 @@ def validate(self):
10941
self.verifier_id.validate()
11042

11143
for submod, details in self.submods.items():
112-
if (
113-
not isinstance(details, Dict)
114-
or "trust_vector" not in details
115-
or "status" not in details
116-
):
44+
if not isinstance(details, Submod):
11745
raise EARValidationError(
11846
f"Submodule {submod} must contain a valid trust_vector and status"
11947
)
12048

121-
trust_vector = details["trust_vector"]
49+
trust_vector = details.trust_vector
12250
trust_vector.validate()
12351

12452
def encode_jwt(
@@ -132,7 +60,9 @@ def encode_jwt(
13260
payload["exp"] = int(
13361
datetime.timestamp(datetime.now() + timedelta(minutes=expiration_minutes))
13462
)
135-
return jwt.encode(payload, secret_key, algorithm=algorithm)
63+
return jwt.encode(
64+
payload, secret_key, algorithm=algorithm
65+
) # pyright: ignore[reportGeneralTypeIssues]
13666

13767
@classmethod
13868
def decode_jwt(

src/example/jwt_example.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from src.claims import AttestationResult
44
from src.jwt_config import generate_secret_key
5+
from src.submod import Submod
56
from src.trust_claims import TRUSTWORTHY_INSTANCE_CLAIM, UNRECOGNIZED_INSTANCE_CLAIM
67
from src.trust_tier import TRUST_TIER_AFFIRMING, TRUST_TIER_CONTRAINDICATED
78
from src.trust_vector import TrustVector
@@ -18,14 +19,14 @@
1819
issued_at=int(datetime.timestamp(datetime.now())),
1920
verifier_id=VerifierID(developer="Acme Inc.", build="v1"),
2021
submods={
21-
"submod1": {
22-
"trust_vector": TrustVector(instance_identity=UNRECOGNIZED_INSTANCE_CLAIM),
23-
"status": TRUST_TIER_AFFIRMING,
24-
},
25-
"submod2": {
26-
"trust_vector": TrustVector(instance_identity=TRUSTWORTHY_INSTANCE_CLAIM),
27-
"status": TRUST_TIER_CONTRAINDICATED,
28-
},
22+
"submod1": Submod(
23+
trust_vector=TrustVector(instance_identity=UNRECOGNIZED_INSTANCE_CLAIM),
24+
status=TRUST_TIER_AFFIRMING,
25+
),
26+
"submod2": Submod(
27+
trust_vector=TrustVector(instance_identity=TRUSTWORTHY_INSTANCE_CLAIM),
28+
status=TRUST_TIER_CONTRAINDICATED,
29+
),
2930
},
3031
)
3132

0 commit comments

Comments
 (0)