-
Notifications
You must be signed in to change notification settings - Fork 132
Expand file tree
/
Copy pathuse_case_handler.py
More file actions
142 lines (119 loc) · 6.05 KB
/
use_case_handler.py
File metadata and controls
142 lines (119 loc) · 6.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import os
from typing import Any, Dict
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.typing import LambdaContext
from clients.llm_chat_client import LLMChatClient
from shared.callbacks.websocket_error_handler import WebsocketErrorHandler
from shared.callbacks.websocket_handler import WebsocketHandler
from utils.constants import (
DEFAULT_RAG_ENABLED_MODE,
END_CONVERSATION_TOKEN,
MESSAGE_KEY,
REQUEST_CONTEXT_KEY,
TRACE_ID_ENV_VAR,
USER_ID_EVENT_KEY,
)
logger = Logger(utc=True)
tracer = Tracer()
class UseCaseHandler:
"""
Abstract class for lambda handlers that use LLMs.
Methods:
handle_event(event: Dict[str, Any], context: LambdaContext) -> Dict[str, Any]:
get_llm_client(event: Dict[str, Any]) -> LLMClient: abstract method who
"""
def __init__(self, llm_client_type: LLMChatClient.__class__):
self.llm_client_type = llm_client_type
def check_streaming_failed(self, callbacks) -> bool:
"""
Check if streaming failed by examining callbacks for has_streamed attribute.
:param callbacks: List of callback objects to check
:return: True if streaming was attempted but failed, False otherwise
"""
streaming_failed = any(
hasattr(callback, "has_streamed") and not callback.has_streamed for callback in callbacks
)
if streaming_failed:
logger.info("Streaming was enabled but failed - using fallback")
return streaming_failed
def handle_event(self, event: Dict[str, Any], context: LambdaContext) -> Dict:
"""
Create a LLMChatClient concrete object type based on the configuration in `event` and
admin configuration and use it to answer user questions
:param event (Dict): AWS Lambda Event
:param context (LambdaContext): AWS Lambda Context
:return: the generated response from the chatbot
"""
batch_item_failures = []
loop_index = 0
total_records = len(event["Records"])
logger.debug(f"Total records received in the event: {total_records}")
sqs_batch_response = {}
while loop_index < total_records:
logger.debug(f"Processing record number {loop_index}")
connection_id = None
conversation_id = None
record = event["Records"][loop_index]
try:
event_body = json.loads(record["body"])
request_context = event_body[REQUEST_CONTEXT_KEY]
connection_id = request_context["connectionId"]
llm_client = self.llm_client_type(
connection_id=connection_id,
)
conversation_id = llm_client.get_event_conversation_id(event_body)
llm_client.check_env()
updated_event_body = llm_client.check_event(event_body, conversation_id)
event_message = updated_event_body[MESSAGE_KEY]
llm_client.rag_enabled = llm_client.use_case_config.get("LlmParams", {}).get(
"RAGEnabled", DEFAULT_RAG_ENABLED_MODE
)
llm_chat = llm_client.get_model(
event_message,
request_context["authorizer"][USER_ID_EVENT_KEY],
)
ai_response = llm_chat.generate(event_message["question"])
socket_handler = WebsocketHandler(
connection_id=connection_id,
conversation_id=conversation_id,
message_id=llm_client.builder.message_id,
)
# Send response via WebSocket if streaming is disabled OR if streaming failed
streaming_failed = self.check_streaming_failed(llm_client.builder.callbacks)
if not llm_client.builder.is_streaming or streaming_failed:
socket_handler.post_response_to_connection(ai_response)
socket_handler.post_token_to_connection(END_CONVERSATION_TOKEN)
loop_index = loop_index + 1
# check if under 20 seconds remaining, proceed with aborting processing of records
while context.get_remaining_time_in_millis() < 20000 and loop_index < total_records:
logger.debug(
f"Lambda reaching timeout and hence adding {loop_index}th message to batch_item_failures"
)
batch_item_failures.append({"itemIdentifier": event["Records"][loop_index]["messageId"]})
loop_index = loop_index + 1
except Exception as ex:
tracer_id = os.getenv(TRACE_ID_ENV_VAR)
chat_error = f"Chat service failed to respond. Please contact your administrator for support and quote the following trace id: {tracer_id}"
logger.error(f"An exception occurred in the processing of chat: {ex}", xray_trace_id=tracer_id)
error_handler = WebsocketErrorHandler(
connection_id=connection_id, trace_id=tracer_id, conversation_id=conversation_id
)
error_handler.post_token_to_connection(chat_error)
# append error records with the same connection id
# fmt:off
while (
loop_index < total_records
and event["Records"][loop_index]["messageAttributes"]["connectionId"]["stringValue"] == connection_id
):
# fmt:on
logger.debug(
f"Record with {loop_index} has the same connectionId, hence to maintain FIFO sequence, pushing them back to the queue"
)
batch_item_failures.append({"itemIdentifier": event["Records"][loop_index]["messageId"]})
loop_index = loop_index + 1
sqs_batch_response["batchItemFailures"] = batch_item_failures
return sqs_batch_response