1
1
# You can run the whole script locally with
2
- # serve run rag.serve:deployment
2
+ # serve run rag.serve:deployment --runtime-env-json='{"env_vars": {"RAY_ASSISTANT_LOGS": "/mnt/shared_storage/ray-assistant-logs/info.log", "RAY_ASSISTANT_SECRET": "ray-assistant-prod"}}'
3
3
4
4
import json
5
+ import logging
5
6
import os
6
7
import pickle
7
8
from pathlib import Path
8
- from typing import List
9
+ from typing import Any , Dict , List
9
10
10
11
import openai
11
12
import ray
17
18
from slack_bolt import App
18
19
from slack_bolt .adapter .socket_mode import SocketModeHandler
19
20
from starlette .responses import StreamingResponse
21
+ import structlog
20
22
21
23
from rag .config import MAX_CONTEXT_LENGTHS , ROOT_DIR
22
24
from rag .generate import QueryAgent
@@ -37,7 +39,7 @@ def get_secret(secret_name):
37
39
import boto3
38
40
39
41
client = boto3 .client ("secretsmanager" , region_name = "us-west-2" )
40
- response = client .get_secret_value (SecretId = "ray-assistant" )
42
+ response = client .get_secret_value (SecretId = os . environ [ "RAY_ASSISTANT_SECRET" ] )
41
43
return json .loads (response ["SecretString" ])[secret_name ]
42
44
43
45
@@ -78,6 +80,17 @@ class Answer(BaseModel):
78
80
@serve .ingress (app )
79
81
class RayAssistantDeployment :
80
82
def __init__ (self , num_chunks , embedding_model_name , llm , run_slack = False ):
83
+ # Configure logging
84
+ logging .basicConfig (filename = os .environ ["RAY_ASSISTANT_LOGS" ], level = logging .INFO , encoding = 'utf-8' )
85
+ structlog .configure (
86
+ processors = [
87
+ structlog .processors .TimeStamper (fmt = "iso" ),
88
+ structlog .processors .JSONRenderer (),
89
+ ],
90
+ logger_factory = structlog .stdlib .LoggerFactory (),
91
+ )
92
+ self .logger = structlog .get_logger ()
93
+
81
94
# Set credentials
82
95
os .environ ["ANYSCALE_API_BASE" ] = "https://api.endpoints.anyscale.com/v1"
83
96
os .environ ["ANYSCALE_API_KEY" ] = get_secret ("ANYSCALE_API_KEY" )
@@ -111,33 +124,48 @@ def __init__(self, num_chunks, embedding_model_name, llm, run_slack=False):
111
124
self .slack_app = SlackApp .remote ()
112
125
self .runner = self .slack_app .run .remote ()
113
126
114
- @app .post ("/query" )
115
- def query (self , query : Query ) -> Answer :
127
+ def predict (self , query : Query , stream : bool ) -> Dict [str , Any ]:
116
128
use_oss_agent = self .router .predict ([query .query ])[0 ]
117
129
agent = self .oss_agent if use_oss_agent else self .gpt_agent
118
- result = agent (query = query .query , num_chunks = self .num_chunks , stream = False )
130
+ result = agent (query = query .query , num_chunks = self .num_chunks , stream = stream )
131
+ return result
132
+
133
+ @app .post ("/query" )
134
+ def query (self , query : Query ) -> Answer :
135
+ result = self .predict (query , stream = False )
119
136
return Answer .parse_obj (result )
120
137
121
- def produce_streaming_answer (self , result ):
138
+ def produce_streaming_answer (self , query , result ):
139
+ answer = []
122
140
for answer_piece in result ["answer" ]:
141
+ answer .append (answer_piece )
123
142
yield answer_piece
143
+
124
144
if result ["sources" ]:
125
145
yield "\n \n **Sources:**\n "
126
146
for source in result ["sources" ]:
127
147
yield "* " + source + "\n "
128
148
149
+ self .logger .info (
150
+ "finished streaming query" ,
151
+ query = query ,
152
+ document_ids = result ["document_ids" ],
153
+ llm = result ["llm" ],
154
+ answer = "" .join (answer )
155
+ )
156
+
129
157
@app .post ("/stream" )
130
158
def stream (self , query : Query ) -> StreamingResponse :
131
- use_oss_agent = self .router .predict ([query .query ])[0 ]
132
- agent = self .oss_agent if use_oss_agent else self .gpt_agent
133
- result = agent (query = query .query , num_chunks = self .num_chunks , stream = True )
159
+ result = self .predict (query , stream = True )
134
160
return StreamingResponse (
135
- self .produce_streaming_answer (result ), media_type = "text/plain" )
161
+ self .produce_streaming_answer (query .query , result ),
162
+ media_type = "text/plain"
163
+ )
136
164
137
165
138
166
# Deploy the Ray Serve app
139
167
deployment = RayAssistantDeployment .bind (
140
- num_chunks = 7 ,
168
+ num_chunks = 5 ,
141
169
embedding_model_name = "thenlper/gte-large" ,
142
170
llm = "meta-llama/Llama-2-70b-chat-hf" ,
143
171
)
0 commit comments