Skip to content

Commit 4166d97

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add and use generic torch version comparator
Summary: # Context Before we would have an individual function for each version of pytorch we want to check for. Let's remove this in favor of a single universal function # This Diff Pass the version as a string instead and replaces old uses Differential Revision: D56446382
1 parent 4269c5b commit 4166d97

File tree

5 files changed

+8
-10
lines changed

5 files changed

+8
-10
lines changed

tests/utils/test_prepare_module.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
143143
"""
144144

145145
tc = unittest.TestCase()
146-
with patch(
147-
"torchtnt.utils.version.get_torch_version", return_value=Version("2.0.0")
148-
):
146+
with patch("torchtnt.utils.version.is_torch_version_geq", return_value=False):
149147
with tc.assertRaisesRegex(
150148
RuntimeError,
151149
"Torch version >= 2.1.0 required",

tests/utils/test_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def test_get_torch_version(self) -> None:
4949

5050
def test_torch_version_comparators(self) -> None:
5151
with patch.object(torch, "__version__", "2.0.0a0"):
52-
self.assertFalse(version.is_torch_version_geq_2_1())
52+
self.assertFalse(version.is_torch_version_geq("2.1.0"))

torchtnt/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from .version import (
7575
get_python_version,
7676
get_torch_version,
77-
is_torch_version_geq_2_1,
77+
is_torch_version_geq,
7878
is_windows,
7979
)
8080

@@ -144,7 +144,7 @@
144144
"TLRScheduler",
145145
"get_python_version",
146146
"get_torch_version",
147-
"is_torch_version_geq_2_1",
147+
"is_torch_version_geq",
148148
"is_windows",
149149
"get_pet_launch_config",
150150
"spawn_multi_process",

torchtnt/utils/prepare_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242

4343
from torchtnt.utils.rank_zero_log import rank_zero_warn
44-
from torchtnt.utils.version import is_torch_version_geq_2_1
44+
from torchtnt.utils.version import is_torch_version_geq
4545

4646

4747
@dataclass
@@ -318,7 +318,7 @@ def prepare_module(
318318
if (
319319
torch_compile_params
320320
and strategy.static_graph is True
321-
and not is_torch_version_geq_2_1()
321+
and not is_torch_version_geq("2.1.0")
322322
):
323323
raise RuntimeError(
324324
"Torch version >= 2.1.0 required for Torch compile + DDP with static graph"

torchtnt/utils/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def get_torch_version() -> Version:
5656
return pkg_version
5757

5858

59-
def is_torch_version_geq_2_1() -> bool:
60-
return get_torch_version() >= Version("2.1.0")
59+
def is_torch_version_geq(version: str) -> bool:
60+
return get_torch_version() >= Version(version)

0 commit comments

Comments
 (0)