Skip to content

Commit dc7aace

Browse files
committed
optimize webhook serialization and logging (#1651)
* optimize webhook serialization and logging * optimize logging by binding structlog proxies * fix tests --------- Signed-off-by: technillogue <[email protected]>
1 parent ce25ad5 commit dc7aace

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

Diff for: python/cog/server/clients.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
import httpx
1919
import structlog
20+
from fastapi.encoders import jsonable_encoder
2021

21-
from ..schema import Status, WebhookEvent
22+
from ..schema import PredictionResponse, Status, WebhookEvent
2223
from ..types import Path
2324
from .response_throttler import ResponseThrottler
2425
from .retry_transport import RetryTransport
@@ -147,6 +148,7 @@ def __init__(self) -> None:
147148
self.retry_webhook_client = httpx_retry_client()
148149
self.file_client = httpx_file_client()
149150
self.download_client = httpx.AsyncClient(follow_redirects=True, http2=True)
151+
self.log = structlog.get_logger(__name__).bind()
150152

151153
async def aclose(self) -> None:
152154
# not used but it's not actually critical to close them
@@ -159,26 +161,29 @@ async def aclose(self) -> None:
159161

160162
async def send_webhook(self, url: str, response: Dict[str, Any]) -> None:
161163
if Status.is_terminal(response["status"]):
162-
log.info("sending terminal webhook with status %s", response["status"])
164+
self.log.info("sending terminal webhook with status %s", response["status"])
163165
# For terminal updates, retry persistently
164166
await self.retry_webhook_client.post(url, json=response)
165167
else:
166-
log.info("sending webhook with status %s", response["status"])
168+
self.log.info("sending webhook with status %s", response["status"])
167169
# For other requests, don't retry, and ignore any errors
168170
try:
169171
await self.webhook_client.post(url, json=response)
170172
except httpx.RequestError:
171-
log.warn("caught exception while sending webhook", exc_info=True)
173+
self.log.warn("caught exception while sending webhook", exc_info=True)
172174

173175
def make_webhook_sender(
174176
self, url: Optional[str], webhook_events_filter: Collection[WebhookEvent]
175177
) -> WebhookSenderType:
176178
throttler = ResponseThrottler(response_interval=_response_interval)
177179

178-
async def sender(response: Any, event: WebhookEvent) -> None:
180+
async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
179181
if url and event in webhook_events_filter:
180182
if throttler.should_send_response(response):
181-
await self.send_webhook(url, response)
183+
# jsonable_encoder is quite slow in context, it would be ideal
184+
# to skip the heavy parts of this for well-known output types
185+
dict_response = jsonable_encoder(response.dict(exclude_unset=True))
186+
await self.send_webhook(url, dict_response)
182187
throttler.update_last_sent_response_time()
183188

184189
return sender
@@ -259,6 +264,9 @@ async def upload_files(
259264
Iterates through an object from make_encodeable and uploads any files.
260265
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
261266
"""
267+
# skip four isinstance checks for fast text models
268+
if type(obj) == str: # noqa: E721
269+
return obj
262270
# # it would be kind of cleaner to make the default file_url
263271
# # instead of skipping entirely, we need to convert to datauri
264272
# if url is None:

Diff for: python/cog/server/runner.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import httpx
1919
import structlog
2020
from attrs import define
21-
from fastapi.encoders import jsonable_encoder
2221

2322
from .. import schema, types
2423
from .clients import SKIP_START_EVENT, ClientManager
@@ -80,6 +79,9 @@ def __init__(
8079

8180
self.client_manager = ClientManager()
8281

82+
# bind logger instead of the module-level logger proxy for performance
83+
self.log = log.bind()
84+
8385
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
8486
def handle_error(task: RunnerTask) -> None:
8587
exc = task.exception()
@@ -91,7 +93,7 @@ def handle_error(task: RunnerTask) -> None:
9193
try:
9294
raise exc
9395
except Exception:
94-
log.error(f"caught exception while running {activity}", exc_info=True)
96+
self.log.error(f"caught exception while running {activity}", exc_info=True)
9597
if self._shutdown_event is not None:
9698
self._shutdown_event.set()
9799

@@ -128,7 +130,7 @@ def predict(
128130
# if upload url was not set, we can respect output_file_prefix
129131
# but maybe we should just throw an error
130132
upload_url = request.output_file_prefix or self._upload_url
131-
event_handler = PredictionEventHandler(request, self.client_manager, upload_url)
133+
event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log)
132134
self._response = event_handler.response
133135

134136
#prediction_input = PredictionInput.from_request(request)
@@ -152,13 +154,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
152154
tb = traceback.format_exc()
153155
await event_handler.append_logs(tb)
154156
await event_handler.failed(error=str(e))
155-
log.warn("failed to download url path from input", exc_info=True)
157+
self.log.warn("failed to download url path from input", exc_info=True)
156158
return event_handler.response
157159
except Exception as e:
158160
tb = traceback.format_exc()
159161
await event_handler.append_logs(tb)
160162
await event_handler.failed(error=str(e))
161-
log.error("caught exception while running prediction", exc_info=True)
163+
self.log.error("caught exception while running prediction", exc_info=True)
162164
if self._shutdown_event is not None:
163165
self._shutdown_event.set()
164166
raise # we don't actually want to raise anymore but w/e
@@ -204,8 +206,10 @@ def __init__(
204206
request: schema.PredictionRequest,
205207
client_manager: ClientManager,
206208
upload_url: Optional[str],
209+
logger: Optional[structlog.BoundLogger] = None,
207210
) -> None:
208-
log.info("starting prediction")
211+
self.logger = logger or log.bind()
212+
self.logger.info("starting prediction")
209213
# maybe this should be a deep copy to not share File state with child worker
210214
self.p = schema.PredictionResponse(**request.dict())
211215
self.p.status = schema.Status.PROCESSING
@@ -255,7 +259,7 @@ async def append_logs(self, logs: str) -> None:
255259
await self._send_webhook(schema.WebhookEvent.LOGS)
256260

257261
async def succeeded(self) -> None:
258-
log.info("prediction succeeded")
262+
self.logger.info("prediction succeeded")
259263
self.p.status = schema.Status.SUCCEEDED
260264
self._set_completed_at()
261265
# These have been set already: this is to convince the typechecker of
@@ -268,14 +272,14 @@ async def succeeded(self) -> None:
268272
await self._send_webhook(schema.WebhookEvent.COMPLETED)
269273

270274
async def failed(self, error: str) -> None:
271-
log.info("prediction failed", error=error)
275+
self.logger.info("prediction failed", error=error)
272276
self.p.status = schema.Status.FAILED
273277
self.p.error = error
274278
self._set_completed_at()
275279
await self._send_webhook(schema.WebhookEvent.COMPLETED)
276280

277281
async def canceled(self) -> None:
278-
log.info("prediction canceled")
282+
self.logger.info("prediction canceled")
279283
self.p.status = schema.Status.CANCELED
280284
self._set_completed_at()
281285
await self._send_webhook(schema.WebhookEvent.COMPLETED)
@@ -284,8 +288,7 @@ def _set_completed_at(self) -> None:
284288
self.p.completed_at = datetime.now(tz=timezone.utc)
285289

286290
async def _send_webhook(self, event: schema.WebhookEvent) -> None:
287-
dict_response = jsonable_encoder(self.response.dict(exclude_unset=True))
288-
await self._webhook_sender(dict_response, event)
291+
await self._webhook_sender(self.response, event)
289292

290293
async def _upload_files(self, output: Any) -> Any:
291294
try:

0 commit comments

Comments
 (0)