Skip to content

Commit 4269c5b

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
enfore pytorch >= 2.0
Differential Revision: D56446353
1 parent e7b9e64 commit 4269c5b

17 files changed

+44
-293
lines changed

tests/framework/callbacks/test_module_summary.py

-9
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919

2020
from torchtnt.framework.callbacks.module_summary import ModuleSummary
2121
from torchtnt.framework.state import EntryPoint, PhaseState, State
22-
from torchtnt.utils.version import is_torch_version_geq_1_13
23-
24-
MODULE_SUMMARY_FLOPS_AVAILABLE = False
25-
if is_torch_version_geq_1_13():
26-
MODULE_SUMMARY_FLOPS_AVAILABLE = True
2722

2823

2924
class ModuleSummaryTest(unittest.TestCase):
@@ -85,10 +80,6 @@ def forward(self, x):
8580
self.assertTrue("b1" in ms.submodule_summaries)
8681
self.assertTrue("l2" in ms.submodule_summaries)
8782

88-
@unittest.skipUnless(
89-
condition=MODULE_SUMMARY_FLOPS_AVAILABLE,
90-
reason="This test needs PyTorch 1.13 or greater to run.",
91-
)
9283
def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None:
9384
"""
9485
Test ModuleSummary callback in train

tests/framework/test_auto_unit.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import torch
15-
from torchtnt.framework.auto_unit import TrainStepResults
16-
from torchtnt.utils.test_utils import skip_if_not_distributed
17-
18-
from torchtnt.utils.version import is_torch_version_geq_1_13
19-
20-
COMPILE_AVAIL = False
21-
if is_torch_version_geq_1_13():
22-
COMPILE_AVAIL = True
23-
import torch._dynamo
2415

2516
from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
2617

@@ -37,6 +28,7 @@
3728
AutoUnit,
3829
SWALRParams,
3930
SWAParams,
31+
TrainStepResults,
4032
)
4133
from torchtnt.framework.evaluate import evaluate
4234
from torchtnt.framework.predict import predict
@@ -49,6 +41,7 @@
4941
from torchtnt.utils.lr_scheduler import TLRScheduler
5042
from torchtnt.utils.prepare_module import DDPStrategy
5143
from torchtnt.utils.progress import Progress
44+
from torchtnt.utils.test_utils import skip_if_not_distributed
5245
from torchtnt.utils.timer import Timer
5346

5447
TParams = ParamSpec("TParams")

tests/framework/test_auto_unit_gpu.py

+4-15
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,16 @@
88
# pyre-strict
99

1010
import unittest
11+
12+
from copy import deepcopy
1113
from typing import TypeVar
1214
from unittest.mock import MagicMock, patch
1315

1416
import torch
15-
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
16-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
17-
18-
from torchtnt.utils.version import is_torch_version_geq_1_13
19-
20-
COMPILE_AVAIL = False
21-
if is_torch_version_geq_1_13():
22-
COMPILE_AVAIL = True
23-
import torch._dynamo
24-
25-
from copy import deepcopy
2617

2718
from pyre_extensions import ParameterSpecification as ParamSpec
2819
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
2921
from torchtnt.framework._test_utils import (
3022
DummyAutoUnit,
3123
generate_random_dataloader,
@@ -40,6 +32,7 @@
4032
from torchtnt.utils.distributed import spawn_multi_process
4133
from torchtnt.utils.env import init_from_env, seed
4234
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
35+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
4336

4437
TParams = ParamSpec("TParams")
4538
T = TypeVar("T")
@@ -320,10 +313,6 @@ def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
320313
device_type="cuda", dtype=torch.float16, enabled=True
321314
)
322315

323-
@unittest.skipUnless(
324-
condition=COMPILE_AVAIL,
325-
reason="This test needs PyTorch 1.13 or greater to run.",
326-
)
327316
@skip_if_not_gpu
328317
@patch("torch.compile")
329318
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:

tests/utils/test_memory_snapshot_profiler.py

-8
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,9 @@
1414
MemorySnapshotParams,
1515
MemorySnapshotProfiler,
1616
)
17-
from torchtnt.utils.version import is_torch_version_geq_2_0
1817

1918

2019
class MemorySnapshotProfilerTest(unittest.TestCase):
21-
22-
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
23-
24-
@unittest.skipUnless(
25-
condition=torch_version_geq_2_0,
26-
reason="This test needs changes from PyTorch 2.0 to run.",
27-
)
2820
def test_validation(self) -> None:
2921
"""Test parameter validation."""
3022
with tempfile.TemporaryDirectory() as temp_dir:

tests/utils/test_memory_snapshot_profiler_gpu.py

-8
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,10 @@
1818
MemorySnapshotProfiler,
1919
)
2020
from torchtnt.utils.test_utils import skip_if_not_gpu
21-
from torchtnt.utils.version import is_torch_version_geq_2_0
2221

2322

2423
class MemorySnapshotProfilerGPUTest(unittest.TestCase):
25-
26-
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
27-
2824
@skip_if_not_gpu
29-
@unittest.skipUnless(
30-
condition=torch_version_geq_2_0,
31-
reason="This test needs changes from PyTorch 2.0 to run.",
32-
)
3325
def test_stop_step(self) -> None:
3426
"""Test that a memory snapshot is saved when stop_step is reached."""
3527
with tempfile.TemporaryDirectory() as temp_dir:

tests/utils/test_oom_gpu.py

-5
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,10 @@
1616
from torchtnt.utils.oom import log_memory_snapshot
1717

1818
from torchtnt.utils.test_utils import skip_if_not_gpu
19-
from torchtnt.utils.version import is_torch_version_geq_2_0
2019

2120

2221
class OomGPUTest(unittest.TestCase):
2322
@skip_if_not_gpu
24-
@unittest.skipUnless(
25-
condition=bool(is_torch_version_geq_2_0()),
26-
reason="This test needs changes from PyTorch 2.0 to run.",
27-
)
2823
def test_log_memory_snapshot(self) -> None:
2924
with tempfile.TemporaryDirectory() as temp_dir:
3025
# Record history

tests/utils/test_prepare_module.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,7 @@
2222
TorchCompileParams,
2323
)
2424
from torchtnt.utils.test_utils import skip_if_not_distributed
25-
from torchtnt.utils.version import is_torch_version_geq_1_13, Version
26-
27-
COMPILE_AVAIL = False
28-
if is_torch_version_geq_1_13():
29-
COMPILE_AVAIL = True
30-
import torch._dynamo
25+
from torchtnt.utils.version import Version
3126

3227

3328
class PrepareModelTest(unittest.TestCase):
@@ -170,10 +165,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
170165
torch_compile_params=TorchCompileParams(backend="inductor"),
171166
)
172167

173-
@unittest.skipUnless(
174-
condition=COMPILE_AVAIL,
175-
reason="This test needs PyTorch 1.13 or greater to run.",
176-
)
177168
def test_prepare_module_compile_invalid_backend(self) -> None:
178169
"""
179170
verify error is thrown on invalid backend
@@ -199,10 +190,6 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
199190
torch_compile_params=TorchCompileParams(),
200191
)
201192

202-
@unittest.skipUnless(
203-
condition=COMPILE_AVAIL,
204-
reason="This test needs PyTorch 1.13 or greater to run.",
205-
)
206193
def test_prepare_module_compile_module_state_dict(self) -> None:
207194
device = init_from_env()
208195
my_module = torch.nn.Linear(2, 2, device=device)

tests/utils/test_prepare_module_gpu.py

+4-40
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
# pyre-strict
99
import unittest
10-
from unittest.mock import patch
1110

1211
import torch
12+
13+
from torch.distributed._composable import fully_shard
1314
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1415
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
1516
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -24,15 +25,6 @@
2425
prepare_module,
2526
)
2627
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
27-
from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0
28-
29-
COMPILE_AVAIL = False
30-
if is_torch_version_geq_1_13():
31-
COMPILE_AVAIL = True
32-
import torch._dynamo
33-
34-
if is_torch_version_geq_2_0():
35-
from torch.distributed._composable import fully_shard
3628

3729

3830
class PrepareModelGPUTest(unittest.TestCase):
@@ -81,33 +73,6 @@ def _test_prepare_fsdp() -> None:
8173
tc = unittest.TestCase()
8274
tc.assertTrue(isinstance(fsdp_module, FSDP))
8375

84-
@skip_if_not_distributed
85-
@skip_if_not_gpu
86-
def test_fsdp_pytorch_version(self) -> None:
87-
"""
88-
Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12
89-
"""
90-
spawn_multi_process(
91-
2,
92-
"nccl",
93-
self._test_fsdp_pytorch_version,
94-
)
95-
96-
@staticmethod
97-
def _test_fsdp_pytorch_version() -> None:
98-
device = init_from_env()
99-
module = torch.nn.Linear(2, 2).to(device)
100-
101-
tc = unittest.TestCase()
102-
with patch(
103-
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
104-
return_value=False,
105-
), tc.assertRaisesRegex(
106-
RuntimeError,
107-
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
108-
):
109-
_ = prepare_fsdp(module, device, FSDPStrategy())
110-
11176
@skip_if_not_distributed
11277
@unittest.skipUnless(
11378
condition=bool(torch.cuda.device_count() >= 2),
@@ -128,9 +93,8 @@ def _test_is_fsdp_module() -> None:
12893
model = FSDP(torch.nn.Linear(1, 1, device=device))
12994
assert _is_fsdp_module(model)
13095
model = torch.nn.Linear(1, 1, device=device)
131-
if is_torch_version_geq_2_0():
132-
fully_shard(model)
133-
assert _is_fsdp_module(model)
96+
fully_shard(model)
97+
assert _is_fsdp_module(model)
13498

13599
@skip_if_not_distributed
136100
@skip_if_not_gpu

tests/utils/test_version.py

+1-44
Original file line numberDiff line numberDiff line change
@@ -48,48 +48,5 @@ def test_get_torch_version(self) -> None:
4848
self.assertEqual(version.get_torch_version(), Version("1.12.0"))
4949

5050
def test_torch_version_comparators(self) -> None:
51-
with patch.object(torch, "__version__", "1.7.0"):
52-
self.assertFalse(version.is_torch_version_geq_1_8())
53-
self.assertFalse(version.is_torch_version_geq_1_9())
54-
self.assertFalse(version.is_torch_version_geq_1_10())
55-
self.assertFalse(version.is_torch_version_geq_1_11())
56-
self.assertFalse(version.is_torch_version_geq_1_12())
57-
58-
with patch.object(torch, "__version__", "1.8.0"):
59-
self.assertTrue(version.is_torch_version_geq_1_8())
60-
self.assertFalse(version.is_torch_version_geq_1_9())
61-
self.assertFalse(version.is_torch_version_geq_1_10())
62-
self.assertFalse(version.is_torch_version_geq_1_11())
63-
self.assertFalse(version.is_torch_version_geq_1_12())
64-
65-
with patch.object(torch, "__version__", "1.9.0"):
66-
self.assertTrue(version.is_torch_version_geq_1_8())
67-
self.assertTrue(version.is_torch_version_geq_1_9())
68-
self.assertFalse(version.is_torch_version_geq_1_10())
69-
self.assertFalse(version.is_torch_version_geq_1_11())
70-
self.assertFalse(version.is_torch_version_geq_1_12())
71-
72-
with patch.object(torch, "__version__", "1.10.0"):
73-
self.assertTrue(version.is_torch_version_geq_1_8())
74-
self.assertTrue(version.is_torch_version_geq_1_9())
75-
self.assertTrue(version.is_torch_version_geq_1_10())
76-
self.assertFalse(version.is_torch_version_geq_1_11())
77-
self.assertFalse(version.is_torch_version_geq_1_12())
78-
79-
with patch.object(torch, "__version__", "1.11.0"):
80-
self.assertTrue(version.is_torch_version_geq_1_8())
81-
self.assertTrue(version.is_torch_version_geq_1_9())
82-
self.assertTrue(version.is_torch_version_geq_1_10())
83-
self.assertTrue(version.is_torch_version_geq_1_11())
84-
self.assertFalse(version.is_torch_version_geq_1_12())
85-
86-
with patch.object(torch, "__version__", "1.12.0"):
87-
self.assertTrue(version.is_torch_version_geq_1_8())
88-
self.assertTrue(version.is_torch_version_geq_1_9())
89-
self.assertTrue(version.is_torch_version_geq_1_10())
90-
self.assertTrue(version.is_torch_version_geq_1_11())
91-
self.assertTrue(version.is_torch_version_geq_1_12())
92-
9351
with patch.object(torch, "__version__", "2.0.0a0"):
94-
self.assertTrue(version.is_torch_version_ge_1_13_1())
95-
self.assertFalse(version.is_torch_version_geq_2_0())
52+
self.assertFalse(version.is_torch_version_geq_2_1())

torchtnt/framework/auto_unit.py

-11
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
TorchCompileParams,
5151
)
5252
from torchtnt.utils.swa import AveragedModel
53-
from torchtnt.utils.version import is_torch_version_ge_1_13_1
5453
from typing_extensions import Literal
5554

5655

@@ -166,8 +165,6 @@ def __init__(
166165
torch_compile_params: Optional[TorchCompileParams] = None,
167166
) -> None:
168167
super().__init__()
169-
if torch_compile_params:
170-
_validate_torch_compile_available()
171168

172169
self.device: torch.device = device or init_from_env()
173170
self.precision: Optional[torch.dtype] = (
@@ -879,11 +876,3 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No
879876
state, f"{self.__class__.__name__}.lr_scheduler_step"
880877
):
881878
self.step_lr_scheduler()
882-
883-
884-
def _validate_torch_compile_available() -> None:
885-
if not is_torch_version_ge_1_13_1():
886-
raise RuntimeError(
887-
"Torch compile support is available only in PyTorch 2.0 or higher. "
888-
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
889-
)

torchtnt/utils/__init__.py

-18
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,6 @@
7474
from .version import (
7575
get_python_version,
7676
get_torch_version,
77-
is_torch_version_ge_1_13_1,
78-
is_torch_version_geq_1_10,
79-
is_torch_version_geq_1_11,
80-
is_torch_version_geq_1_12,
81-
is_torch_version_geq_1_13,
82-
is_torch_version_geq_1_14,
83-
is_torch_version_geq_1_8,
84-
is_torch_version_geq_1_9,
85-
is_torch_version_geq_2_0,
8677
is_torch_version_geq_2_1,
8778
is_windows,
8879
)
@@ -153,15 +144,6 @@
153144
"TLRScheduler",
154145
"get_python_version",
155146
"get_torch_version",
156-
"is_torch_version_ge_1_13_1",
157-
"is_torch_version_geq_1_10",
158-
"is_torch_version_geq_1_11",
159-
"is_torch_version_geq_1_12",
160-
"is_torch_version_geq_1_13",
161-
"is_torch_version_geq_1_14",
162-
"is_torch_version_geq_1_8",
163-
"is_torch_version_geq_1_9",
164-
"is_torch_version_geq_2_0",
165147
"is_torch_version_geq_2_1",
166148
"is_windows",
167149
"get_pet_launch_config",

0 commit comments

Comments
 (0)