-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce ManagedDeviceMesh to integrate DeviceMesh with TorchFT #56
Changes from 3 commits
24126da
4b9252b
884c4bd
e98c67c
7453067
6fb19cd
50b7520
dd0bdb1
9711549
029ea69
95ff257
939ac2a
ee798ba
73d0e81
52c5362
15645b6
1120930
96ea9be
99cb0ab
556e286
3122fda
a03406f
a92576c
31ba2ef
745f33e
85cbb85
d58ee00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,10 +28,12 @@ | |
get_world_size, | ||
) | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from torch.testing._internal.common_distributed import MultiProcessTestCase | ||
|
||
from torchft.manager import Manager | ||
from torchft.process_group import ( | ||
ErrorSwallowingProcessGroupWrapper, | ||
ManagedDeviceMesh, | ||
ManagedProcessGroup, | ||
ProcessGroup, | ||
ProcessGroupBabyGloo, | ||
|
@@ -44,6 +46,7 @@ | |
_ErrorSwallowingWork, | ||
_ManagedWork, | ||
extend_device_mesh, | ||
ft_init_device_mesh, | ||
) | ||
|
||
|
||
|
@@ -234,6 +237,7 @@ def test_device_mesh(self) -> None: | |
pg.configure(store_addr, 0, 1) | ||
|
||
mesh_2d = extend_device_mesh(mesh_1d, pg) | ||
mesh_2d.get_group("dp") | ||
assert mesh_2d.ndim == 2 | ||
|
||
pg.unregister() | ||
|
@@ -299,3 +303,46 @@ def test_managed_process_group(self) -> None: | |
|
||
self.assertEqual(manager.report_error.call_count, 0) | ||
self.assertEqual(manager.wrap_future.call_count, 1) | ||
|
||
|
||
class DevideMeshTest(MultiProcessTestCase): | ||
@property | ||
def world_size(self): | ||
return 4 | ||
|
||
def setUp(self): | ||
super().setUp() | ||
os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" | ||
self._spawn_processes() | ||
|
||
def test_init_device_mesh(self) -> None: | ||
os.environ["MASTER_PORT"] = str(12346) | ||
os.environ["RANK"] = str(self.rank) | ||
os.environ["WORLD_SIZE"] = str(4) | ||
|
||
manager = Mock(spec=Manager) | ||
# Even though we only have 4 workers, we can still initialize (2, 4) mesh. | ||
# That's because the replicate group is NOT phystically created in the | ||
# real mesh but is virtually added to the mesh via ManagedDeviceMesh. | ||
device_mesh = ft_init_device_mesh( | ||
device_type="cpu", | ||
mesh_shape=(2, 4), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this value used at all? I assume it doesn't really matter what it's set to? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The replicate part is not going to be valid but other parts are valid and will be used. |
||
mesh_dim_names=("dp_replicate", "dp_shard"), | ||
replicate_dim=0, | ||
manager=manager, | ||
) | ||
|
||
self.assertTrue( | ||
isinstance(device_mesh.get_group("dp_replicate"), ManagedProcessGroup) | ||
) | ||
self.assertTrue( | ||
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup) | ||
) | ||
replicate_group = device_mesh.get_group("dp_replicate") | ||
self.assertEqual(replicate_group._manager, manager) | ||
replicate_mesh = device_mesh["dp_replicate"] | ||
self.assertEqual(replicate_mesh.get_group(), replicate_group) | ||
flatten_mesh = device_mesh._flatten("dp") | ||
manager.num_participants.return_value = 1 | ||
self.assertEqual(flatten_mesh.size(), 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be equal to |
||
self.assertEqual(flatten_mesh.get_local_rank(), dist.get_rank()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nice! Does this solve the issue with flattening in FSDP or just throws an error for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should work for the case where we flatten the device mesh to compute the global loss average but not work for data loader. I think we need to customize dataloader anyway.