Skip to content

Commit f32861c

Browse files
authored
[misc] update torch version (#6206)
* [misc] update torch version * fix test * fix test * fix test * fix test
1 parent b9e6055 commit f32861c

File tree

5 files changed

+7
-6
lines changed

5 files changed

+7
-6
lines changed

.compatibility

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
2.2.2-12.1.0
21
2.3.0-12.1.0
32
2.4.0-12.4.1
3+
2.5.1-12.4.1

.cuda_ext.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
{
22
"build": [
33
{
4-
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
4+
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
55
"cuda_image": "hpcaitech/cuda-conda:12.1"
66
},
77
{
8-
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
8+
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
99
"cuda_image": "hpcaitech/cuda-conda:12.4"
1010
}
1111
]

requirements/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ click
88
fabric
99
contexttimer
1010
ninja
11-
torch>=2.2.0,<=2.4.1
11+
torch>=2.2.0,<=2.5.1
1212
safetensors
1313
einops
1414
pydantic

tests/test_cluster/test_device_mesh_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
22
from colossalai.initialize import launch
33
from colossalai.logging import disable_existing_loggers
4-
from colossalai.testing import spawn
4+
from colossalai.testing import rerun_if_address_is_in_use, spawn
55

66

77
def check_device_mesh_manager(rank, world_size, port):
@@ -24,6 +24,7 @@ def check_device_mesh_manager(rank, world_size, port):
2424
assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]]
2525

2626

27+
@rerun_if_address_is_in_use()
2728
def test_device_mesh_manager():
2829
spawn(check_device_mesh_manager, 4)
2930

tests/test_shardformer/test_model/test_shard_t5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
5151
if test_config["precision"] == "fp32":
5252
atol, rtol = 1e-5, 1e-3
5353
else:
54-
atol, rtol = 5e-2, 5e-2
54+
atol, rtol = 9e-2, 0
5555
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
5656
row_layer_grads = get_grad_tensors_for_check(
5757
t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0

0 commit comments

Comments
 (0)