18
18
import httpx
19
19
import structlog
20
20
from attrs import define
21
- from fastapi .encoders import jsonable_encoder
22
21
23
22
from .. import schema , types
24
23
from .clients import SKIP_START_EVENT , ClientManager
@@ -80,6 +79,9 @@ def __init__(
80
79
81
80
self .client_manager = ClientManager ()
82
81
82
+ # bind logger instead of the module-level logger proxy for performance
83
+ self .log = log .bind ()
84
+
83
85
def make_error_handler (self , activity : str ) -> Callable [[RunnerTask ], None ]:
84
86
def handle_error (task : RunnerTask ) -> None :
85
87
exc = task .exception ()
@@ -91,7 +93,7 @@ def handle_error(task: RunnerTask) -> None:
91
93
try :
92
94
raise exc
93
95
except Exception :
94
- log .error (f"caught exception while running { activity } " , exc_info = True )
96
+ self . log .error (f"caught exception while running { activity } " , exc_info = True )
95
97
if self ._shutdown_event is not None :
96
98
self ._shutdown_event .set ()
97
99
@@ -128,7 +130,7 @@ def predict(
128
130
# if upload url was not set, we can respect output_file_prefix
129
131
# but maybe we should just throw an error
130
132
upload_url = request .output_file_prefix or self ._upload_url
131
- event_handler = PredictionEventHandler (request , self .client_manager , upload_url )
133
+ event_handler = PredictionEventHandler (request , self .client_manager , upload_url , self . log )
132
134
self ._response = event_handler .response
133
135
134
136
#prediction_input = PredictionInput.from_request(request)
@@ -152,13 +154,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
152
154
tb = traceback .format_exc ()
153
155
await event_handler .append_logs (tb )
154
156
await event_handler .failed (error = str (e ))
155
- log .warn ("failed to download url path from input" , exc_info = True )
157
+ self . log .warn ("failed to download url path from input" , exc_info = True )
156
158
return event_handler .response
157
159
except Exception as e :
158
160
tb = traceback .format_exc ()
159
161
await event_handler .append_logs (tb )
160
162
await event_handler .failed (error = str (e ))
161
- log .error ("caught exception while running prediction" , exc_info = True )
163
+ self . log .error ("caught exception while running prediction" , exc_info = True )
162
164
if self ._shutdown_event is not None :
163
165
self ._shutdown_event .set ()
164
166
raise # we don't actually want to raise anymore but w/e
@@ -204,8 +206,10 @@ def __init__(
204
206
request : schema .PredictionRequest ,
205
207
client_manager : ClientManager ,
206
208
upload_url : Optional [str ],
209
+ logger : Optional [structlog .BoundLogger ] = None ,
207
210
) -> None :
208
- log .info ("starting prediction" )
211
+ self .logger = logger or log .bind ()
212
+ self .logger .info ("starting prediction" )
209
213
# maybe this should be a deep copy to not share File state with child worker
210
214
self .p = schema .PredictionResponse (** request .dict ())
211
215
self .p .status = schema .Status .PROCESSING
@@ -256,7 +260,7 @@ async def append_logs(self, logs: str) -> None:
256
260
await self ._send_webhook (schema .WebhookEvent .LOGS )
257
261
258
262
async def succeeded (self ) -> None :
259
- log .info ("prediction succeeded" )
263
+ self . logger .info ("prediction succeeded" )
260
264
self .p .status = schema .Status .SUCCEEDED
261
265
self ._set_completed_at ()
262
266
# These have been set already: this is to convince the typechecker of
@@ -269,14 +273,14 @@ async def succeeded(self) -> None:
269
273
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
270
274
271
275
async def failed (self , error : str ) -> None :
272
- log .info ("prediction failed" , error = error )
276
+ self . logger .info ("prediction failed" , error = error )
273
277
self .p .status = schema .Status .FAILED
274
278
self .p .error = error
275
279
self ._set_completed_at ()
276
280
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
277
281
278
282
async def canceled (self ) -> None :
279
- log .info ("prediction canceled" )
283
+ self . logger .info ("prediction canceled" )
280
284
self .p .status = schema .Status .CANCELED
281
285
self ._set_completed_at ()
282
286
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
@@ -285,8 +289,7 @@ def _set_completed_at(self) -> None:
285
289
self .p .completed_at = datetime .now (tz = timezone .utc )
286
290
287
291
async def _send_webhook (self , event : schema .WebhookEvent ) -> None :
288
- dict_response = jsonable_encoder (self .response .dict (exclude_unset = True ))
289
- await self ._webhook_sender (dict_response , event )
292
+ await self ._webhook_sender (self .response , event )
290
293
291
294
async def _upload_files (self , output : Any ) -> Any :
292
295
try :
0 commit comments