diff --git a/robyn/__init__.py b/robyn/__init__.py index 08dd0631d..a570114b7 100644 --- a/robyn/__init__.py +++ b/robyn/__init__.py @@ -155,6 +155,7 @@ def add_route( auth_required: bool = False, openapi_name: str = "", openapi_tags: Union[List[str], None] = None, + openapi_responses: Optional[dict] = None, ): """ Connect a URI to a handler @@ -215,6 +216,7 @@ def add_route( openapi_tags=list_openapi_tags, exception_handler=self.exception_handler, injected_dependencies=injected_dependencies, + openapi_responses=openapi_responses, ) logger.info("Added route %s %s", route_type, normalized_endpoint) @@ -371,6 +373,7 @@ def get( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["get"], + openapi_responses: Optional[dict] = None, ): """ The @app.get decorator to add a route with the GET method @@ -380,10 +383,11 @@ def get( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, openapi_name, openapi_tags) + return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, openapi_name, openapi_tags, openapi_responses) return inner @@ -393,6 +397,7 @@ def post( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"], + openapi_responses: Optional[dict] = None, ): """ The @app.post decorator to add a route with POST method @@ -401,10 +406,19 @@ def post( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.POST, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -414,6 +428,7 @@ def put( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"], + openapi_responses: Optional[dict] = None, ): """ The @app.put decorator to add a get route with PUT method @@ -422,10 +437,19 @@ def put( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.PUT, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -435,6 +459,7 @@ def delete( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"], + openapi_responses: Optional[dict] = None, ): """ The @app.delete decorator to add a route with DELETE method @@ -443,10 +468,19 @@ def delete( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.DELETE, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -456,6 +490,7 @@ def patch( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"], + openapi_responses: Optional[dict] = None, ): """ The @app.patch decorator to add a route with PATCH method @@ -464,10 +499,19 @@ def patch( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.PATCH, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -477,6 +521,7 @@ def head( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"], + openapi_responses: Optional[dict] = None, ): """ The @app.head decorator to add a route with HEAD method @@ -485,10 +530,19 @@ def head( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.HEAD, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -498,6 +552,7 @@ def options( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["options"], + openapi_responses: Optional[dict] = None, ): """ The @app.options decorator to add a route with OPTIONS method @@ -506,10 +561,19 @@ def options( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.OPTIONS, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -519,6 +583,7 @@ def connect( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["connect"], + openapi_responses: Optional[dict] = None, ): """ The @app.connect decorator to add a route with CONNECT method @@ -527,10 +592,19 @@ def connect( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.CONNECT, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -540,6 +614,7 @@ def trace( auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"], + openapi_responses: Optional[dict] = None, ): """ The @app.trace decorator to add a route with TRACE method @@ -548,10 +623,19 @@ def trace( :param auth_required bool: represents if the route needs authentication or not :param openapi_name: str -- the name of the endpoint in the openapi spec :param openapi_tags: List[str] -- for grouping of endpoints in the openapi spec + :param openapi_responses: Optional[dict] -- additional response definitions keyed by status code """ def inner(handler): - return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + return self.add_route( + HttpMethod.TRACE, + endpoint, + handler, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) return inner @@ -705,29 +789,105 @@ def __add_prefix(self, endpoint: str): return f"{normalized_prefix}{normalized_endpoint}" - def get(self, endpoint: str, const: bool = False, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["get"]): - return super().get(endpoint=self.__add_prefix(endpoint), const=const, auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def get( + self, + endpoint: str, + const: bool = False, + auth_required: bool = False, + openapi_name: str = "", + openapi_tags: List[str] = ["get"], + openapi_responses: Optional[dict] = None, + ): + return super().get( + endpoint=self.__add_prefix(endpoint), + const=const, + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def post(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"]): - return super().post(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def post( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["post"], openapi_responses: Optional[dict] = None + ): + return super().post( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def put(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"]): - return super().put(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def put( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["put"], openapi_responses: Optional[dict] = None + ): + return super().put( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def delete(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"]): - return super().delete(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def delete( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["delete"], openapi_responses: Optional[dict] = None + ): + return super().delete( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def patch(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"]): - return super().patch(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def patch( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["patch"], openapi_responses: Optional[dict] = None + ): + return super().patch( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def head(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"]): - return super().head(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def head( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["head"], openapi_responses: Optional[dict] = None + ): + return super().head( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def trace(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"]): - return super().trace(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def trace( + self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["trace"], openapi_responses: Optional[dict] = None + ): + return super().trace( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) - def options(self, endpoint: str, auth_required: bool = False, openapi_name: str = "", openapi_tags: List[str] = ["options"]): - return super().options(endpoint=self.__add_prefix(endpoint), auth_required=auth_required, openapi_name=openapi_name, openapi_tags=openapi_tags) + def options( + self, + endpoint: str, + auth_required: bool = False, + openapi_name: str = "", + openapi_tags: List[str] = ["options"], + openapi_responses: Optional[dict] = None, + ): + return super().options( + endpoint=self.__add_prefix(endpoint), + auth_required=auth_required, + openapi_name=openapi_name, + openapi_tags=openapi_tags, + openapi_responses=openapi_responses, + ) def websocket(self, endpoint: str): """ diff --git a/robyn/openapi.py b/robyn/openapi.py index 7383a3daf..2cbfc88da 100644 --- a/robyn/openapi.py +++ b/robyn/openapi.py @@ -168,7 +168,9 @@ def __post_init__(self): "externalDocs": asdict(self.info.externalDocs) if self.info.externalDocs.url else None, } - def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str, openapi_tags: List[str], handler: Callable): + def add_openapi_path_obj( + self, route_type: str, endpoint: str, openapi_name: str, openapi_tags: List[str], handler: Callable, openapi_responses: Optional[dict] = None + ): """ Adds the given path to openapi spec @@ -177,6 +179,7 @@ def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str @param openapi_name: str the name of the endpoint @param openapi_tags: List[str] for grouping of endpoints @param handler: Callable the handler function for the endpoint + @param openapi_responses: Optional[dict] additional response definitions keyed by status code """ if self.openapi_file_override: @@ -224,7 +227,7 @@ def add_openapi_path_obj(self, route_type: str, endpoint: str, openapi_name: str return_annotation = signature.return_annotation modified_endpoint, path_obj = self.get_path_obj( - endpoint, openapi_name, openapi_description, openapi_tags, query_params, request_body, return_annotation + endpoint, openapi_name, openapi_description, openapi_tags, query_params, request_body, return_annotation, openapi_responses ) if modified_endpoint not in self.openapi_spec["paths"]: @@ -274,6 +277,7 @@ def get_path_obj( query_params: Optional[str_typed_dict], request_body: Optional[str_typed_dict], return_annotation: Optional[str_typed_dict], + openapi_responses: Optional[dict] = None, ) -> Tuple[str, dict]: """ Get the "path" openapi object according to spec @@ -371,6 +375,14 @@ def get_path_obj( openapi_path_object["responses"] = {"200": {"description": "Successful Response", "content": {response_type: {"schema": response_schema}}}} + if openapi_responses: + for status_code_str, response_spec in openapi_responses.items(): + code = str(status_code_str) + if isinstance(response_spec, dict): + openapi_path_object["responses"][code] = response_spec + else: + openapi_path_object["responses"][code] = {"description": str(response_spec)} + return endpoint_with_path_params_wrapped_in_braces, openapi_path_object def get_openapi_type(self, typed_dict: str_typed_dict) -> str: diff --git a/robyn/router.py b/robyn/router.py index 242e57531..f8de0e911 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -37,6 +37,7 @@ class Route(NamedTuple): auth_required: bool openapi_name: str openapi_tags: List[str] + openapi_responses: Optional[dict] = None class RouteMiddleware(NamedTuple): @@ -138,6 +139,7 @@ def add_route( # type: ignore openapi_tags: List[str], exception_handler: Optional[Callable], injected_dependencies: dict, + openapi_responses: Optional[dict] = None, ) -> Union[Callable, CoroutineType]: # Pre-compute handler signature ONCE at registration time. # This avoids calling inspect.signature() on every request. @@ -340,7 +342,7 @@ def inner_handler(*args, **kwargs): params, new_injected_dependencies, ) - self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags)) + self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags, openapi_responses)) return async_inner_handler else: function = FunctionInfo( @@ -350,12 +352,14 @@ def inner_handler(*args, **kwargs): params, new_injected_dependencies, ) - self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags)) + self.routes.append(Route(route_type, endpoint, function, is_const, auth_required, openapi_name, openapi_tags, openapi_responses)) return inner_handler def prepare_routes_openapi(self, openapi: OpenAPI, included_routers: List) -> None: for route in self.routes: - openapi.add_openapi_path_obj(lower_http_method(route.route_type), route.route, route.openapi_name, route.openapi_tags, route.function.handler) + openapi.add_openapi_path_obj( + lower_http_method(route.route_type), route.route, route.openapi_name, route.openapi_tags, route.function.handler, route.openapi_responses + ) # TODO! after include_routes does not immediately merge all the routes # for router in included_routers: diff --git a/unit_tests/test_openapi_multiple_responses.py b/unit_tests/test_openapi_multiple_responses.py new file mode 100644 index 000000000..446f668d9 --- /dev/null +++ b/unit_tests/test_openapi_multiple_responses.py @@ -0,0 +1,30 @@ +from robyn import Robyn + + +def test_openapi_responses_registered(): + app = Robyn(__file__) + + @app.get( + "/items/:id", + openapi_responses={ + 404: {"description": "Not found"}, + 422: {"description": "Validation error"}, + }, + ) + async def get_item(): + return {} + + routes = app.router.get_routes() + assert routes[0].openapi_responses is not None + assert 404 in routes[0].openapi_responses + + +def test_openapi_responses_default_none(): + app = Robyn(__file__) + + @app.get("/items") + async def list_items(): + return [] + + routes = app.router.get_routes() + assert routes[0].openapi_responses is None