Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fegin committed Jan 10, 2025
1 parent a92576c commit 31ba2ef
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
1 change: 1 addition & 0 deletions torchft/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _test_fsdp(world_size: int, rank: int) -> None:
shard_model = fully_shard(model, mesh=device_mesh)
shard_model(batch).mean().backward()

# pyre-ignore[56]: Pyre was not able to infer the type of argument
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
def test_fsdp(self) -> None:
multiprocessing.set_start_method("spawn")
Expand Down
2 changes: 1 addition & 1 deletion torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def _test_init_device_mesh(world_size: int, rank: int) -> None:
not isinstance(device_mesh.get_group("dp_shard"), ManagedProcessGroup)
)
replicate_group = device_mesh.get_group("dp_replicate")
testcase.assertEqual(replicate_group._manager, manager)
testcase.assertEqual(cast(ManagedProcessGroup, replicate_group)._manager, manager)
replicate_mesh = device_mesh["dp_replicate"]
testcase.assertEqual(replicate_mesh.get_group(), replicate_group)
flatten_mesh = device_mesh._flatten("dp")
Expand Down

0 comments on commit 31ba2ef

Please sign in to comment.