9
9
import httpx
10
10
import structlog
11
11
from attrs import define
12
- from fastapi .encoders import jsonable_encoder
13
12
14
13
from .. import schema , types
15
14
from .clients import SKIP_START_EVENT , ClientManager
@@ -72,6 +71,9 @@ def __init__(
72
71
73
72
self .client_manager = ClientManager ()
74
73
74
+ # bind logger instead of the module-level logger proxy for performance
75
+ self .log = log .bind ()
76
+
75
77
def make_error_handler (self , activity : str ) -> Callable [[RunnerTask ], None ]:
76
78
def handle_error (task : RunnerTask ) -> None :
77
79
exc = task .exception ()
@@ -83,7 +85,7 @@ def handle_error(task: RunnerTask) -> None:
83
85
try :
84
86
raise exc
85
87
except Exception :
86
- log .error (f"caught exception while running { activity } " , exc_info = True )
88
+ self . log .error (f"caught exception while running { activity } " , exc_info = True )
87
89
if self ._shutdown_event is not None :
88
90
self ._shutdown_event .set ()
89
91
@@ -121,7 +123,7 @@ def predict(
121
123
# if upload url was not set, we can respect output_file_prefix
122
124
# but maybe we should just throw an error
123
125
upload_url = request .output_file_prefix or self ._upload_url
124
- event_handler = PredictionEventHandler (request , self .client_manager , upload_url )
126
+ event_handler = PredictionEventHandler (request , self .client_manager , upload_url , self . log )
125
127
self ._response = event_handler .response
126
128
127
129
prediction_input = PredictionInput .from_request (request )
@@ -143,13 +145,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
143
145
tb = traceback .format_exc ()
144
146
await event_handler .append_logs (tb )
145
147
await event_handler .failed (error = str (e ))
146
- log .warn ("failed to download url path from input" , exc_info = True )
148
+ self . log .warn ("failed to download url path from input" , exc_info = True )
147
149
return event_handler .response
148
150
except Exception as e :
149
151
tb = traceback .format_exc ()
150
152
await event_handler .append_logs (tb )
151
153
await event_handler .failed (error = str (e ))
152
- log .error ("caught exception while running prediction" , exc_info = True )
154
+ self . log .error ("caught exception while running prediction" , exc_info = True )
153
155
if self ._shutdown_event is not None :
154
156
self ._shutdown_event .set ()
155
157
raise # we don't actually want to raise anymore but w/e
@@ -195,8 +197,10 @@ def __init__(
195
197
request : schema .PredictionRequest ,
196
198
client_manager : ClientManager ,
197
199
upload_url : Optional [str ],
200
+ logger : Optional [structlog .BoundLogger ] = None ,
198
201
) -> None :
199
- log .info ("starting prediction" )
202
+ self .logger = logger or log .bind ()
203
+ self .logger .info ("starting prediction" )
200
204
# maybe this should be a deep copy to not share File state with child worker
201
205
self .p = schema .PredictionResponse (** request .dict ())
202
206
self .p .status = schema .Status .PROCESSING
@@ -244,7 +248,7 @@ async def append_logs(self, logs: str) -> None:
244
248
await self ._send_webhook (schema .WebhookEvent .LOGS )
245
249
246
250
async def succeeded (self ) -> None :
247
- log .info ("prediction succeeded" )
251
+ self . logger .info ("prediction succeeded" )
248
252
self .p .status = schema .Status .SUCCEEDED
249
253
self ._set_completed_at ()
250
254
# These have been set already: this is to convince the typechecker of
@@ -257,14 +261,14 @@ async def succeeded(self) -> None:
257
261
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
258
262
259
263
async def failed (self , error : str ) -> None :
260
- log .info ("prediction failed" , error = error )
264
+ self . logger .info ("prediction failed" , error = error )
261
265
self .p .status = schema .Status .FAILED
262
266
self .p .error = error
263
267
self ._set_completed_at ()
264
268
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
265
269
266
270
async def canceled (self ) -> None :
267
- log .info ("prediction canceled" )
271
+ self . logger .info ("prediction canceled" )
268
272
self .p .status = schema .Status .CANCELED
269
273
self ._set_completed_at ()
270
274
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
@@ -273,8 +277,7 @@ def _set_completed_at(self) -> None:
273
277
self .p .completed_at = datetime .now (tz = timezone .utc )
274
278
275
279
async def _send_webhook (self , event : schema .WebhookEvent ) -> None :
276
- dict_response = jsonable_encoder (self .response .dict (exclude_unset = True ))
277
- await self ._webhook_sender (dict_response , event )
280
+ await self ._webhook_sender (self .response , event )
278
281
279
282
async def _upload_files (self , output : Any ) -> Any :
280
283
try :
0 commit comments