Skip to content

Commit 4daa523

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
RecMetricModule: apply reduction function before gloo all gathers (meta-pytorch#3593)
Summary: metric_module's get_pre_compute_states() provides an API to perform gloo all gathers instead of the default torchmetric.Metric's sync_dist (nccl). However, the mechanism calls gloo all gathers for each element in a list of tensors. This can be problematic because: - AUC's 3 state tensors hold a list of tensors, not a single tensor. - The size of the tensor list is theoretically unbounded. (In practice, it can grow to orders of 100K) - gloo all gathers are inherently much slower. Instead, this patch aims to: - apply the reduction function prior to the all gather if we're processing a tensor list - enforce that the reduction_fn does not rely on ordering Differential Revision: D88297404
1 parent 1b69fd6 commit 4daa523

File tree

2 files changed

+420
-27
lines changed

2 files changed

+420
-27
lines changed

torchrec/metrics/metric_module.py

Lines changed: 141 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,24 @@
2121
import torch.nn as nn
2222
from torch.distributed.tensor import DeviceMesh
2323
from torch.profiler import record_function
24+
from torchmetrics.utilities.data import (
25+
dim_zero_cat,
26+
dim_zero_max,
27+
dim_zero_mean,
28+
dim_zero_min,
29+
dim_zero_sum,
30+
)
2431
from torchrec.metrics.accuracy import AccuracyMetric
25-
from torchrec.metrics.auc import AUCMetric
26-
from torchrec.metrics.auprc import AUPRCMetric
32+
from torchrec.metrics.auc import (
33+
_grouping_keys_state_reduction,
34+
_state_reduction,
35+
AUCMetric,
36+
)
37+
from torchrec.metrics.auprc import (
38+
_grouping_keys_state_reduction as auprc_grouping_keys_state_reduction,
39+
_state_reduction as auprc_state_reduction,
40+
AUPRCMetric,
41+
)
2742
from torchrec.metrics.average import AverageMetric
2843
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
2944
from torchrec.metrics.calibration import CalibrationMetric
@@ -58,23 +73,92 @@
5873
from torchrec.metrics.output import OutputMetric
5974
from torchrec.metrics.precision import PrecisionMetric
6075
from torchrec.metrics.precision_session import PrecisionSessionMetric
61-
from torchrec.metrics.rauc import RAUCMetric
76+
from torchrec.metrics.rauc import (
77+
_grouping_keys_state_reduction as rauc_grouping_keys_state_reduction,
78+
_state_reduction as rauc_state_reduction,
79+
RAUCMetric,
80+
)
6281
from torchrec.metrics.rec_metric import RecMetric, RecMetricException, RecMetricList
6382
from torchrec.metrics.recall import RecallMetric
6483
from torchrec.metrics.recall_session import RecallSessionMetric
6584
from torchrec.metrics.scalar import ScalarMetric
66-
from torchrec.metrics.segmented_ne import SegmentedNEMetric
85+
from torchrec.metrics.segmented_ne import _state_reduction_sum, SegmentedNEMetric
6786
from torchrec.metrics.serving_calibration import ServingCalibrationMetric
6887
from torchrec.metrics.serving_ne import ServingNEMetric
6988
from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric
7089
from torchrec.metrics.throughput import ThroughputMetric
71-
from torchrec.metrics.tower_qps import TowerQPSMetric
90+
from torchrec.metrics.tower_qps import _max_reduction, TowerQPSMetric
7291
from torchrec.metrics.unweighted_ne import UnweightedNEMetric
7392
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7493
from torchrec.metrics.xauc import XAUCMetric
7594

7695

7796
logger: logging.Logger = logging.getLogger(__name__)
97+
# TorchRec-specific custom reduction functions.
98+
# These work correctly with local+global reduction pattern.
99+
# Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
100+
SAFE_CALLABLE_REDUCTIONS: frozenset[Any] = frozenset(
101+
{
102+
_state_reduction, # Concatenation + AUC sorts data, making final result order-invariant
103+
_grouping_keys_state_reduction, # Concatenation along dim=0 + sorting makes result order-invariant
104+
auprc_state_reduction,
105+
auprc_grouping_keys_state_reduction,
106+
rauc_state_reduction,
107+
rauc_grouping_keys_state_reduction,
108+
_state_reduction_sum, # Sum on dimension 0.
109+
_max_reduction, # Max is associative and commutative.
110+
}
111+
)
112+
113+
# torchmetrics.Metric built-in reduction functions.
114+
# All dim_zero_* functions are both associative and commutative (dim_zero_cat is not commutative
115+
# but torchmetrics.Metric also reduce before sync_dist to reduce number of collectives).
116+
TORCHMETRICS_REDUCTIONS: frozenset[Any] = frozenset(
117+
{
118+
dim_zero_sum,
119+
dim_zero_mean,
120+
dim_zero_max,
121+
dim_zero_min,
122+
dim_zero_cat,
123+
}
124+
)
125+
126+
127+
def _validate_reduction_function(
128+
reduction_fn: Union[str, Any, None],
129+
state_name: str,
130+
metric_namespace: str,
131+
) -> None:
132+
"""
133+
Validate that a reduction function is safe for local+global reduction pattern.
134+
135+
Only validates custom reduction functions. TorchMetrics built-in functions
136+
(dim_zero_*) are skipped as they're safe by construction (all are associative & commutative).
137+
138+
Mathematical Requirements:
139+
1. **Associativity**: f([f([a,b]), f([c,d])]) = f([a,b,c,d])
140+
- Required so local reduction + global reduction = direct reduction
141+
142+
2. **Commutativity**: f([a, b]) = f([b, a])
143+
- Required so rank ordering doesn't affect the result
144+
- OR the metric's computation must make the final result order-invariant
145+
(e.g., AUC concatenates in rank order but sorts before computing, making final result order-invariant)
146+
"""
147+
# Skip validation for None and torchmetrics.Metric built-in functions (safe by construction)
148+
if reduction_fn is None or reduction_fn in TORCHMETRICS_REDUCTIONS:
149+
return
150+
151+
# Validate custom callable reductions
152+
if callable(reduction_fn):
153+
if reduction_fn not in SAFE_CALLABLE_REDUCTIONS:
154+
raise RecMetricException(
155+
f"Unknown custom reduction '{reduction_fn}' for state '{state_name}' in '{metric_namespace}'. "
156+
f"Must be associative: f([f([a,b]), f([c,d])]) == f([a,b,c,d]) "
157+
f"AND commutative: f([a,b]) == f([b,a]) (or metric makes result order-invariant). "
158+
f"Known safe custom reductions: {[f for f in SAFE_CALLABLE_REDUCTIONS if f not in TORCHMETRICS_REDUCTIONS]}. "
159+
f"Add to SAFE_CALLABLE_REDUCTIONS if verified safe."
160+
)
161+
78162

79163
REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
80164
RecMetricEnum.NE: NEMetric,
@@ -218,6 +302,8 @@ def __init__(
218302
self.oom_count = 0
219303
self.compute_count = 0
220304

305+
self._validate_all_reduction_functions()
306+
221307
self.compute_interval_steps = compute_interval_steps
222308
self.min_compute_interval = min_compute_interval
223309
self.max_compute_interval = max_compute_interval
@@ -240,6 +326,20 @@ def __init__(
240326

241327
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
242328

329+
def _validate_all_reduction_functions(self) -> None:
330+
"""
331+
Validate all reduction functions in rec_metrics during initialization.
332+
This ensures that all reduction functions are safe for the local+global reduction pattern.
333+
"""
334+
for metric in self.rec_metrics.rec_metrics:
335+
for computation in metric._metrics_computations: # pyre-ignore[16]
336+
for state_name, reduction_fn in computation._reductions.items(): # pyre-ignore[16]
337+
_validate_reduction_function(
338+
reduction_fn,
339+
state_name,
340+
metric._namespace.value, # pyre-ignore[16]
341+
)
342+
243343
def load_state_dict_hook(
244344
self,
245345
state_dict: OrderedDict[str, torch.Tensor],
@@ -408,22 +508,24 @@ def _get_metric_states(
408508
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
409509
# `items`.
410510
for state_name, reduction_fn in computation._reductions.items():
411-
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
412-
computation, state_name
413-
)
414-
415-
if isinstance(tensor_or_list, list):
416-
gathered = _all_gather_tensor_list(
417-
tensor_or_list, world_size, process_group
418-
)
419-
else:
420-
gathered = torch.stack(
421-
_all_gather_tensor(tensor_or_list, world_size, process_group)
511+
with record_function(f"## RecMetricModule: {state_name} all gather ##"):
512+
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
513+
computation, state_name
422514
)
423-
reduced = (
424-
reduction_fn(gathered) if reduction_fn is not None else gathered
425-
)
426-
result[task.name][state_name] = reduced
515+
516+
if isinstance(tensor_or_list, list):
517+
local_reduced = reduction_fn(tensor_or_list)
518+
gathered = _all_gather_tensor_list(
519+
local_reduced, world_size, process_group
520+
)
521+
else:
522+
gathered = torch.stack(
523+
_all_gather_tensor(
524+
tensor_or_list, world_size, process_group
525+
)
526+
)
527+
global_reduced = reduction_fn(gathered)
528+
result[task.name][state_name] = global_reduced
427529

428530
return result
429531

@@ -472,7 +574,8 @@ def get_pre_compute_states(
472574
# throughput metric requires special handling, since it's not a RecMetric
473575
throughput_metric = self.throughput_metric
474576
if throughput_metric is not None:
475-
aggregated_states[throughput_metric._namespace.value] = (
577+
# Merge in case there are rec metric namespaces that overlap with throughput metric namespace
578+
aggregated_states.setdefault(throughput_metric._namespace.value, {}).update(
476579
self._get_throughput_metric_states(throughput_metric)
477580
)
478581

@@ -666,8 +769,23 @@ def _all_gather_tensor_list(
666769
world_size: int,
667770
pg: Union[dist.ProcessGroup, DeviceMesh],
668771
) -> List[torch.Tensor]:
669-
"""All-gather every tensor in a list and flatten the result."""
670-
gathered: List[torch.Tensor] = [] # pragma: no cover
772+
"""
773+
All-gather every tensor in a list and flatten the result.
774+
775+
Note: In the current implementation with local reduction in _get_metric_states,
776+
this function should only receive a list with at most 1 tensor after local reduction.
777+
"""
778+
if not tensors:
779+
return []
780+
781+
# After local reduction in _get_metric_states, tensors should contain at most 1 element
782+
if len(tensors) > 1:
783+
raise ValueError(
784+
f"_all_gather_tensor_list expected at most 1 tensor after local reduction, "
785+
f"but received {len(tensors)} tensors. This indicates a bug in _get_metric_states."
786+
)
787+
788+
gathered: List[torch.Tensor] = []
671789
for t in tensors:
672790
gathered.extend(_all_gather_tensor(t, world_size, pg))
673791
return gathered

0 commit comments

Comments
 (0)