File tree 5 files changed +8
-10
lines changed
5 files changed +8
-10
lines changed Original file line number Diff line number Diff line change @@ -143,9 +143,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
143
143
"""
144
144
145
145
tc = unittest .TestCase ()
146
- with patch (
147
- "torchtnt.utils.version.get_torch_version" , return_value = Version ("2.0.0" )
148
- ):
146
+ with patch ("torchtnt.utils.version.is_torch_version_geq" , return_value = False ):
149
147
with tc .assertRaisesRegex (
150
148
RuntimeError ,
151
149
"Torch version >= 2.1.0 required" ,
Original file line number Diff line number Diff line change @@ -49,4 +49,4 @@ def test_get_torch_version(self) -> None:
49
49
50
50
def test_torch_version_comparators (self ) -> None :
51
51
with patch .object (torch , "__version__" , "2.0.0a0" ):
52
- self .assertFalse (version .is_torch_version_geq_2_1 ( ))
52
+ self .assertFalse (version .is_torch_version_geq ( "2.1.0" ))
Original file line number Diff line number Diff line change 74
74
from .version import (
75
75
get_python_version ,
76
76
get_torch_version ,
77
- is_torch_version_geq_2_1 ,
77
+ is_torch_version_geq ,
78
78
is_windows ,
79
79
)
80
80
144
144
"TLRScheduler" ,
145
145
"get_python_version" ,
146
146
"get_torch_version" ,
147
- "is_torch_version_geq_2_1 " ,
147
+ "is_torch_version_geq " ,
148
148
"is_windows" ,
149
149
"get_pet_launch_config" ,
150
150
"spawn_multi_process" ,
Original file line number Diff line number Diff line change 41
41
)
42
42
43
43
from torchtnt .utils .rank_zero_log import rank_zero_warn
44
- from torchtnt .utils .version import is_torch_version_geq_2_1
44
+ from torchtnt .utils .version import is_torch_version_geq
45
45
46
46
47
47
@dataclass
@@ -318,7 +318,7 @@ def prepare_module(
318
318
if (
319
319
torch_compile_params
320
320
and strategy .static_graph is True
321
- and not is_torch_version_geq_2_1 ( )
321
+ and not is_torch_version_geq ( "2.1.0" )
322
322
):
323
323
raise RuntimeError (
324
324
"Torch version >= 2.1.0 required for Torch compile + DDP with static graph"
Original file line number Diff line number Diff line change @@ -56,5 +56,5 @@ def get_torch_version() -> Version:
56
56
return pkg_version
57
57
58
58
59
- def is_torch_version_geq_2_1 ( ) -> bool :
60
- return get_torch_version () >= Version ("2.1.0" )
59
+ def is_torch_version_geq ( version : str ) -> bool :
60
+ return get_torch_version () >= Version (version )
You can’t perform that action at this time.
0 commit comments