Skip to content

Commit dcdb254

Browse files
oulgenpytorchmergebot
authored andcommitted
Make FX Graph Cache work with distributed training (pytorch#133374)
During distributed training if all ranks except one hit the cache, the rank that did not hit the cache will cause a NCCL timeout since rest of the ranks will enter the collective and start the timer. This PR uses the new PTD API to increase timeout for the ranks that hit the cache by the amount of time the cache would save. Pull Request resolved: pytorch#133374 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#133362, pytorch#133363
1 parent 6d42874 commit dcdb254

File tree

4 files changed

+79
-5
lines changed

4 files changed

+79
-5
lines changed

test/distributed/test_dynamo_distributed.py

+55
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,61 @@ def f(x):
964964
for r in res[1:]:
965965
self.assertEqual(res[0], r)
966966

967+
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
968+
@patch.object(torch._inductor.config, "fx_graph_cache", True)
969+
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
970+
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
971+
def test_asymmetric_compilation_with_fx_cache(self):
972+
from torch._dynamo.utils import counters
973+
from torch._inductor.utils import fresh_inductor_cache
974+
975+
with fresh_inductor_cache(), _dynamo_dist_per_rank_init(
976+
self.rank, self.world_size
977+
):
978+
torch._dynamo.utils.clear_compilation_metrics()
979+
980+
device = f"cuda:{self.rank}"
981+
982+
pg = dist.distributed_c10d._get_default_group()
983+
984+
@torch.compile
985+
def f(x):
986+
y = 2 * x
987+
return y.sum()
988+
989+
backend = pg._get_backend(torch.device(device))
990+
backend._set_default_timeout(timedelta(seconds=5))
991+
counters.clear()
992+
993+
x = torch.ones(4, device=device)
994+
995+
f(x)
996+
997+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
998+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
999+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1000+
1001+
w = pg.allreduce(x)
1002+
w.wait()
1003+
torch.cuda.synchronize(device)
1004+
torch._dynamo.reset()
1005+
1006+
if self.rank == 0:
1007+
with fresh_inductor_cache():
1008+
f(x)
1009+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
1010+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
1011+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1012+
else:
1013+
f(x)
1014+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
1015+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
1016+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
1017+
1018+
w = pg.allreduce(x)
1019+
w.wait()
1020+
torch.cuda.synchronize(device)
1021+
9671022

9681023
@requires_nccl()
9691024
@requires_cuda

torch/_inductor/codecache.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from bisect import bisect_right
2727
from copy import copy
2828
from ctypes import c_void_p, CDLL, cdll
29+
from datetime import timedelta
2930
from functools import partial
3031
from pathlib import Path
3132
from time import time, time_ns
@@ -50,6 +51,7 @@
5051
from typing_extensions import TypeAlias
5152

5253
import torch
54+
import torch.distributed as dist
5355
from torch import SymInt, Tensor
5456
from torch._dynamo.utils import ChromiumEventLogger, counters, dynamo_timed
5557
from torch._inductor import config, exc, metrics
@@ -1144,7 +1146,6 @@ def _save_graph(
11441146
key: str,
11451147
compiled_graph: CompiledFxGraph,
11461148
example_inputs: List[torch.Tensor],
1147-
time_taken_ns: int,
11481149
local: bool,
11491150
remote_cache: None,
11501151
) -> None:
@@ -1196,8 +1197,8 @@ def _save_graph(
11961197
cache_data = (
11971198
{
11981199
"data": content,
1199-
"time_taken_ms": time_taken_ns
1200-
// 1000000, # Convert from NS to MS
1200+
"time_taken_ms": disk_compiled_graph._time_taken_ns
1201+
// 1e6, # Convert from NS to MS
12011202
}
12021203
if config.is_fbcode()
12031204
else content
@@ -1291,12 +1292,11 @@ def load( # type: ignore[no-untyped-def]
12911292
compiled_graph = compile_fx_fn(
12921293
gm, example_inputs, inputs_to_check, fx_kwargs
12931294
)
1294-
time_taken_ns = time_ns() - start_time
1295+
compiled_graph._time_taken_ns = time_ns() - start_time
12951296
FxGraphCache._save_graph(
12961297
key,
12971298
compiled_graph,
12981299
example_inputs,
1299-
time_taken_ns,
13001300
local,
13011301
remote_cache,
13021302
)
@@ -1305,6 +1305,15 @@ def load( # type: ignore[no-untyped-def]
13051305
counters["inductor"]["fxgraph_cache_hit"] += 1
13061306
cache_state = "hit"
13071307
cache_event_time = time_ns()
1308+
if (
1309+
dist.distributed_c10d.is_initialized()
1310+
and (time_taken_ns := compiled_graph._time_taken_ns) is not None
1311+
):
1312+
increased_timeout_sec = time_taken_ns // 1e9 # convert to seconds
1313+
log.info("Increasing NCCL timeout by %d", increased_timeout_sec)
1314+
dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
1315+
timedelta(seconds=increased_timeout_sec)
1316+
)
13081317
compiled_graph._fx_graph_cache_key = key
13091318
except BypassFxGraphCache:
13101319
counters["inductor"]["fxgraph_cache_bypass"] += 1
@@ -1380,6 +1389,7 @@ class CompiledFxGraph:
13801389
inputs_to_check: Sequence[int]
13811390
boxed_forward_device_index: Optional[BoxedDeviceIndex]
13821391

1392+
_time_taken_ns: Optional[int] = None
13831393
_boxed_call: Optional[bool] = None
13841394
_fx_graph_cache_key: Optional[str] = None
13851395

torch/_inductor/compile_fx.py

+6
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ def fx_codegen_and_compile(
705705
layout_opt: Optional[bool] = None,
706706
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,
707707
) -> Union[CompiledFxGraph, str]:
708+
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
709+
import time
710+
711+
log.warning("Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec)
712+
time.sleep(sleep_sec)
713+
708714
with dynamo_utils.preserve_rng_state():
709715
if is_tf32_warning_applicable(gm):
710716
_warn_tf32_disabled()

torch/_inductor/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def autotune_remote_cache_default() -> Optional[bool]:
5757
# Force disabled all inductor level caching -- This will override any other caching flag
5858
force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
5959

60+
# sleep in inductor for testing
61+
sleep_sec_TESTING_ONLY: Optional[int] = None
62+
6063
# use cpp wrapper instead of python wrapper
6164
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
6265

0 commit comments

Comments
 (0)