Skip to content

Commit b8dc2dc

Browse files
committed
typing
1 parent 26b62be commit b8dc2dc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchft/process_group.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ def __getstate__(self) -> Dict[str, Any]:
898898

899899
def __setstate__(self, state: Dict[str, Any]) -> None:
900900
self.__dict__.update(state)
901+
assert self.replicate_pg_singleton is not None
901902
self.replicate_pg = self.replicate_pg_singleton
902903

903904
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
@@ -921,10 +922,10 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
921922
assert self.mesh is not None
922923
return self.mesh[mesh_dim_names]
923924
else:
924-
assert self.mesh is not None
925925
mesh_dim_names_wo_replicate = tuple(
926926
n for n in mesh_dim_names if n != self.replicate_dim_name
927927
)
928+
assert self.mesh is not None
928929
return ManagedDeviceMesh(
929930
self.mesh[mesh_dim_names_wo_replicate],
930931
mesh_dim_names,

0 commit comments

Comments
 (0)