|
1 | | -import json |
2 | 1 | from dataclasses import dataclass, field |
3 | | -from typing import Any, Dict |
| 2 | +from datetime import datetime, timedelta |
| 3 | +from typing import Dict |
4 | 4 |
|
| 5 | +from jose import jwt # type: ignore # pylint: disable=import-error |
5 | 6 |
|
| 7 | +from src.base import BaseJCSerializable, KeyMapping |
| 8 | +from src.errors import EARValidationError |
| 9 | +from src.jwt_config import DEFAULT_ALGORITHM, DEFAULT_EXPIRATION_MINUTES |
| 10 | +from src.submod import Submod |
| 11 | +from src.verifier_id import VerifierID |
| 12 | + |
| 13 | + |
| 14 | +# https://datatracker.ietf.org/doc/draft-fv-rats-ear/ |
6 | 15 | @dataclass |
7 | | -class EARClaims: |
| 16 | +class AttestationResult(BaseJCSerializable): |
8 | 17 | profile: str |
9 | 18 | issued_at: int |
10 | | - verifier_id: Dict[str, str] = field(default_factory=dict) |
11 | | - submods: Dict[str, Any] = field(default_factory=dict) |
| 19 | + verifier_id: VerifierID |
| 20 | + submods: Dict[str, Submod] = field(default_factory=dict) |
12 | 21 |
|
13 | | - def to_dict(self) -> Dict[str, Any]: |
14 | | - return { |
15 | | - "eat_profile": self.profile, |
16 | | - "iat": self.issued_at, |
17 | | - "ear.verifier-id": self.verifier_id, |
18 | | - "submods": self.submods, |
19 | | - } |
| 22 | + # https://www.ietf.org/archive/id/draft-ietf-rats-eat-31.html#section-7.2.4 |
| 23 | + jc_map = { |
| 24 | + "profile": KeyMapping(265, "eat_profile"), |
| 25 | + "issued_at": KeyMapping(6, "iat"), |
| 26 | + "verifier_id": KeyMapping(1004, "ear.verifier-id"), |
| 27 | + "submods": KeyMapping(266, "submods"), |
| 28 | + } |
20 | 29 |
|
21 | | - @classmethod |
22 | | - def from_dict(cls, data: Dict[str, Any]): |
23 | | - return cls( |
24 | | - profile=data.get("eat_profile", ""), |
25 | | - issued_at=data.get("iat", 0), |
26 | | - verifier_id=data.get("ear.verifier-id", {}), |
27 | | - submods=data.get("submods", {}), |
28 | | - ) |
| 30 | + def validate(self): |
| 31 | + # Validates an AttestationResult object |
| 32 | + if not isinstance(self.profile, str) or not self.profile: |
| 33 | + raise EARValidationError( |
| 34 | + "AttestationResult profile must be a non-empty string" |
| 35 | + ) |
| 36 | + if not isinstance(self.issued_at, int) or self.issued_at <= 0: |
| 37 | + raise EARValidationError( |
| 38 | + "AttestationResult issued_at must be a positive integer" |
| 39 | + ) |
29 | 40 |
|
30 | | - def to_json(self) -> str: |
31 | | - return json.dumps(self.to_dict()) |
| 41 | + self.verifier_id.validate() |
| 42 | + |
| 43 | + for submod, details in self.submods.items(): |
| 44 | + if not isinstance(details, Submod): |
| 45 | + raise EARValidationError( |
| 46 | + f"Submodule {submod} must contain a valid trust_vector and status" |
| 47 | + ) |
| 48 | + |
| 49 | + trust_vector = details.trust_vector |
| 50 | + trust_vector.validate() |
| 51 | + |
| 52 | + def encode_jwt( |
| 53 | + self, |
| 54 | + secret_key: str, |
| 55 | + algorithm: str = DEFAULT_ALGORITHM, |
| 56 | + expiration_minutes: int = DEFAULT_EXPIRATION_MINUTES, |
| 57 | + ) -> str: |
| 58 | + # Signs an AttestationResult object and returns a JWT |
| 59 | + payload = self.to_dict() |
| 60 | + payload["exp"] = int( |
| 61 | + datetime.timestamp(datetime.now() + timedelta(minutes=expiration_minutes)) |
| 62 | + ) |
| 63 | + return jwt.encode( |
| 64 | + payload, secret_key, algorithm=algorithm |
| 65 | + ) # pyright: ignore[reportGeneralTypeIssues] |
32 | 66 |
|
33 | 67 | @classmethod |
34 | | - def from_json(cls, json_str: str): |
35 | | - return cls.from_dict(json.loads(json_str)) |
| 68 | + def decode_jwt( |
| 69 | + cls, token: str, secret_key: str, algorithm: str = DEFAULT_ALGORITHM |
| 70 | + ): |
| 71 | + # Verifies a JWT and returns the decoded AttestationResult object. |
| 72 | + try: |
| 73 | + payload = jwt.decode(token, secret_key, algorithms=[algorithm]) |
| 74 | + return cls.from_dict(payload) |
| 75 | + except Exception as exc: |
| 76 | + raise ValueError(f"JWT decoding failed: {exc}") from exc |
0 commit comments