2121import torch .nn as nn
2222from torch .distributed .tensor import DeviceMesh
2323from torch .profiler import record_function
24+ from torchmetrics .utilities .data import (
25+ dim_zero_cat ,
26+ dim_zero_max ,
27+ dim_zero_mean ,
28+ dim_zero_min ,
29+ dim_zero_sum ,
30+ )
2431from torchrec .metrics .accuracy import AccuracyMetric
25- from torchrec .metrics .auc import AUCMetric
32+ from torchrec .metrics .auc import (
33+ _grouping_keys_state_reduction ,
34+ _state_reduction ,
35+ AUCMetric ,
36+ )
2637from torchrec .metrics .auprc import AUPRCMetric
2738from torchrec .metrics .average import AverageMetric
2839from torchrec .metrics .cali_free_ne import CaliFreeNEMetric
6374from torchrec .metrics .recall import RecallMetric
6475from torchrec .metrics .recall_session import RecallSessionMetric
6576from torchrec .metrics .scalar import ScalarMetric
66- from torchrec .metrics .segmented_ne import SegmentedNEMetric
77+ from torchrec .metrics .segmented_ne import _state_reduction_sum , SegmentedNEMetric
6778from torchrec .metrics .serving_calibration import ServingCalibrationMetric
6879from torchrec .metrics .serving_ne import ServingNEMetric
6980from torchrec .metrics .tensor_weighted_avg import TensorWeightedAvgMetric
7586
7687
7788logger : logging .Logger = logging .getLogger (__name__ )
89+ # TorchRec-specific custom reduction functions.
90+ # These work correctly with local+global reduction pattern.
91+ # Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
92+ SAFE_CALLABLE_REDUCTIONS : frozenset [Any ] = frozenset (
93+ {
94+ _state_reduction , # Concatenation + AUC sorts data, making final result order-invariant
95+ _grouping_keys_state_reduction , # Concatenation along dim=0 + sorting makes result order-invariant
96+ _state_reduction_sum , # Sum on dimension 0.
97+ }
98+ )
99+
100+ # torchmetrics.Metric built-in reduction functions.
101+ # All dim_zero_* functions are both associative and commutative (dim_zero_cat is not commutative
102+ # but torchmetrics.Metric also reduce before sync_dist to reduce number of collectives).
103+ TORCHMETRICS_REDUCTIONS : frozenset [Any ] = frozenset (
104+ {
105+ dim_zero_sum ,
106+ dim_zero_mean ,
107+ dim_zero_max ,
108+ dim_zero_min ,
109+ dim_zero_cat ,
110+ }
111+ )
112+
113+
114+ def _validate_reduction_function (
115+ reduction_fn : Union [str , Any , None ],
116+ state_name : str ,
117+ metric_namespace : str ,
118+ ) -> None :
119+ """
120+ Validate that a reduction function is safe for local+global reduction pattern.
121+
122+ Only validates custom reduction functions. TorchMetrics built-in functions
123+ (dim_zero_*) are skipped as they're safe by construction (all are associative & commutative).
124+
125+ Mathematical Requirements:
126+ 1. **Associativity**: f([f([a,b]), f([c,d])]) = f([a,b,c,d])
127+ - Required so local reduction + global reduction = direct reduction
128+
129+ 2. **Commutativity**: f([a, b]) = f([b, a])
130+ - Required so rank ordering doesn't affect the result
131+ - OR the metric's computation must make the final result order-invariant
132+ (e.g., AUC concatenates in rank order but sorts before computing, making final result order-invariant)
133+ """
134+ # Skip validation for None and torchmetrics.Metric built-in functions (safe by construction)
135+ if reduction_fn is None or reduction_fn in TORCHMETRICS_REDUCTIONS :
136+ return
137+
138+ # Validate custom callable reductions
139+ if callable (reduction_fn ):
140+ if reduction_fn not in SAFE_CALLABLE_REDUCTIONS :
141+ raise RecMetricException (
142+ f"Unknown custom reduction '{ reduction_fn } ' for state '{ state_name } ' in '{ metric_namespace } '. "
143+ f"Must be associative: f([f([a,b]), f([c,d])]) == f([a,b,c,d]) "
144+ f"AND commutative: f([a,b]) == f([b,a]) (or metric makes result order-invariant). "
145+ f"Known safe custom reductions: { [f for f in SAFE_CALLABLE_REDUCTIONS if f not in TORCHMETRICS_REDUCTIONS ]} . "
146+ f"Add to SAFE_CALLABLE_REDUCTIONS if verified safe."
147+ )
148+
78149
79150REC_METRICS_MAPPING : Dict [RecMetricEnumBase , Type [RecMetric ]] = {
80151 RecMetricEnum .NE : NEMetric ,
@@ -218,6 +289,8 @@ def __init__(
218289 self .oom_count = 0
219290 self .compute_count = 0
220291
292+ self ._validate_all_reduction_functions ()
293+
221294 self .compute_interval_steps = compute_interval_steps
222295 self .min_compute_interval = min_compute_interval
223296 self .max_compute_interval = max_compute_interval
@@ -240,6 +313,20 @@ def __init__(
240313
241314 self ._register_load_state_dict_pre_hook (self .load_state_dict_hook )
242315
316+ def _validate_all_reduction_functions (self ) -> None :
317+ """
318+ Validate all reduction functions in rec_metrics during initialization.
319+ This ensures that all reduction functions are safe for the local+global reduction pattern.
320+ """
321+ for metric in self .rec_metrics .rec_metrics :
322+ for computation in metric ._metrics_computations : # pyre-ignore[16]
323+ for state_name , reduction_fn in computation ._reductions .items (): # pyre-ignore[16]
324+ _validate_reduction_function (
325+ reduction_fn ,
326+ state_name ,
327+ metric ._namespace .value , # pyre-ignore[16]
328+ )
329+
243330 def load_state_dict_hook (
244331 self ,
245332 state_dict : OrderedDict [str , torch .Tensor ],
@@ -408,22 +495,24 @@ def _get_metric_states(
408495 # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
409496 # `items`.
410497 for state_name , reduction_fn in computation ._reductions .items ():
411- tensor_or_list : Union [List [torch .Tensor ], torch .Tensor ] = getattr (
412- computation , state_name
413- )
414-
415- if isinstance (tensor_or_list , list ):
416- gathered = _all_gather_tensor_list (
417- tensor_or_list , world_size , process_group
418- )
419- else :
420- gathered = torch .stack (
421- _all_gather_tensor (tensor_or_list , world_size , process_group )
498+ with record_function (f"## RecMetricModule: { state_name } all gather ##" ):
499+ tensor_or_list : Union [List [torch .Tensor ], torch .Tensor ] = getattr (
500+ computation , state_name
422501 )
423- reduced = (
424- reduction_fn (gathered ) if reduction_fn is not None else gathered
425- )
426- result [task .name ][state_name ] = reduced
502+
503+ if isinstance (tensor_or_list , list ):
504+ local_reduced = reduction_fn (tensor_or_list )
505+ gathered = _all_gather_tensor_list (
506+ local_reduced , world_size , process_group
507+ )
508+ else :
509+ gathered = torch .stack (
510+ _all_gather_tensor (
511+ tensor_or_list , world_size , process_group
512+ )
513+ )
514+ global_reduced = reduction_fn (gathered )
515+ result [task .name ][state_name ] = global_reduced
427516
428517 return result
429518
@@ -472,7 +561,8 @@ def get_pre_compute_states(
472561 # throughput metric requires special handling, since it's not a RecMetric
473562 throughput_metric = self .throughput_metric
474563 if throughput_metric is not None :
475- aggregated_states [throughput_metric ._namespace .value ] = (
564+ # Merge in case there are rec metric namespaces that overlap with throughput metric namespace
565+ aggregated_states .setdefault (throughput_metric ._namespace .value , {}).update (
476566 self ._get_throughput_metric_states (throughput_metric )
477567 )
478568
@@ -666,8 +756,23 @@ def _all_gather_tensor_list(
666756 world_size : int ,
667757 pg : Union [dist .ProcessGroup , DeviceMesh ],
668758) -> List [torch .Tensor ]:
669- """All-gather every tensor in a list and flatten the result."""
670- gathered : List [torch .Tensor ] = [] # pragma: no cover
759+ """
760+ All-gather every tensor in a list and flatten the result.
761+
762+ Note: In the current implementation with local reduction in _get_metric_states,
763+ this function should only receive a list with at most 1 tensor after local reduction.
764+ """
765+ if not tensors :
766+ return []
767+
768+ # After local reduction in _get_metric_states, tensors should contain at most 1 element
769+ if len (tensors ) > 1 :
770+ raise ValueError (
771+ f"_all_gather_tensor_list expected at most 1 tensor after local reduction, "
772+ f"but received { len (tensors )} tensors. This indicates a bug in _get_metric_states."
773+ )
774+
775+ gathered : List [torch .Tensor ] = []
671776 for t in tensors :
672777 gathered .extend (_all_gather_tensor (t , world_size , pg ))
673778 return gathered
0 commit comments