Skip to content

Commit 8463fb9

Browse files
anafalcaoleandrodamascenadreamorosi
authored
feat(bedrock_agents): add optional fields to response payload (#6336)
* feat(bedrock_agents): add optional fields to response payload * reformat bedrock response * fix type knowledge_base_config * fix type knowledge_base_config * fix bedrock response * mypy * mypy * remove unnecessary attributes and add docstrings * remove unnecessary attributes and add docstrings * add more tests * Fix middleware validation * Fix middleware validation --------- Co-authored-by: Leandro Damascena <[email protected]> Co-authored-by: Andrea Amorosi <[email protected]>
1 parent 640a032 commit 8463fb9

File tree

6 files changed

+239
-18
lines changed

6 files changed

+239
-18
lines changed

aws_lambda_powertools/event_handler/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Response,
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
14-
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
14+
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse
1515
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
1616
from aws_lambda_powertools.event_handler.lambda_function_url import (
1717
LambdaFunctionUrlResolver,
@@ -26,6 +26,7 @@
2626
"ALBResolver",
2727
"ApiGatewayResolver",
2828
"BedrockAgentResolver",
29+
"BedrockResponse",
2930
"CORSConfig",
3031
"LambdaFunctionUrlResolver",
3132
"Response",

aws_lambda_powertools/event_handler/api_gateway.py

+48-12
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
7474
_ROUTE_REGEX = "^{}$"
7575
_JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder)
76+
_DEFAULT_CONTENT_TYPE = "application/json"
7677

7778
ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
7879
ResponseT = TypeVar("ResponseT")
@@ -255,6 +256,35 @@ def build_allow_methods(methods: set[str]) -> str:
255256
return ",".join(sorted(methods))
256257

257258

259+
class BedrockResponse(Generic[ResponseT]):
260+
"""
261+
Contains the response body, status code, content type, and optional attributes
262+
for session management and knowledge base configuration.
263+
"""
264+
265+
def __init__(
266+
self,
267+
body: Any = None,
268+
status_code: int = 200,
269+
content_type: str = _DEFAULT_CONTENT_TYPE,
270+
session_attributes: dict[str, Any] | None = None,
271+
prompt_session_attributes: dict[str, Any] | None = None,
272+
knowledge_bases_configuration: list[dict[str, Any]] | None = None,
273+
) -> None:
274+
self.body = body
275+
self.status_code = status_code
276+
self.content_type = content_type
277+
self.session_attributes = session_attributes
278+
self.prompt_session_attributes = prompt_session_attributes
279+
self.knowledge_bases_configuration = knowledge_bases_configuration
280+
281+
def is_json(self) -> bool:
282+
"""
283+
Returns True if the response is JSON, based on the Content-Type.
284+
"""
285+
return True
286+
287+
258288
class Response(Generic[ResponseT]):
259289
"""Response data class that provides greater control over what is returned from the proxy event"""
260290

@@ -300,7 +330,7 @@ def is_json(self) -> bool:
300330
content_type = self.headers.get("Content-Type", "")
301331
if isinstance(content_type, list):
302332
content_type = content_type[0]
303-
return content_type.startswith("application/json")
333+
return content_type.startswith(_DEFAULT_CONTENT_TYPE)
304334

305335

306336
class Route:
@@ -572,7 +602,7 @@ def _get_openapi_path(
572602
operation_responses: dict[int, OpenAPIResponse] = {
573603
422: {
574604
"description": "Validation Error",
575-
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
605+
"content": {_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}},
576606
},
577607
}
578608

@@ -581,7 +611,9 @@ def _get_openapi_path(
581611
http_code = self.custom_response_validation_http_code.value
582612
operation_responses[http_code] = {
583613
"description": "Response Validation Error",
584-
"content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}},
614+
"content": {
615+
_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}},
616+
},
585617
}
586618
# Add model definition
587619
definitions["ResponseValidationError"] = response_validation_error_response_definition
@@ -594,7 +626,7 @@ def _get_openapi_path(
594626
# Case 1: there is not 'content' key
595627
if "content" not in response:
596628
response["content"] = {
597-
"application/json": self._openapi_operation_return(
629+
_DEFAULT_CONTENT_TYPE: self._openapi_operation_return(
598630
param=dependant.return_param,
599631
model_name_map=model_name_map,
600632
field_mapping=field_mapping,
@@ -645,7 +677,7 @@ def _get_openapi_path(
645677
# Add the response schema to the OpenAPI 200 response
646678
operation_responses[200] = {
647679
"description": self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
648-
"content": {"application/json": response_schema},
680+
"content": {_DEFAULT_CONTENT_TYPE: response_schema},
649681
}
650682

651683
operation["responses"] = operation_responses
@@ -1474,7 +1506,10 @@ def __call__(self, app: ApiGatewayResolver) -> dict | tuple | Response:
14741506
return self.current_middleware(app, self.next_middleware)
14751507

14761508

1477-
def _registered_api_adapter(app: ApiGatewayResolver, next_middleware: Callable[..., Any]) -> dict | tuple | Response:
1509+
def _registered_api_adapter(
1510+
app: ApiGatewayResolver,
1511+
next_middleware: Callable[..., Any],
1512+
) -> dict | tuple | Response | BedrockResponse:
14781513
"""
14791514
Calls the registered API using the "_route_args" from the Resolver context to ensure the last call
14801515
in the chain will match the API route function signature and ensure that Powertools passes the API
@@ -1632,7 +1667,7 @@ def _add_resolver_response_validation_error_response_to_route(
16321667
response_validation_error_response = {
16331668
"description": "Response Validation Error",
16341669
"content": {
1635-
"application/json": {
1670+
_DEFAULT_CONTENT_TYPE: {
16361671
"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"},
16371672
},
16381673
},
@@ -2151,7 +2186,7 @@ def swagger_handler():
21512186
if query_params.get("format") == "json":
21522187
return Response(
21532188
status_code=200,
2154-
content_type="application/json",
2189+
content_type=_DEFAULT_CONTENT_TYPE,
21552190
body=escaped_spec,
21562191
)
21572192

@@ -2538,7 +2573,7 @@ def _call_route(self, route: Route, route_arguments: dict[str, str]) -> Response
25382573
self._reset_processed_stack()
25392574

25402575
return self._response_builder_class(
2541-
response=self._to_response(
2576+
response=self._to_response( # type: ignore[arg-type]
25422577
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
25432578
),
25442579
serializer=self._serializer,
@@ -2627,7 +2662,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
26272662

26282663
return None
26292664

2630-
def _to_response(self, result: dict | tuple | Response) -> Response:
2665+
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
26312666
"""Convert the route's result to a Response
26322667
26332668
3 main result types are supported:
@@ -2638,7 +2673,7 @@ def _to_response(self, result: dict | tuple | Response) -> Response:
26382673
- Response: returned as is, and allows for more flexibility
26392674
"""
26402675
status_code = HTTPStatus.OK
2641-
if isinstance(result, Response):
2676+
if isinstance(result, (Response, BedrockResponse)):
26422677
return result
26432678
elif isinstance(result, tuple) and len(result) == 2:
26442679
# Unpack result dict and status code from tuple
@@ -2971,8 +3006,9 @@ def _get_base_path(self) -> str:
29713006
# ALB doesn't have a stage variable, so we just return an empty string
29723007
return ""
29733008

3009+
# BedrockResponse is not used here but adding the same signature to keep strong typing
29743010
@override
2975-
def _to_response(self, result: dict | tuple | Response) -> Response:
3011+
def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse:
29763012
"""Convert the route's result to a Response
29773013
29783014
ALB requires a non-null body otherwise it converts as HTTP 5xx

aws_lambda_powertools/event_handler/bedrock_agent.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_lambda_powertools.event_handler import ApiGatewayResolver
99
from aws_lambda_powertools.event_handler.api_gateway import (
1010
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
11+
BedrockResponse,
1112
ProxyEventType,
1213
ResponseBuilder,
1314
)
@@ -32,14 +33,11 @@ class BedrockResponseBuilder(ResponseBuilder):
3233

3334
@override
3435
def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
35-
"""Build the full response dict to be returned by the lambda"""
36-
self._route(event, None)
37-
3836
body = self.response.body
3937
if self.response.is_json() and not isinstance(self.response.body, str):
4038
body = self.serializer(self.response.body)
4139

42-
return {
40+
response = {
4341
"messageVersion": "1.0",
4442
"response": {
4543
"actionGroup": event.action_group,
@@ -54,6 +52,19 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]:
5452
},
5553
}
5654

55+
# Add Bedrock-specific attributes
56+
if isinstance(self.response, BedrockResponse):
57+
if self.response.session_attributes:
58+
response["sessionAttributes"] = self.response.session_attributes
59+
60+
if self.response.prompt_session_attributes:
61+
response["promptSessionAttributes"] = self.response.prompt_session_attributes
62+
63+
if self.response.knowledge_bases_configuration:
64+
response["knowledgeBasesConfiguration"] = self.response.knowledge_bases_configuration
65+
66+
return response
67+
5768

5869
class BedrockAgentResolver(ApiGatewayResolver):
5970
"""Bedrock Agent Resolver

docs/core/event_handler/bedrock_agents.md

+11
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,17 @@ You can enable user confirmation with Bedrock Agents to have your application as
323323

324324
1. Add an openapi extension
325325

326+
### Fine grained responses
327+
328+
???+ info "Note"
329+
The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB.
330+
331+
You can use `BedrockResponse` class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response){target="_blank"}.
332+
333+
```python title="working_with_bedrockresponse.py" title="Customzing your Bedrock Response" hl_lines="5 16"
334+
--8<-- "examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py"
335+
```
336+
326337
## Testing your code
327338

328339
Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from http import HTTPStatus
2+
3+
from aws_lambda_powertools import Logger, Tracer
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
5+
from aws_lambda_powertools.event_handler.api_gateway import BedrockResponse
6+
from aws_lambda_powertools.utilities.typing import LambdaContext
7+
8+
tracer = Tracer()
9+
logger = Logger()
10+
app = BedrockAgentResolver()
11+
12+
13+
@app.get("/return_with_session", description="Returns a hello world with session attributes")
14+
@tracer.capture_method
15+
def hello_world():
16+
return BedrockResponse(
17+
status_code=HTTPStatus.OK.value,
18+
body={"message": "Hello from Bedrock!"},
19+
session_attributes={"user_id": "123"},
20+
prompt_session_attributes={"context": "testing"},
21+
knowledge_bases_configuration=[
22+
{
23+
"knowledgeBaseId": "kb-123",
24+
"retrievalConfiguration": {
25+
"vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"},
26+
},
27+
},
28+
],
29+
)
30+
31+
32+
@logger.inject_lambda_context
33+
@tracer.capture_lambda_handler
34+
def lambda_handler(event: dict, context: LambdaContext):
35+
return app.resolve(event, context)

0 commit comments

Comments
 (0)