|
| 1 | +"""CMPC bilateral chain-session state machine.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from dataclasses import dataclass, field |
| 6 | +from datetime import datetime, timezone |
| 7 | +from enum import Enum |
| 8 | +import hashlib |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +from concordia.canonicalization import canonicalize_jcs |
| 12 | + |
| 13 | + |
| 14 | +class ChainSessionState(str, Enum): |
| 15 | + PROPOSED = "PROPOSED" |
| 16 | + OPEN = "OPEN" |
| 17 | + ACTIVATED = "ACTIVATED" |
| 18 | + DISSOLVED = "DISSOLVED" |
| 19 | + EXPIRED = "EXPIRED" |
| 20 | + |
| 21 | + |
| 22 | +LEGAL_TRANSITIONS: dict[ChainSessionState, set[ChainSessionState]] = { |
| 23 | + ChainSessionState.PROPOSED: {ChainSessionState.OPEN}, |
| 24 | + ChainSessionState.OPEN: { |
| 25 | + ChainSessionState.ACTIVATED, |
| 26 | + ChainSessionState.DISSOLVED, |
| 27 | + ChainSessionState.EXPIRED, |
| 28 | + }, |
| 29 | + ChainSessionState.ACTIVATED: set(), |
| 30 | + ChainSessionState.DISSOLVED: set(), |
| 31 | + ChainSessionState.EXPIRED: set(), |
| 32 | +} |
| 33 | + |
| 34 | + |
| 35 | +class InvalidTransitionError(Exception): |
| 36 | + """Raised when a ChainSession state transition is not allowed.""" |
| 37 | + |
| 38 | + |
| 39 | +def _enum_value(value: ChainSessionState | str) -> str: |
| 40 | + return value.value if isinstance(value, ChainSessionState) else value |
| 41 | + |
| 42 | + |
| 43 | +@dataclass(kw_only=True) |
| 44 | +class TransitionRecord: |
| 45 | + from_state: ChainSessionState | str |
| 46 | + to_state: ChainSessionState | str |
| 47 | + transitioned_at: datetime |
| 48 | + evidence: dict[str, Any] | None |
| 49 | + prev_transition_hash: str | None |
| 50 | + transition_hash: str | None = None |
| 51 | + |
| 52 | + def __post_init__(self) -> None: |
| 53 | + if isinstance(self.from_state, str): |
| 54 | + self.from_state = ChainSessionState(self.from_state) |
| 55 | + if isinstance(self.to_state, str): |
| 56 | + self.to_state = ChainSessionState(self.to_state) |
| 57 | + |
| 58 | + def canonical_bytes_excl_hash(self) -> bytes: |
| 59 | + data = { |
| 60 | + "from_state": _enum_value(self.from_state), |
| 61 | + "to_state": _enum_value(self.to_state), |
| 62 | + "transitioned_at": self.transitioned_at.isoformat(), |
| 63 | + "evidence": self.evidence, |
| 64 | + "prev_transition_hash": self.prev_transition_hash, |
| 65 | + } |
| 66 | + return canonicalize_jcs(data) |
| 67 | + |
| 68 | + def compute_hash(self) -> str: |
| 69 | + return hashlib.sha256(self.canonical_bytes_excl_hash()).hexdigest() |
| 70 | + |
| 71 | + def to_dict(self) -> dict[str, Any]: |
| 72 | + return { |
| 73 | + "from_state": _enum_value(self.from_state), |
| 74 | + "to_state": _enum_value(self.to_state), |
| 75 | + "transitioned_at": self.transitioned_at, |
| 76 | + "evidence": self.evidence, |
| 77 | + "prev_transition_hash": self.prev_transition_hash, |
| 78 | + "transition_hash": self.transition_hash, |
| 79 | + } |
| 80 | + |
| 81 | + |
| 82 | +@dataclass(kw_only=True) |
| 83 | +class ChainSession: |
| 84 | + chain_session_id: str |
| 85 | + participants: list[str] |
| 86 | + closure_predicate_ref: str |
| 87 | + state: ChainSessionState | str |
| 88 | + created_at: datetime |
| 89 | + activation_deadline: datetime |
| 90 | + activated_at: datetime | None = None |
| 91 | + dissolved_at: datetime | None = None |
| 92 | + commitments: list[str] = field(default_factory=list) |
| 93 | + unwind_record_id: str | None = None |
| 94 | + activation_proof_id: str | None = None |
| 95 | + transitions: list[TransitionRecord] = field(default_factory=list) |
| 96 | + |
| 97 | + def __post_init__(self) -> None: |
| 98 | + if isinstance(self.state, str): |
| 99 | + self.state = ChainSessionState(self.state) |
| 100 | + self.transitions = [ |
| 101 | + record |
| 102 | + if isinstance(record, TransitionRecord) |
| 103 | + else TransitionRecord(**record) |
| 104 | + for record in self.transitions |
| 105 | + ] |
| 106 | + |
| 107 | + @classmethod |
| 108 | + def from_dict(cls, data: dict[str, Any]) -> "ChainSession": |
| 109 | + return cls(**data) |
| 110 | + |
| 111 | + def to_dict(self) -> dict[str, Any]: |
| 112 | + data: dict[str, Any] = { |
| 113 | + "chain_session_id": self.chain_session_id, |
| 114 | + "participants": self.participants, |
| 115 | + "closure_predicate_ref": self.closure_predicate_ref, |
| 116 | + "state": _enum_value(self.state), |
| 117 | + "created_at": self.created_at, |
| 118 | + "activation_deadline": self.activation_deadline, |
| 119 | + "activated_at": self.activated_at, |
| 120 | + "dissolved_at": self.dissolved_at, |
| 121 | + "commitments": self.commitments, |
| 122 | + "unwind_record_id": self.unwind_record_id, |
| 123 | + "activation_proof_id": self.activation_proof_id, |
| 124 | + } |
| 125 | + if self.transitions: |
| 126 | + data["transitions"] = [record.to_dict() for record in self.transitions] |
| 127 | + return data |
| 128 | + |
| 129 | + def transition_to( |
| 130 | + self, |
| 131 | + new_state: ChainSessionState, |
| 132 | + evidence: dict[str, Any] | None = None, |
| 133 | + now: datetime | None = None, |
| 134 | + ) -> None: |
| 135 | + now = now or datetime.now(timezone.utc) |
| 136 | + current_state = ChainSessionState(self.state) |
| 137 | + if new_state not in LEGAL_TRANSITIONS.get(current_state, set()): |
| 138 | + raise InvalidTransitionError( |
| 139 | + f"Illegal transition: {current_state.value} -> {new_state.value}" |
| 140 | + ) |
| 141 | + |
| 142 | + self._validate_transition_preconditions(new_state, now) |
| 143 | + prev_hash = self.transitions[-1].transition_hash if self.transitions else None |
| 144 | + record = TransitionRecord( |
| 145 | + from_state=self.state, |
| 146 | + to_state=new_state, |
| 147 | + transitioned_at=now, |
| 148 | + evidence=evidence, |
| 149 | + prev_transition_hash=prev_hash, |
| 150 | + ) |
| 151 | + record.transition_hash = record.compute_hash() |
| 152 | + self.transitions.append(record) |
| 153 | + self.state = new_state |
| 154 | + if new_state == ChainSessionState.ACTIVATED: |
| 155 | + self.activated_at = now |
| 156 | + elif new_state in (ChainSessionState.DISSOLVED, ChainSessionState.EXPIRED): |
| 157 | + self.dissolved_at = now |
| 158 | + |
| 159 | + def expire_due_to_timeout(self, now: datetime | None = None) -> None: |
| 160 | + self.transition_to( |
| 161 | + ChainSessionState.EXPIRED, |
| 162 | + evidence={"reason": "activation_timeout"}, |
| 163 | + now=now, |
| 164 | + ) |
| 165 | + |
| 166 | + def _validate_transition_preconditions( |
| 167 | + self, |
| 168 | + new_state: ChainSessionState, |
| 169 | + now: datetime, |
| 170 | + ) -> None: |
| 171 | + if self.state == ChainSessionState.PROPOSED and new_state == ChainSessionState.OPEN: |
| 172 | + if len(self.commitments) != len(self.participants): |
| 173 | + raise InvalidTransitionError( |
| 174 | + "PROPOSED -> OPEN requires " |
| 175 | + f"len(commitments)={len(self.commitments)} == " |
| 176 | + f"len(participants)={len(self.participants)}" |
| 177 | + ) |
| 178 | + |
| 179 | + if self.state == ChainSessionState.OPEN and new_state == ChainSessionState.ACTIVATED: |
| 180 | + if self.activation_proof_id is None: |
| 181 | + raise InvalidTransitionError( |
| 182 | + "OPEN -> ACTIVATED requires activation_proof_id" |
| 183 | + ) |
| 184 | + if now >= self.activation_deadline: |
| 185 | + raise InvalidTransitionError( |
| 186 | + "OPEN -> ACTIVATED requires now < activation_deadline; " |
| 187 | + f"got now={now.isoformat()}, " |
| 188 | + f"deadline={self.activation_deadline.isoformat()}" |
| 189 | + ) |
| 190 | + |
| 191 | + if self.state == ChainSessionState.OPEN and new_state == ChainSessionState.DISSOLVED: |
| 192 | + if self.unwind_record_id is None: |
| 193 | + raise InvalidTransitionError( |
| 194 | + "OPEN -> DISSOLVED requires unwind_record_id" |
| 195 | + ) |
| 196 | + |
| 197 | + if self.state == ChainSessionState.OPEN and new_state == ChainSessionState.EXPIRED: |
| 198 | + if now < self.activation_deadline: |
| 199 | + raise InvalidTransitionError( |
| 200 | + "OPEN -> EXPIRED requires now >= activation_deadline" |
| 201 | + ) |
| 202 | + if self.activation_proof_id is not None: |
| 203 | + raise InvalidTransitionError( |
| 204 | + "OPEN -> EXPIRED requires no activation_proof_id" |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +def verify_transcript(chain_session: ChainSession) -> bool: |
| 209 | + prev_hash: str | None = None |
| 210 | + for record in chain_session.transitions: |
| 211 | + if record.prev_transition_hash != prev_hash: |
| 212 | + return False |
| 213 | + if record.transition_hash != record.compute_hash(): |
| 214 | + return False |
| 215 | + prev_hash = record.transition_hash |
| 216 | + return True |
0 commit comments