@@ -38,11 +38,17 @@ class StorageOperationMetrics:
3838 total_duration : timedelta = dataclasses .field (default_factory = timedelta )
3939 """Wall-clock time spent on external storage operations."""
4040
41- def record_batch (self , count : int , size : int , duration : timedelta ) -> None :
41+ driver_names : set [str ] = dataclasses .field (default_factory = set )
42+ """Names of the drivers that participated in the operations."""
43+
44+ def record_batch (
45+ self , count : int , size : int , duration : timedelta , driver_names : set [str ]
46+ ) -> None :
4247 """Record metrics from a batch of storage operations."""
4348 self .payload_count += count
4449 self .total_size += size
4550 self .total_duration += duration
51+ self .driver_names .update (driver_names )
4652
4753 @contextlib .contextmanager
4854 def track (self ) -> Generator [Self , None , None ]:
@@ -362,7 +368,7 @@ async def _store_payload(self, payload: Payload) -> Payload:
362368 )
363369 reference_payload .external_payloads .add ().size_bytes = external_size
364370
365- ExternalStorage ._record_metrics (1 , external_size , start_time )
371+ ExternalStorage ._record_metrics (1 , external_size , start_time , { driver . name ()} )
366372
367373 return reference_payload
368374
@@ -407,6 +413,7 @@ async def _store_payload_sequence(
407413
408414 external_count = 0
409415 external_size = 0
416+ driver_names : set [str ] = set ()
410417 for (driver , indexed_payloads ), claims in zip (driver_group_list , all_claims ):
411418 indices = [idx for idx , _ in indexed_payloads ]
412419 sizes = [p .ByteSize () for _ , p in indexed_payloads ]
@@ -428,8 +435,11 @@ async def _store_payload_sequence(
428435 external_size += sizes [i ]
429436
430437 external_count += len (claims )
438+ driver_names .add (driver .name ())
431439
432- ExternalStorage ._record_metrics (external_count , external_size , start_time )
440+ ExternalStorage ._record_metrics (
441+ external_count , external_size , start_time , driver_names
442+ )
433443
434444 return results
435445
@@ -452,7 +462,9 @@ async def _retrieve_payload(self, payload: Payload) -> Payload:
452462
453463 stored_payload = stored_payloads [0 ]
454464
455- ExternalStorage ._record_metrics (1 , stored_payload .ByteSize (), start_time )
465+ ExternalStorage ._record_metrics (
466+ 1 , stored_payload .ByteSize (), start_time , {driver .name ()}
467+ )
456468
457469 return stored_payload
458470
@@ -501,6 +513,7 @@ async def _retrieve_payload_sequence(
501513
502514 external_count = 0
503515 external_size = 0
516+ driver_names : set [str ] = set ()
504517 for (driver , indexed_claims ), stored_payloads in zip (
505518 driver_claim_list , all_stored
506519 ):
@@ -517,14 +530,17 @@ async def _retrieve_payload_sequence(
517530 external_size += stored_payload .ByteSize ()
518531
519532 external_count += len (stored_payloads )
533+ driver_names .add (driver .name ())
520534
521535 retrieve_indices = sorted (stored_by_index .keys ())
522536 stored_list = [stored_by_index [idx ] for idx in retrieve_indices ]
523537
524538 for i , retrieved_payload in enumerate (stored_list ):
525539 results [retrieve_indices [i ]] = retrieved_payload
526540
527- ExternalStorage ._record_metrics (external_count , external_size , start_time )
541+ ExternalStorage ._record_metrics (
542+ external_count , external_size , start_time , driver_names
543+ )
528544
529545 return results
530546
@@ -545,9 +561,14 @@ def _validate_payload_length(
545561 )
546562
547563 @staticmethod
548- def _record_metrics (count : int , size : int , start_time : float ):
564+ def _record_metrics (
565+ count : int , size : int , start_time : float , driver_names : set [str ]
566+ ):
549567 metrics = _current_storage_metrics .get ()
550568 if metrics is not None :
551569 metrics .record_batch (
552- count , size , timedelta (seconds = time .monotonic () - start_time )
570+ count ,
571+ size ,
572+ timedelta (seconds = time .monotonic () - start_time ),
573+ driver_names ,
553574 )
0 commit comments