@@ -39,6 +39,11 @@ class EtlApiException(Exception):
39
39
logger = logging .getLogger ("uvicorn.error" )
40
40
41
41
42
+ class MessageChannels (BaseModel ):
43
+ infos : list [str ] = Field (default_factory = list )
44
+ warnings : list [str ] = Field (default_factory = list )
45
+
46
+
42
47
def log_func_and_body (func : Callable , body : Optional [str ] = None ) -> None :
43
48
msg = None
44
49
if logger .level == LOG_LEVELS .get ("debug" , logging .NOTSET ):
@@ -135,6 +140,7 @@ class InvokeResponse(BaseModel):
135
140
filedata_meta : Optional [filedata_meta_model ]
136
141
status_code_text : Optional [str ] = None
137
142
output : Optional [response_type ] = None
143
+ message_channels : MessageChannels = Field (default_factory = MessageChannels )
138
144
139
145
input_schema = get_input_schema (func , omit = ["usage" , "filedata_meta" ])
140
146
input_schema_model = schema_to_base_model (input_schema )
@@ -146,11 +152,14 @@ class InvokeResponse(BaseModel):
146
152
async def wrap_fn (func : Callable , kwargs : Optional [dict [str , Any ]] = None ) -> ResponseType :
147
153
usage : list [UsageData ] = []
148
154
filedata_meta = FileDataMeta ()
155
+ message_channels = MessageChannels ()
149
156
request_dict = kwargs if kwargs else {}
150
157
if "usage" in inspect .signature (func ).parameters :
151
158
request_dict ["usage" ] = usage
152
159
else :
153
160
logger .warning ("usage data not an expected parameter, omitting" )
161
+ if "message_channels" in inspect .signature (func ).parameters :
162
+ request_dict ["message_channels" ] = message_channels
154
163
if "filedata_meta" in inspect .signature (func ).parameters :
155
164
request_dict ["filedata_meta" ] = filedata_meta
156
165
try :
@@ -161,6 +170,7 @@ async def _stream_response():
161
170
async for output in func (** (request_dict or {})):
162
171
yield InvokeResponse (
163
172
usage = usage ,
173
+ message_channels = message_channels ,
164
174
filedata_meta = filedata_meta_model .model_validate (
165
175
filedata_meta .model_dump ()
166
176
),
@@ -171,6 +181,7 @@ async def _stream_response():
171
181
logger .error (f"Failure streaming response: { e } " , exc_info = True )
172
182
yield InvokeResponse (
173
183
usage = usage ,
184
+ message_channels = message_channels ,
174
185
filedata_meta = None ,
175
186
status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
176
187
status_code_text = f"[{ e .__class__ .__name__ } ] { e } " ,
@@ -181,6 +192,7 @@ async def _stream_response():
181
192
output = await invoke_func (func = func , kwargs = request_dict )
182
193
return InvokeResponse (
183
194
usage = usage ,
195
+ message_channels = message_channels ,
184
196
filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
185
197
status_code = status .HTTP_200_OK ,
186
198
output = output ,
@@ -189,6 +201,7 @@ async def _stream_response():
189
201
logger .info ("Unrecoverable error occurred during plugin invocation" )
190
202
return InvokeResponse (
191
203
usage = usage ,
204
+ message_channels = message_channels ,
192
205
status_code = 512 ,
193
206
status_code_text = ex .message ,
194
207
filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
@@ -198,6 +211,7 @@ async def _stream_response():
198
211
http_error = wrap_error (invoke_error )
199
212
return InvokeResponse (
200
213
usage = usage ,
214
+ message_channels = message_channels ,
201
215
filedata_meta = filedata_meta_model .model_validate (filedata_meta .model_dump ()),
202
216
status_code = http_error .status_code ,
203
217
status_code_text = f"[{ invoke_error .__class__ .__name__ } ] { invoke_error } " ,
0 commit comments