1717import hmac
1818import json
1919from dataclasses import dataclass
20+ from dataclasses import field as dataclass_field
2021from datetime import UTC , datetime
2122from typing import Final , Literal
2223
2728MfaChallengePurpose = Literal ["login" , "action" , "phone_verify" ]
2829
2930_AUDIENCE_FIELD : Final [str ] = "audience_json"
31+ _EXTRA_FIELD : Final [str ] = "extra_json"
3032
3133
3234class MfaChallengeStoreError (Exception ):
@@ -57,6 +59,7 @@ class ChallengeState:
5759 jti : str
5860 audience : str | list [str ] | None
5961 created_at : datetime
62+ extra : dict [str , str ] = dataclass_field (default_factory = dict )
6063
6164
6265class MfaChallengeStore :
@@ -75,8 +78,18 @@ async def store(
7578 jti : str ,
7679 ttl_seconds : int ,
7780 audience : str | list [str ] | None = None ,
81+ extra : dict [str , str ] | None = None ,
7882 ) -> None :
79- """Persist a fresh challenge, overwriting any prior row for (user, purpose)."""
83+ """Persist a fresh challenge, overwriting any prior row for (user, purpose).
84+
85+ ``extra`` lets callers attach purpose-specific metadata (for example the
86+ pending phone ciphertext during a phone-verify flow, or the action name
87+ on an action challenge) that survives the round trip through Redis.
88+ Reserved field names (``user_id``, ``purpose``, ``method``,
89+ ``code_hash``, ``attempt_count``, ``jti``, ``created_at``,
90+ ``audience_json``, ``extra_json``) are silently overwritten by the
91+ canonical payload to prevent shadow data.
92+ """
8093 if not user_id .strip ():
8194 raise ValueError ("user_id must be non-empty." )
8295 if not jti .strip ():
@@ -94,6 +107,7 @@ async def store(
94107 "jti" : jti ,
95108 "created_at" : datetime .now (UTC ).isoformat (),
96109 _AUDIENCE_FIELD : json .dumps (audience ) if audience is not None else "" ,
110+ _EXTRA_FIELD : json .dumps (extra ) if extra else "" ,
97111 }
98112
99113 try :
@@ -106,6 +120,46 @@ async def store(
106120 "session_backend_unavailable" ,
107121 ) from exc
108122
123+ async def store_safely (
124+ self ,
125+ * ,
126+ user_id : str ,
127+ purpose : MfaChallengePurpose ,
128+ method : MfaMethod ,
129+ code_hash : str ,
130+ jti : str ,
131+ ttl_seconds : int ,
132+ audience : str | list [str ] | None = None ,
133+ pending_phone_ciphertext : bytes | None = None ,
134+ pending_phone_lookup_hash : str | None = None ,
135+ ) -> None :
136+ """Phone-verify-flavored ``store`` that bundles pending phone state.
137+
138+ The ciphertext is hex-encoded for safe round-tripping through Redis.
139+ Callers can read it back via :meth:`read_extra` using the
140+ ``pending_phone_ciphertext_hex`` and ``pending_phone_lookup_hash`` keys.
141+ """
142+ extra : dict [str , str ] = {}
143+ if pending_phone_ciphertext is not None :
144+ extra ["pending_phone_ciphertext_hex" ] = pending_phone_ciphertext .hex ()
145+ if pending_phone_lookup_hash is not None :
146+ extra ["pending_phone_lookup_hash" ] = pending_phone_lookup_hash
147+ await self .store (
148+ user_id = user_id ,
149+ purpose = purpose ,
150+ method = method ,
151+ code_hash = code_hash ,
152+ jti = jti ,
153+ ttl_seconds = ttl_seconds ,
154+ audience = audience ,
155+ extra = extra or None ,
156+ )
157+
158+ @staticmethod
159+ def read_extra (* , challenge : ChallengeState , key : str ) -> str | None :
160+ """Return one extra field by name from a loaded :class:`ChallengeState`."""
161+ return challenge .extra .get (key ) if challenge .extra else None
162+
109163 async def load (
110164 self ,
111165 * ,
@@ -189,6 +243,14 @@ def _deserialize(self, raw: dict[str, str]) -> ChallengeState:
189243 except (KeyError , ValueError ):
190244 created_at = datetime .now (UTC )
191245
246+ extra_raw = raw .get (_EXTRA_FIELD , "" )
247+ try :
248+ extra : dict [str , str ] = (
249+ {str (k ): str (v ) for k , v in json .loads (extra_raw ).items ()} if extra_raw else {}
250+ )
251+ except (json .JSONDecodeError , AttributeError ):
252+ extra = {}
253+
192254 return ChallengeState (
193255 user_id = raw .get ("user_id" , "" ),
194256 purpose = raw .get ("purpose" , "login" ), # type: ignore[arg-type]
@@ -198,4 +260,5 @@ def _deserialize(self, raw: dict[str, str]) -> ChallengeState:
198260 jti = raw .get ("jti" , "" ),
199261 audience = audience ,
200262 created_at = created_at ,
263+ extra = extra ,
201264 )
0 commit comments