55from __future__ import annotations
66
77import asyncio
8+ import contextlib
9+ import contextvars
810import dataclasses
11+ import time
912from abc import ABC , abstractmethod
10- from collections .abc import Callable , Coroutine , Mapping , Sequence
13+ from collections .abc import Callable , Coroutine , Generator , Mapping , Sequence
1114from dataclasses import dataclass
15+ from datetime import timedelta
1216from typing import Any , ClassVar , TypeVar
1317
1418from typing_extensions import Self
2529_REFERENCE_ENCODING = b"json/external-storage-reference"
2630
2731
32+ @dataclass
33+ class StorageOperationMetrics :
34+ """Accumulates metrics from external storage operations."""
35+
36+ payload_count : int = 0
37+ """Number of payloads stored or retrieved externally."""
38+
39+ total_size : int = 0
40+ """Total size in bytes of externally stored/retrieved payloads."""
41+
42+ total_duration : timedelta = dataclasses .field (default_factory = timedelta )
43+ """Wall-clock time spent on external storage operations."""
44+
45+ def record_batch (self , count : int , size : int , duration : timedelta ) -> None :
46+ """Record metrics from a batch of storage operations."""
47+ self .payload_count += count
48+ self .total_size += size
49+ self .total_duration += duration
50+
51+ @contextlib .contextmanager
52+ def track (self ) -> Generator [Self , None , None ]:
53+ """Set this instance as the current metrics context and reset on exit."""
54+ token = _current_storage_metrics .set (self )
55+ try :
56+ yield self
57+ finally :
58+ _current_storage_metrics .reset (token )
59+
60+
61+ _current_storage_metrics : contextvars .ContextVar [StorageOperationMetrics | None ] = (
62+ contextvars .ContextVar ("_current_storage_metrics" , default = None )
63+ )
64+
65+
2866async def _gather_cancel_on_error (
2967 coros : Sequence [Coroutine [Any , Any , _T ]],
3068) -> list [_T ]:
@@ -255,6 +293,7 @@ def _get_driver_by_name(self, name: str) -> StorageDriver:
255293 return driver
256294
257295 async def _store_payload (self , payload : Payload ) -> Payload :
296+ start_time = time .monotonic ()
258297 context = StorageDriverStoreContext (serialization_context = self ._context )
259298
260299 driver = self ._select_driver (context , payload )
@@ -265,6 +304,7 @@ async def _store_payload(self, payload: Payload) -> Payload:
265304
266305 self ._validate_claim_length (claims , expected = 1 , driver = driver )
267306
307+ external_size = payload .ByteSize ()
268308 reference = _StorageReference (
269309 driver_name = driver .name (),
270310 driver_claim = claims [0 ],
@@ -274,7 +314,10 @@ async def _store_payload(self, payload: Payload) -> Payload:
274314 raise ValueError (
275315 f"Failed to serialize storage reference for driver '{ driver .name ()} '"
276316 )
277- reference_payload .external_payloads .add ().size_bytes = payload .ByteSize ()
317+ reference_payload .external_payloads .add ().size_bytes = external_size
318+
319+ ExternalStorage ._record_metrics (1 , external_size , start_time )
320+
278321 return reference_payload
279322
280323 async def _store_payloads (self , payloads : Payloads ):
@@ -289,6 +332,8 @@ async def _store_payload_sequence(
289332 if len (payloads ) == 1 :
290333 return [await self ._store_payload (payloads [0 ])]
291334
335+ start_time = time .monotonic ()
336+
292337 results = list (payloads )
293338 context = StorageDriverStoreContext (serialization_context = self ._context )
294339
@@ -315,6 +360,8 @@ async def _store_payload_sequence(
315360 ]
316361 )
317362
363+ external_count = 0
364+ external_size = 0
318365 for (driver , indexed_payloads ), claims in zip (driver_group_list , all_claims ):
319366 indices = [idx for idx , _ in indexed_payloads ]
320367 sizes = [p .ByteSize () for _ , p in indexed_payloads ]
@@ -333,13 +380,20 @@ async def _store_payload_sequence(
333380 )
334381 reference_payload .external_payloads .add ().size_bytes = sizes [i ]
335382 results [indices [i ]] = reference_payload
383+ external_size += sizes [i ]
384+
385+ external_count += len (claims )
386+
387+ ExternalStorage ._record_metrics (external_count , external_size , start_time )
336388
337389 return results
338390
339391 async def _retrieve_payload (self , payload : Payload ) -> Payload :
340392 if len (payload .external_payloads ) == 0 :
341393 return payload
342394
395+ start_time = time .monotonic ()
396+
343397 reference = self ._claim_converter .from_payload (payload , _StorageReference )
344398 if not isinstance (reference , _StorageReference ):
345399 return payload
@@ -351,7 +405,11 @@ async def _retrieve_payload(self, payload: Payload) -> Payload:
351405
352406 self ._validate_payload_length (stored_payloads , expected = 1 , driver = driver )
353407
354- return stored_payloads [0 ]
408+ stored_payload = stored_payloads [0 ]
409+
410+ ExternalStorage ._record_metrics (1 , stored_payload .ByteSize (), start_time )
411+
412+ return stored_payload
355413
356414 async def _retrieve_payloads (self , payloads : Payloads ):
357415 stored_payloads = await self ._retrieve_payload_sequence (payloads .payloads )
@@ -362,11 +420,13 @@ async def _retrieve_payload_sequence(
362420 self ,
363421 payloads : Sequence [Payload ],
364422 ) -> list [Payload ]:
365- results = list (payloads )
366-
367423 if len (payloads ) == 1 :
368424 return [await self ._retrieve_payload (payloads [0 ])]
369425
426+ start_time = time .monotonic ()
427+
428+ results = list (payloads )
429+
370430 driver_claims : dict [StorageDriver , list [tuple [int , StorageDriverClaim ]]] = {}
371431 for index , payload in enumerate (payloads ):
372432 if len (payload .external_payloads ) == 0 :
@@ -394,6 +454,8 @@ async def _retrieve_payload_sequence(
394454 ]
395455 )
396456
457+ external_count = 0
458+ external_size = 0
397459 for (driver , indexed_claims ), stored_payloads in zip (
398460 driver_claim_list , all_stored
399461 ):
@@ -407,13 +469,18 @@ async def _retrieve_payload_sequence(
407469
408470 for idx , stored_payload in zip (indices , stored_payloads ):
409471 stored_by_index [idx ] = stored_payload
472+ external_size += stored_payload .ByteSize ()
473+
474+ external_count += len (stored_payloads )
410475
411476 retrieve_indices = sorted (stored_by_index .keys ())
412477 stored_list = [stored_by_index [idx ] for idx in retrieve_indices ]
413478
414479 for i , retrieved_payload in enumerate (stored_list ):
415480 results [retrieve_indices [i ]] = retrieved_payload
416481
482+ ExternalStorage ._record_metrics (external_count , external_size , start_time )
483+
417484 return results
418485
419486 def _validate_claim_length (
@@ -431,3 +498,11 @@ def _validate_payload_length(
431498 raise ValueError (
432499 f"Driver '{ driver .name ()} ' returned { len (payloads )} payloads, expected { expected } " ,
433500 )
501+
502+ @staticmethod
503+ def _record_metrics (count : int , size : int , start_time : float ):
504+ metrics = _current_storage_metrics .get ()
505+ if metrics is not None :
506+ metrics .record_batch (
507+ count , size , timedelta (seconds = time .monotonic () - start_time )
508+ )
0 commit comments