Skip to content

Commit 08f3780

Browse files
committed
optimize webhook serialization and logging (#1651)
* optimize webhook serialization and logging * optimize logging by binding structlog proxies * fix tests --------- Signed-off-by: technillogue <[email protected]>
1 parent 12b0abe commit 08f3780

File tree

4 files changed

+43
-30
lines changed

4 files changed

+43
-30
lines changed

Diff for: python/cog/server/clients.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import httpx
99
import structlog
10+
from fastapi.encoders import jsonable_encoder
1011

1112
from .. import types
12-
from ..schema import Status, WebhookEvent
13+
from ..schema import PredictionResponse, Status, WebhookEvent
1314
from ..types import Path
1415
from .eventtypes import PredictionInput
1516
from .response_throttler import ResponseThrottler
@@ -105,6 +106,7 @@ def __init__(self) -> None:
105106
self.retry_webhook_client = httpx_retry_client()
106107
self.file_client = httpx_file_client()
107108
self.download_client = httpx.AsyncClient(follow_redirects=True, http2=True)
109+
self.log = structlog.get_logger(__name__).bind()
108110

109111
async def aclose(self) -> None:
110112
# not used but it's not actually critical to close them
@@ -119,26 +121,29 @@ async def send_webhook(
119121
self, url: str, response: Dict[str, Any], event: WebhookEvent
120122
) -> None:
121123
if Status.is_terminal(response["status"]):
122-
log.info("sending terminal webhook with status %s", response["status"])
124+
self.log.info("sending terminal webhook with status %s", response["status"])
123125
# For terminal updates, retry persistently
124126
await self.retry_webhook_client.post(url, json=response)
125127
else:
126-
log.info("sending webhook with status %s", response["status"])
128+
self.log.info("sending webhook with status %s", response["status"])
127129
# For other requests, don't retry, and ignore any errors
128130
try:
129131
await self.webhook_client.post(url, json=response)
130132
except httpx.RequestError:
131-
log.warn("caught exception while sending webhook", exc_info=True)
133+
self.log.warn("caught exception while sending webhook", exc_info=True)
132134

133135
def make_webhook_sender(
134136
self, url: Optional[str], webhook_events_filter: Collection[WebhookEvent]
135137
) -> WebhookSenderType:
136138
throttler = ResponseThrottler(response_interval=_response_interval)
137139

138-
async def sender(response: Any, event: WebhookEvent) -> None:
140+
async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
139141
if url and event in webhook_events_filter:
140142
if throttler.should_send_response(response):
141-
await self.send_webhook(url, response, event)
143+
# jsonable_encoder is quite slow in context, it would be ideal
144+
# to skip the heavy parts of this for well-known output types
145+
dict_response = jsonable_encoder(response.dict(exclude_unset=True))
146+
await self.send_webhook(url, dict_response, event)
142147
throttler.update_last_sent_response_time()
143148

144149
return sender
@@ -213,6 +218,9 @@ async def upload_files(self, obj: Any, url: Optional[str]) -> Any:
213218
Iterates through an object from make_encodeable and uploads any files.
214219
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
215220
"""
221+
# skip four isinstance checks for fast text models
222+
if type(obj) == str: # noqa: E721
223+
return obj
216224
# # it would be kind of cleaner to make the default file_url
217225
# # instead of skipping entirely, we need to convert to datauri
218226
# if url is None:

Diff for: python/cog/server/response_throttler.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import time
2-
from typing import Any, Dict
32

4-
from ..schema import Status
3+
from ..schema import PredictionResponse, Status
54

65

76
class ResponseThrottler:
87
def __init__(self, response_interval: float) -> None:
98
self.last_sent_response_time = 0.0
109
self.response_interval = response_interval
1110

12-
def should_send_response(self, response: Dict[str, Any]) -> bool:
13-
if Status.is_terminal(response["status"]):
11+
def should_send_response(self, response: PredictionResponse) -> bool:
12+
if Status.is_terminal(response.status):
1413
return True
1514

1615
return self.seconds_since_last_response() >= self.response_interval

Diff for: python/cog/server/runner.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import httpx
1010
import structlog
1111
from attrs import define
12-
from fastapi.encoders import jsonable_encoder
1312

1413
from .. import schema, types
1514
from .clients import SKIP_START_EVENT, ClientManager
@@ -72,6 +71,9 @@ def __init__(
7271

7372
self.client_manager = ClientManager()
7473

74+
# bind logger instead of the module-level logger proxy for performance
75+
self.log = log.bind()
76+
7577
def make_error_handler(self, activity: str) -> Callable[[RunnerTask], None]:
7678
def handle_error(task: RunnerTask) -> None:
7779
exc = task.exception()
@@ -83,7 +85,7 @@ def handle_error(task: RunnerTask) -> None:
8385
try:
8486
raise exc
8587
except Exception:
86-
log.error(f"caught exception while running {activity}", exc_info=True)
88+
self.log.error(f"caught exception while running {activity}", exc_info=True)
8789
if self._shutdown_event is not None:
8890
self._shutdown_event.set()
8991

@@ -121,7 +123,7 @@ def predict(
121123
# if upload url was not set, we can respect output_file_prefix
122124
# but maybe we should just throw an error
123125
upload_url = request.output_file_prefix or self._upload_url
124-
event_handler = PredictionEventHandler(request, self.client_manager, upload_url)
126+
event_handler = PredictionEventHandler(request, self.client_manager, upload_url, self.log)
125127
self._response = event_handler.response
126128

127129
prediction_input = PredictionInput.from_request(request)
@@ -143,13 +145,13 @@ async def async_predict_handling_errors() -> schema.PredictionResponse:
143145
tb = traceback.format_exc()
144146
await event_handler.append_logs(tb)
145147
await event_handler.failed(error=str(e))
146-
log.warn("failed to download url path from input", exc_info=True)
148+
self.log.warn("failed to download url path from input", exc_info=True)
147149
return event_handler.response
148150
except Exception as e:
149151
tb = traceback.format_exc()
150152
await event_handler.append_logs(tb)
151153
await event_handler.failed(error=str(e))
152-
log.error("caught exception while running prediction", exc_info=True)
154+
self.log.error("caught exception while running prediction", exc_info=True)
153155
if self._shutdown_event is not None:
154156
self._shutdown_event.set()
155157
raise # we don't actually want to raise anymore but w/e
@@ -195,8 +197,10 @@ def __init__(
195197
request: schema.PredictionRequest,
196198
client_manager: ClientManager,
197199
upload_url: Optional[str],
200+
logger: Optional[structlog.BoundLogger] = None,
198201
) -> None:
199-
log.info("starting prediction")
202+
self.logger = logger or log.bind()
203+
self.logger.info("starting prediction")
200204
# maybe this should be a deep copy to not share File state with child worker
201205
self.p = schema.PredictionResponse(**request.dict())
202206
self.p.status = schema.Status.PROCESSING
@@ -244,7 +248,7 @@ async def append_logs(self, logs: str) -> None:
244248
await self._send_webhook(schema.WebhookEvent.LOGS)
245249

246250
async def succeeded(self) -> None:
247-
log.info("prediction succeeded")
251+
self.logger.info("prediction succeeded")
248252
self.p.status = schema.Status.SUCCEEDED
249253
self._set_completed_at()
250254
# These have been set already: this is to convince the typechecker of
@@ -257,14 +261,14 @@ async def succeeded(self) -> None:
257261
await self._send_webhook(schema.WebhookEvent.COMPLETED)
258262

259263
async def failed(self, error: str) -> None:
260-
log.info("prediction failed", error=error)
264+
self.logger.info("prediction failed", error=error)
261265
self.p.status = schema.Status.FAILED
262266
self.p.error = error
263267
self._set_completed_at()
264268
await self._send_webhook(schema.WebhookEvent.COMPLETED)
265269

266270
async def canceled(self) -> None:
267-
log.info("prediction canceled")
271+
self.logger.info("prediction canceled")
268272
self.p.status = schema.Status.CANCELED
269273
self._set_completed_at()
270274
await self._send_webhook(schema.WebhookEvent.COMPLETED)
@@ -273,8 +277,7 @@ def _set_completed_at(self) -> None:
273277
self.p.completed_at = datetime.now(tz=timezone.utc)
274278

275279
async def _send_webhook(self, event: schema.WebhookEvent) -> None:
276-
dict_response = jsonable_encoder(self.response.dict(exclude_unset=True))
277-
await self._webhook_sender(dict_response, event)
280+
await self._webhook_sender(self.response, event)
278281

279282
async def _upload_files(self, output: Any) -> Any:
280283
try:

Diff for: python/tests/server/test_response_throttler.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
11
import time
22

3-
from cog.schema import Status
3+
from cog.schema import PredictionResponse, Status
44
from cog.server.response_throttler import ResponseThrottler
55

6+
processing = PredictionResponse(input={}, status=Status.PROCESSING)
7+
succeeded = PredictionResponse(input={}, status=Status.SUCCEEDED)
8+
69

710
def test_zero_interval():
811
throttler = ResponseThrottler(response_interval=0)
912

10-
assert throttler.should_send_response({"status": Status.PROCESSING})
13+
assert throttler.should_send_response(processing)
1114
throttler.update_last_sent_response_time()
12-
assert throttler.should_send_response({"status": Status.SUCCEEDED})
15+
assert throttler.should_send_response(succeeded)
1316

1417

1518
def test_terminal_status():
1619
throttler = ResponseThrottler(response_interval=10)
1720

18-
assert throttler.should_send_response({"status": Status.PROCESSING})
21+
assert throttler.should_send_response(processing)
1922
throttler.update_last_sent_response_time()
20-
assert not throttler.should_send_response({"status": Status.PROCESSING})
23+
assert not throttler.should_send_response(processing)
2124
throttler.update_last_sent_response_time()
22-
assert throttler.should_send_response({"status": Status.SUCCEEDED})
25+
assert throttler.should_send_response(succeeded)
2326

2427

2528
def test_nonzero_internal():
2629
throttler = ResponseThrottler(response_interval=0.2)
2730

28-
assert throttler.should_send_response({"status": Status.PROCESSING})
31+
assert throttler.should_send_response(processing)
2932
throttler.update_last_sent_response_time()
30-
assert not throttler.should_send_response({"status": Status.PROCESSING})
33+
assert not throttler.should_send_response(processing)
3134
throttler.update_last_sent_response_time()
3235

3336
time.sleep(0.3)
3437

35-
assert throttler.should_send_response({"status": Status.PROCESSING})
38+
assert throttler.should_send_response(processing)

0 commit comments

Comments
 (0)