2121import torch .nn as nn
2222from torch .distributed .tensor import DeviceMesh
2323from torch .profiler import record_function
24+ from torchmetrics .utilities .data import (
25+ dim_zero_cat ,
26+ dim_zero_max ,
27+ dim_zero_mean ,
28+ dim_zero_min ,
29+ dim_zero_sum ,
30+ )
2431from torchrec .metrics .accuracy import AccuracyMetric
25- from torchrec .metrics .auc import AUCMetric
26- from torchrec .metrics .auprc import AUPRCMetric
32+ from torchrec .metrics .auc import (
33+ _grouping_keys_state_reduction ,
34+ _state_reduction ,
35+ AUCMetric ,
36+ )
37+ from torchrec .metrics .auprc import (
38+ _grouping_keys_state_reduction as auprc_grouping_keys_state_reduction ,
39+ _state_reduction as auprc_state_reduction ,
40+ AUPRCMetric ,
41+ )
2742from torchrec .metrics .average import AverageMetric
2843from torchrec .metrics .cali_free_ne import CaliFreeNEMetric
2944from torchrec .metrics .calibration import CalibrationMetric
5873from torchrec .metrics .output import OutputMetric
5974from torchrec .metrics .precision import PrecisionMetric
6075from torchrec .metrics .precision_session import PrecisionSessionMetric
61- from torchrec .metrics .rauc import RAUCMetric
76+ from torchrec .metrics .rauc import (
77+ _grouping_keys_state_reduction as rauc_grouping_keys_state_reduction ,
78+ _state_reduction as rauc_state_reduction ,
79+ RAUCMetric ,
80+ )
6281from torchrec .metrics .rec_metric import RecMetric , RecMetricException , RecMetricList
6382from torchrec .metrics .recall import RecallMetric
6483from torchrec .metrics .recall_session import RecallSessionMetric
6584from torchrec .metrics .scalar import ScalarMetric
66- from torchrec .metrics .segmented_ne import SegmentedNEMetric
85+ from torchrec .metrics .segmented_ne import _state_reduction_sum , SegmentedNEMetric
6786from torchrec .metrics .serving_calibration import ServingCalibrationMetric
6887from torchrec .metrics .serving_ne import ServingNEMetric
6988from torchrec .metrics .tensor_weighted_avg import TensorWeightedAvgMetric
7089from torchrec .metrics .throughput import ThroughputMetric
71- from torchrec .metrics .tower_qps import TowerQPSMetric
90+ from torchrec .metrics .tower_qps import _max_reduction , TowerQPSMetric
7291from torchrec .metrics .unweighted_ne import UnweightedNEMetric
7392from torchrec .metrics .weighted_avg import WeightedAvgMetric
7493from torchrec .metrics .xauc import XAUCMetric
7594
7695
7796logger : logging .Logger = logging .getLogger (__name__ )
97+ # TorchRec-specific custom reduction functions.
98+ # These work correctly with local+global reduction pattern.
99+ # Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
100+ SAFE_CALLABLE_REDUCTIONS : frozenset [Any ] = frozenset (
101+ {
102+ _state_reduction , # Concatenation + AUC sorts data, making final result order-invariant
103+ _grouping_keys_state_reduction , # Concatenation along dim=0 + sorting makes result order-invariant
104+ auprc_state_reduction ,
105+ auprc_grouping_keys_state_reduction ,
106+ rauc_state_reduction ,
107+ rauc_grouping_keys_state_reduction ,
108+ _state_reduction_sum , # Sum on dimension 0.
109+ _max_reduction , # Max is associative and commutative.
110+ }
111+ )
112+
113+ # torchmetrics.Metric built-in reduction functions.
114+ # All dim_zero_* functions are both associative and commutative (dim_zero_cat is not commutative
115+ # but torchmetrics.Metric also reduce before sync_dist to reduce number of collectives).
116+ TORCHMETRICS_REDUCTIONS : frozenset [Any ] = frozenset (
117+ {
118+ dim_zero_sum ,
119+ dim_zero_mean ,
120+ dim_zero_max ,
121+ dim_zero_min ,
122+ dim_zero_cat ,
123+ }
124+ )
125+
126+
127+ def _validate_reduction_function (
128+ reduction_fn : Union [str , Any , None ],
129+ state_name : str ,
130+ metric_namespace : str ,
131+ ) -> None :
132+ """
133+ Validate that a reduction function is safe for local+global reduction pattern.
134+
135+ Only validates custom reduction functions. TorchMetrics built-in functions
136+ (dim_zero_*) are skipped as they're safe by construction (all are associative & commutative).
137+
138+ Mathematical Requirements:
139+ 1. **Associativity**: f([f([a,b]), f([c,d])]) = f([a,b,c,d])
140+ - Required so local reduction + global reduction = direct reduction
141+
142+ 2. **Commutativity**: f([a, b]) = f([b, a])
143+ - Required so rank ordering doesn't affect the result
144+ - OR the metric's computation must make the final result order-invariant
145+ (e.g., AUC concatenates in rank order but sorts before computing, making final result order-invariant)
146+ """
147+ # Skip validation for None and torchmetrics.Metric built-in functions (safe by construction)
148+ if reduction_fn is None or reduction_fn in TORCHMETRICS_REDUCTIONS :
149+ return
150+
151+ # Validate custom callable reductions
152+ if callable (reduction_fn ):
153+ if reduction_fn not in SAFE_CALLABLE_REDUCTIONS :
154+ raise RecMetricException (
155+ f"Unknown custom reduction '{ reduction_fn } ' for state '{ state_name } ' in '{ metric_namespace } '. "
156+ f"Must be associative: f([f([a,b]), f([c,d])]) == f([a,b,c,d]) "
157+ f"AND commutative: f([a,b]) == f([b,a]) (or metric makes result order-invariant). "
158+ f"Known safe custom reductions: { [f for f in SAFE_CALLABLE_REDUCTIONS if f not in TORCHMETRICS_REDUCTIONS ]} . "
159+ f"Add to SAFE_CALLABLE_REDUCTIONS if verified safe."
160+ )
161+
78162
79163REC_METRICS_MAPPING : Dict [RecMetricEnumBase , Type [RecMetric ]] = {
80164 RecMetricEnum .NE : NEMetric ,
@@ -218,6 +302,8 @@ def __init__(
218302 self .oom_count = 0
219303 self .compute_count = 0
220304
305+ self ._validate_all_reduction_functions ()
306+
221307 self .compute_interval_steps = compute_interval_steps
222308 self .min_compute_interval = min_compute_interval
223309 self .max_compute_interval = max_compute_interval
@@ -240,6 +326,20 @@ def __init__(
240326
241327 self ._register_load_state_dict_pre_hook (self .load_state_dict_hook )
242328
329+ def _validate_all_reduction_functions (self ) -> None :
330+ """
331+ Validate all reduction functions in rec_metrics during initialization.
332+ This ensures that all reduction functions are safe for the local+global reduction pattern.
333+ """
334+ for metric in self .rec_metrics .rec_metrics :
335+ for computation in metric ._metrics_computations : # pyre-ignore[16]
336+ for state_name , reduction_fn in computation ._reductions .items (): # pyre-ignore[16]
337+ _validate_reduction_function (
338+ reduction_fn ,
339+ state_name ,
340+ metric ._namespace .value , # pyre-ignore[16]
341+ )
342+
243343 def load_state_dict_hook (
244344 self ,
245345 state_dict : OrderedDict [str , torch .Tensor ],
@@ -408,22 +508,24 @@ def _get_metric_states(
408508 # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
409509 # `items`.
410510 for state_name , reduction_fn in computation ._reductions .items ():
411- tensor_or_list : Union [List [torch .Tensor ], torch .Tensor ] = getattr (
412- computation , state_name
413- )
414-
415- if isinstance (tensor_or_list , list ):
416- gathered = _all_gather_tensor_list (
417- tensor_or_list , world_size , process_group
418- )
419- else :
420- gathered = torch .stack (
421- _all_gather_tensor (tensor_or_list , world_size , process_group )
511+ with record_function (f"## RecMetricModule: { state_name } all gather ##" ):
512+ tensor_or_list : Union [List [torch .Tensor ], torch .Tensor ] = getattr (
513+ computation , state_name
422514 )
423- reduced = (
424- reduction_fn (gathered ) if reduction_fn is not None else gathered
425- )
426- result [task .name ][state_name ] = reduced
515+
516+ if isinstance (tensor_or_list , list ):
517+ local_reduced = reduction_fn (tensor_or_list )
518+ gathered = _all_gather_tensor_list (
519+ local_reduced , world_size , process_group
520+ )
521+ else :
522+ gathered = torch .stack (
523+ _all_gather_tensor (
524+ tensor_or_list , world_size , process_group
525+ )
526+ )
527+ global_reduced = reduction_fn (gathered )
528+ result [task .name ][state_name ] = global_reduced
427529
428530 return result
429531
@@ -472,7 +574,8 @@ def get_pre_compute_states(
472574 # throughput metric requires special handling, since it's not a RecMetric
473575 throughput_metric = self .throughput_metric
474576 if throughput_metric is not None :
475- aggregated_states [throughput_metric ._namespace .value ] = (
577+ # Merge in case there are rec metric namespaces that overlap with throughput metric namespace
578+ aggregated_states .setdefault (throughput_metric ._namespace .value , {}).update (
476579 self ._get_throughput_metric_states (throughput_metric )
477580 )
478581
@@ -666,8 +769,23 @@ def _all_gather_tensor_list(
666769 world_size : int ,
667770 pg : Union [dist .ProcessGroup , DeviceMesh ],
668771) -> List [torch .Tensor ]:
669- """All-gather every tensor in a list and flatten the result."""
670- gathered : List [torch .Tensor ] = [] # pragma: no cover
772+ """
773+ All-gather every tensor in a list and flatten the result.
774+
775+ Note: In the current implementation with local reduction in _get_metric_states,
776+ this function should only receive a list with at most 1 tensor after local reduction.
777+ """
778+ if not tensors :
779+ return []
780+
781+ # After local reduction in _get_metric_states, tensors should contain at most 1 element
782+ if len (tensors ) > 1 :
783+ raise ValueError (
784+ f"_all_gather_tensor_list expected at most 1 tensor after local reduction, "
785+ f"but received { len (tensors )} tensors. This indicates a bug in _get_metric_states."
786+ )
787+
788+ gathered : List [torch .Tensor ] = []
671789 for t in tensors :
672790 gathered .extend (_all_gather_tensor (t , world_size , pg ))
673791 return gathered
0 commit comments