Skip to content

Commit 942921a

Browse files
authored
Merge pull request #83 from jymchng/gh-#80-req-diff-name
#80; completely implemented this ticket
2 parents 4f6352b + 9e383b2 commit 942921a

File tree

3 files changed

+857
-4
lines changed

3 files changed

+857
-4
lines changed

src/fastapi_shield/shield.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,15 @@ def protected_endpoint():
578578
if isinstance(param.default, ShieldDepends)
579579
}
580580

581+
request_annotation_in_guard_fn: bool = False
582+
request_param_name_in_guard_fn: str = "request"
583+
584+
for k, v in self._guard_func_params.items():
585+
if v.annotation is Request:
586+
request_param_name_in_guard_fn = k
587+
request_annotation_in_guard_fn = True
588+
break
589+
581590
dependency_cache: Optional[Dict[str, Any]] = {} if self.use_cache else None
582591

583592
@wraps(endpoint)
@@ -593,12 +602,12 @@ async def wrapper(*args, **kwargs):
593602
if obj:
594603
# from here onwards, the shield's job is done
595604
# hence we should raise an error from now on if anything goes wrong
596-
request = kwargs.get("request")
605+
request = kwargs.get(request_param_name_in_guard_fn)
597606

598607
if not request or not isinstance(request, Request):
599608
raise HTTPException(
600609
status.HTTP_400_BAD_REQUEST,
601-
detail="Request is required or `request` is not of type `Request`",
610+
detail=f"Request is required or `{request_param_name_in_guard_fn}` is not of type `Request`",
602611
)
603612

604613
if not hasattr(wrapper, SHIELDED_ENDPOINT_PATH_FORMAT_KEY):
@@ -640,8 +649,10 @@ async def wrapper(*args, **kwargs):
640649

641650
wrapper.__signature__ = Signature( # type:ignore[attr-defined]
642651
rearrange_params( # type:ignore[reportArgumentType]
643-
merge_dedup_seq_params(
644-
prepend_request_to_signature_params_of_function(self._guard_func),
652+
merge_dedup_seq_params( # type:ignore[arg-type]
653+
prepend_request_to_signature_params_of_function(self._guard_func) # type:ignore[arg-type]
654+
if not request_annotation_in_guard_fn
655+
else self._guard_func_params.values(),
645656
)
646657
)
647658
)

0 commit comments

Comments
 (0)