18
18
import httpx
19
19
import structlog
20
20
from attrs import define
21
- from fastapi .encoders import jsonable_encoder
22
21
23
22
from .. import schema , types
24
23
from .clients import SKIP_START_EVENT , ClientManager
@@ -80,6 +79,9 @@ def __init__(
80
79
81
80
self .client_manager = ClientManager ()
82
81
82
+ # bind logger instead of the module-level logger proxy for performance
83
+ self .log = log .bind ()
84
+
83
85
def make_error_handler (self , activity : str ) -> Callable [[RunnerTask ], None ]:
84
86
def handle_error (task : RunnerTask ) -> None :
85
87
exc = task .exception ()
@@ -91,7 +93,7 @@ def handle_error(task: RunnerTask) -> None:
91
93
try :
92
94
raise exc
93
95
except Exception :
94
- log .error (f"caught exception while running { activity } " , exc_info = True )
96
+ self . log .error (f"caught exception while running { activity } " , exc_info = True )
95
97
if self ._shutdown_event is not None :
96
98
self ._shutdown_event .set ()
97
99
@@ -128,7 +130,7 @@ def predict(
128
130
# if upload url was not set, we can respect output_file_prefix
129
131
# but maybe we should just throw an error
130
132
upload_url = request .output_file_prefix or self ._upload_url
131
- event_handler = PredictionEventHandler (request , self .client_manager , upload_url )
133
+ event_handler = PredictionEventHandler (request , self .client_manager , upload_url , self . log )
132
134
self ._response = event_handler .response
133
135
134
136
#prediction_input = PredictionInput.from_request(request)
@@ -152,13 +154,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
152
154
tb = traceback .format_exc ()
153
155
await event_handler .append_logs (tb )
154
156
await event_handler .failed (error = str (e ))
155
- log .warn ("failed to download url path from input" , exc_info = True )
157
+ self . log .warn ("failed to download url path from input" , exc_info = True )
156
158
return event_handler .response
157
159
except Exception as e :
158
160
tb = traceback .format_exc ()
159
161
await event_handler .append_logs (tb )
160
162
await event_handler .failed (error = str (e ))
161
- log .error ("caught exception while running prediction" , exc_info = True )
163
+ self . log .error ("caught exception while running prediction" , exc_info = True )
162
164
if self ._shutdown_event is not None :
163
165
self ._shutdown_event .set ()
164
166
raise # we don't actually want to raise anymore but w/e
@@ -204,8 +206,10 @@ def __init__(
204
206
request : schema .PredictionRequest ,
205
207
client_manager : ClientManager ,
206
208
upload_url : Optional [str ],
209
+ logger : Optional [structlog .BoundLogger ] = None ,
207
210
) -> None :
208
- log .info ("starting prediction" )
211
+ self .logger = logger or log .bind ()
212
+ self .logger .info ("starting prediction" )
209
213
# maybe this should be a deep copy to not share File state with child worker
210
214
self .p = schema .PredictionResponse (** request .dict ())
211
215
self .p .status = schema .Status .PROCESSING
@@ -255,7 +259,7 @@ async def append_logs(self, logs: str) -> None:
255
259
await self ._send_webhook (schema .WebhookEvent .LOGS )
256
260
257
261
async def succeeded (self ) -> None :
258
- log .info ("prediction succeeded" )
262
+ self . logger .info ("prediction succeeded" )
259
263
self .p .status = schema .Status .SUCCEEDED
260
264
self ._set_completed_at ()
261
265
# These have been set already: this is to convince the typechecker of
@@ -268,14 +272,14 @@ async def succeeded(self) -> None:
268
272
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
269
273
270
274
async def failed (self , error : str ) -> None :
271
- log .info ("prediction failed" , error = error )
275
+ self . logger .info ("prediction failed" , error = error )
272
276
self .p .status = schema .Status .FAILED
273
277
self .p .error = error
274
278
self ._set_completed_at ()
275
279
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
276
280
277
281
async def canceled (self ) -> None :
278
- log .info ("prediction canceled" )
282
+ self . logger .info ("prediction canceled" )
279
283
self .p .status = schema .Status .CANCELED
280
284
self ._set_completed_at ()
281
285
await self ._send_webhook (schema .WebhookEvent .COMPLETED )
@@ -284,8 +288,7 @@ def _set_completed_at(self) -> None:
284
288
self .p .completed_at = datetime .now (tz = timezone .utc )
285
289
286
290
async def _send_webhook (self , event : schema .WebhookEvent ) -> None :
287
- dict_response = jsonable_encoder (self .response .dict (exclude_unset = True ))
288
- await self ._webhook_sender (dict_response , event )
291
+ await self ._webhook_sender (self .response , event )
289
292
290
293
async def _upload_files (self , output : Any ) -> Any :
291
294
try :
0 commit comments