|
40 | 40 | import torch.distributed as dist |
41 | 41 | import torch.multiprocessing as mp |
42 | 42 |
|
43 | | -# pyre-fixme[21]: no attribute ProcessGroupNCCL |
44 | 43 | # pyre-fixme[21]: no attribute ProcessGroupGloo |
45 | 44 | from torch.distributed import ( |
46 | 45 | DeviceMesh, |
47 | 46 | PrefixStore, |
48 | 47 | ProcessGroup as BaseProcessGroup, |
49 | 48 | ProcessGroupGloo as BaseProcessGroupGloo, |
50 | | - ProcessGroupNCCL as BaseProcessGroupNCCL, |
51 | 49 | Store, |
52 | 50 | TCPStore, |
53 | 51 | ) |
@@ -687,6 +685,9 @@ def _wrap_work(self, work: Work, opts: object) -> Work: |
687 | 685 | return _WorkCUDATimeout(self, work, timeout) |
688 | 686 |
|
689 | 687 | def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: |
| 688 | + # pyre-fixme[21]: no attribute ProcessGroupNCCL |
| 689 | + from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL |
| 690 | + |
690 | 691 | self._errored = None |
691 | 692 |
|
692 | 693 | pg = BaseProcessGroup(store, rank, world_size) |
@@ -1717,6 +1718,8 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby): |
1717 | 1718 |
|
1718 | 1719 | @classmethod |
1719 | 1720 | def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: |
| 1721 | + from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL |
| 1722 | + |
1720 | 1723 | pg = BaseProcessGroup(store, rank, world_size) |
1721 | 1724 | pg._set_default_backend(ProcessGroup.BackendType.NCCL) |
1722 | 1725 | # pyre-fixme[16]: no attribute ProcessGroupNCCL |
|
0 commit comments