@@ -578,6 +578,15 @@ def protected_endpoint():
578578 if isinstance (param .default , ShieldDepends )
579579 }
580580
581+ request_annotation_in_guard_fn : bool = False
582+ request_param_name_in_guard_fn : str = "request"
583+
584+ for k , v in self ._guard_func_params .items ():
585+ if v .annotation is Request :
586+ request_param_name_in_guard_fn = k
587+ request_annotation_in_guard_fn = True
588+ break
589+
581590 dependency_cache : Optional [Dict [str , Any ]] = {} if self .use_cache else None
582591
583592 @wraps (endpoint )
@@ -593,12 +602,12 @@ async def wrapper(*args, **kwargs):
593602 if obj :
594603 # from here onwards, the shield's job is done
595604 # hence we should raise an error from now on if anything goes wrong
596- request = kwargs .get ("request" )
605+ request = kwargs .get (request_param_name_in_guard_fn )
597606
598607 if not request or not isinstance (request , Request ):
599608 raise HTTPException (
600609 status .HTTP_400_BAD_REQUEST ,
601- detail = "Request is required or `request ` is not of type `Request`" ,
610+ detail = f "Request is required or `{ request_param_name_in_guard_fn } ` is not of type `Request`" ,
602611 )
603612
604613 if not hasattr (wrapper , SHIELDED_ENDPOINT_PATH_FORMAT_KEY ):
@@ -640,8 +649,10 @@ async def wrapper(*args, **kwargs):
640649
641650 wrapper .__signature__ = Signature ( # type:ignore[attr-defined]
642651 rearrange_params ( # type:ignore[reportArgumentType]
643- merge_dedup_seq_params (
644- prepend_request_to_signature_params_of_function (self ._guard_func ),
652+ merge_dedup_seq_params ( # type:ignore[arg-type]
653+ prepend_request_to_signature_params_of_function (self ._guard_func ) # type:ignore[arg-type]
654+ if not request_annotation_in_guard_fn
655+ else self ._guard_func_params .values (),
645656 )
646657 )
647658 )
0 commit comments