diff --git a/src/firebase_functions/identity_fn.py b/src/firebase_functions/identity_fn.py index 978dd81..453707b 100644 --- a/src/firebase_functions/identity_fn.py +++ b/src/firebase_functions/identity_fn.py @@ -203,6 +203,9 @@ class AdditionalUserInfo: is_new_user: bool """A boolean indicating if the user is new or not.""" + recaptcha_score: float | None + """The user's reCAPTCHA score, if available.""" + @_dataclasses.dataclass(frozen=True) class Credential: @@ -282,6 +285,12 @@ class AuthBlockingEvent: The time the event was triggered.""" +RecaptchaActionOptions = _typing.Literal["ALLOW", "BLOCK"] +""" +The reCAPTCHA action options. +""" + + class BeforeCreateResponse(_typing.TypedDict, total=False): """ The handler response type for 'before_user_created' blocking events. @@ -302,6 +311,8 @@ class BeforeCreateResponse(_typing.TypedDict, total=False): custom_claims: dict[str, _typing.Any] | None """The user's custom claims object if available.""" + recaptcha_action_override: RecaptchaActionOptions | None + class BeforeSignInResponse(BeforeCreateResponse, total=False): """ diff --git a/src/firebase_functions/private/_identity_fn.py b/src/firebase_functions/private/_identity_fn.py index b64b0c3..9597965 100644 --- a/src/firebase_functions/private/_identity_fn.py +++ b/src/firebase_functions/private/_identity_fn.py @@ -165,6 +165,7 @@ def _additional_user_info_from_token_data(token_data: dict[str, _typing.Any]): profile=profile, username=username, is_new_user=is_new_user, + recaptcha_score=token_data.get("recaptcha_score"), ) @@ -302,9 +303,35 @@ def _validate_auth_response( auth_response_dict["customClaims"] = auth_response["custom_claims"] if "session_claims" in auth_response_keys: auth_response_dict["sessionClaims"] = auth_response["session_claims"] + if "recaptcha_action_override" in auth_response_keys: + auth_response_dict["recaptchaActionOverride"] = auth_response[ + "recaptcha_action_override"] return auth_response_dict +def _generate_response_payload( + auth_response_dict: dict[str, _typing.Any] | None +) -> dict[str, _typing.Any]: + if not auth_response_dict: + return {} + + formatted_auth_response = auth_response_dict.copy() + recaptcha_action_override = formatted_auth_response.pop( + "recaptchaActionOverride", None) + result = {} + update_mask = ",".join(formatted_auth_response.keys()) + + if len(update_mask): + result["userRecord"] = { + **formatted_auth_response, "updateMask": update_mask + } + + if recaptcha_action_override is not None: + result["recaptchaActionOverride"] = recaptcha_action_override + + return result + + def before_operation_handler( func: _typing.Callable, event_type: str, @@ -329,13 +356,7 @@ def before_operation_handler( if not auth_response: return _jsonify({}) auth_response_dict = _validate_auth_response(event_type, auth_response) - update_mask = ",".join(auth_response_dict.keys()) - result = { - "userRecord": { - **auth_response_dict, - "updateMask": update_mask, - } - } + result = _generate_response_payload(auth_response_dict) return _jsonify(result) # Disable broad exceptions lint since we want to handle all exceptions. # pylint: disable=broad-except