Skip to content

Commit 05c1b3d

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
RecMetricModule: apply reduction function before gloo all gathers (meta-pytorch#3593)
Summary: Pull Request resolved: meta-pytorch#3593 Differential Revision: D88297404
1 parent 8487722 commit 05c1b3d

File tree

2 files changed

+404
-24
lines changed

2 files changed

+404
-24
lines changed

torchrec/metrics/metric_module.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,19 @@
2121
import torch.nn as nn
2222
from torch.distributed.tensor import DeviceMesh
2323
from torch.profiler import record_function
24+
from torchmetrics.utilities.data import (
25+
dim_zero_cat,
26+
dim_zero_max,
27+
dim_zero_mean,
28+
dim_zero_min,
29+
dim_zero_sum,
30+
)
2431
from torchrec.metrics.accuracy import AccuracyMetric
25-
from torchrec.metrics.auc import AUCMetric
32+
from torchrec.metrics.auc import (
33+
_grouping_keys_state_reduction,
34+
_state_reduction,
35+
AUCMetric,
36+
)
2637
from torchrec.metrics.auprc import AUPRCMetric
2738
from torchrec.metrics.average import AverageMetric
2839
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
@@ -63,7 +74,7 @@
6374
from torchrec.metrics.recall import RecallMetric
6475
from torchrec.metrics.recall_session import RecallSessionMetric
6576
from torchrec.metrics.scalar import ScalarMetric
66-
from torchrec.metrics.segmented_ne import SegmentedNEMetric
77+
from torchrec.metrics.segmented_ne import _state_reduction_sum, SegmentedNEMetric
6778
from torchrec.metrics.serving_calibration import ServingCalibrationMetric
6879
from torchrec.metrics.serving_ne import ServingNEMetric
6980
from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric
@@ -75,6 +86,66 @@
7586

7687

7788
logger: logging.Logger = logging.getLogger(__name__)
89+
# TorchRec-specific custom reduction functions.
90+
# These work correctly with local+global reduction pattern.
91+
# Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
92+
SAFE_CALLABLE_REDUCTIONS: frozenset[Any] = frozenset(
93+
{
94+
_state_reduction, # Concatenation + AUC sorts data, making final result order-invariant
95+
_grouping_keys_state_reduction, # Concatenation along dim=0 + sorting makes result order-invariant
96+
_state_reduction_sum, # Sum on dimension 0.
97+
}
98+
)
99+
100+
# torchmetrics.Metric built-in reduction functions.
101+
# All dim_zero_* functions are both associative and commutative (dim_zero_cat is not commutative
102+
# but torchmetrics.Metric also reduce before sync_dist to reduce number of collectives).
103+
TORCHMETRICS_REDUCTIONS: frozenset[Any] = frozenset(
104+
{
105+
dim_zero_sum,
106+
dim_zero_mean,
107+
dim_zero_max,
108+
dim_zero_min,
109+
dim_zero_cat,
110+
}
111+
)
112+
113+
114+
def _validate_reduction_function(
115+
reduction_fn: Union[str, Any, None],
116+
state_name: str,
117+
metric_namespace: str,
118+
) -> None:
119+
"""
120+
Validate that a reduction function is safe for local+global reduction pattern.
121+
122+
Only validates custom reduction functions. TorchMetrics built-in functions
123+
(dim_zero_*) are skipped as they're safe by construction (all are associative & commutative).
124+
125+
Mathematical Requirements:
126+
1. **Associativity**: f([f([a,b]), f([c,d])]) = f([a,b,c,d])
127+
- Required so local reduction + global reduction = direct reduction
128+
129+
2. **Commutativity**: f([a, b]) = f([b, a])
130+
- Required so rank ordering doesn't affect the result
131+
- OR the metric's computation must make the final result order-invariant
132+
(e.g., AUC concatenates in rank order but sorts before computing, making final result order-invariant)
133+
"""
134+
# Skip validation for None and torchmetrics.Metric built-in functions (safe by construction)
135+
if reduction_fn is None or reduction_fn in TORCHMETRICS_REDUCTIONS:
136+
return
137+
138+
# Validate custom callable reductions
139+
if callable(reduction_fn):
140+
if reduction_fn not in SAFE_CALLABLE_REDUCTIONS:
141+
raise RecMetricException(
142+
f"Unknown custom reduction '{reduction_fn}' for state '{state_name}' in '{metric_namespace}'. "
143+
f"Must be associative: f([f([a,b]), f([c,d])]) == f([a,b,c,d]) "
144+
f"AND commutative: f([a,b]) == f([b,a]) (or metric makes result order-invariant). "
145+
f"Known safe custom reductions: {[f for f in SAFE_CALLABLE_REDUCTIONS if f not in TORCHMETRICS_REDUCTIONS]}. "
146+
f"Add to SAFE_CALLABLE_REDUCTIONS if verified safe."
147+
)
148+
78149

79150
REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
80151
RecMetricEnum.NE: NEMetric,
@@ -218,6 +289,8 @@ def __init__(
218289
self.oom_count = 0
219290
self.compute_count = 0
220291

292+
self._validate_all_reduction_functions()
293+
221294
self.compute_interval_steps = compute_interval_steps
222295
self.min_compute_interval = min_compute_interval
223296
self.max_compute_interval = max_compute_interval
@@ -240,6 +313,20 @@ def __init__(
240313

241314
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
242315

316+
def _validate_all_reduction_functions(self) -> None:
317+
"""
318+
Validate all reduction functions in rec_metrics during initialization.
319+
This ensures that all reduction functions are safe for the local+global reduction pattern.
320+
"""
321+
for metric in self.rec_metrics.rec_metrics:
322+
for computation in metric._metrics_computations: # pyre-ignore[16]
323+
for state_name, reduction_fn in computation._reductions.items(): # pyre-ignore[16]
324+
_validate_reduction_function(
325+
reduction_fn,
326+
state_name,
327+
metric._namespace.value, # pyre-ignore[16]
328+
)
329+
243330
def load_state_dict_hook(
244331
self,
245332
state_dict: OrderedDict[str, torch.Tensor],
@@ -408,22 +495,24 @@ def _get_metric_states(
408495
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
409496
# `items`.
410497
for state_name, reduction_fn in computation._reductions.items():
411-
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
412-
computation, state_name
413-
)
414-
415-
if isinstance(tensor_or_list, list):
416-
gathered = _all_gather_tensor_list(
417-
tensor_or_list, world_size, process_group
418-
)
419-
else:
420-
gathered = torch.stack(
421-
_all_gather_tensor(tensor_or_list, world_size, process_group)
498+
with record_function(f"## RecMetricModule: {state_name} all gather ##"):
499+
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
500+
computation, state_name
422501
)
423-
reduced = (
424-
reduction_fn(gathered) if reduction_fn is not None else gathered
425-
)
426-
result[task.name][state_name] = reduced
502+
503+
if isinstance(tensor_or_list, list):
504+
local_reduced = reduction_fn(tensor_or_list)
505+
gathered = _all_gather_tensor_list(
506+
local_reduced, world_size, process_group
507+
)
508+
else:
509+
gathered = torch.stack(
510+
_all_gather_tensor(
511+
tensor_or_list, world_size, process_group
512+
)
513+
)
514+
global_reduced = reduction_fn(gathered)
515+
result[task.name][state_name] = global_reduced
427516

428517
return result
429518

@@ -472,7 +561,8 @@ def get_pre_compute_states(
472561
# throughput metric requires special handling, since it's not a RecMetric
473562
throughput_metric = self.throughput_metric
474563
if throughput_metric is not None:
475-
aggregated_states[throughput_metric._namespace.value] = (
564+
# Merge in case there are rec metric namespaces that overlap with throughput metric namespace
565+
aggregated_states.setdefault(throughput_metric._namespace.value, {}).update(
476566
self._get_throughput_metric_states(throughput_metric)
477567
)
478568

@@ -666,8 +756,23 @@ def _all_gather_tensor_list(
666756
world_size: int,
667757
pg: Union[dist.ProcessGroup, DeviceMesh],
668758
) -> List[torch.Tensor]:
669-
"""All-gather every tensor in a list and flatten the result."""
670-
gathered: List[torch.Tensor] = [] # pragma: no cover
759+
"""
760+
All-gather every tensor in a list and flatten the result.
761+
762+
Note: In the current implementation with local reduction in _get_metric_states,
763+
this function should only receive a list with at most 1 tensor after local reduction.
764+
"""
765+
if not tensors:
766+
return []
767+
768+
# After local reduction in _get_metric_states, tensors should contain at most 1 element
769+
if len(tensors) > 1:
770+
raise ValueError(
771+
f"_all_gather_tensor_list expected at most 1 tensor after local reduction, "
772+
f"but received {len(tensors)} tensors. This indicates a bug in _get_metric_states."
773+
)
774+
775+
gathered: List[torch.Tensor] = []
671776
for t in tensors:
672777
gathered.extend(_all_gather_tensor(t, world_size, pg))
673778
return gathered

0 commit comments

Comments
 (0)