File tree Expand file tree Collapse file tree
megatron/core/dist_checkpointing/strategies Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -209,6 +209,7 @@ def sharded_tensor_to_torch_sharded_tensor(
209209 ]
210210
211211 # Create a ShardedTensor without invoking communication. Determine global shards
212+ world_size = torch .distributed .get_world_size ()
212213 shard_metadata = []
213214 # NOTE: here we assume a regular grid of shards
214215 for fragment_offsets in itertools .product (* map (range , some_sh_ten .axis_fragmentations )):
@@ -232,13 +233,16 @@ def sharded_tensor_to_torch_sharded_tensor(
232233
233234 else :
234235 # for shards from other ranks we provide simplistic data - this information will be discarded
235- # during TorchShardedTensor._init_from_local_shards_and_global_metadata call
236+ # during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
237+ # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
238+ # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
239+ placement = f"rank:{ (rank + 1 ) % world_size } /cuda"
236240 if has_flattened_range and not is_flattened_range_1d :
237241 offset = offset + (0 ,)
238242 size = (1 ,) * len (offsets_shape ) + global_shape [- 1 :]
239243 else :
240244 size = offsets_shape
241- shard_metadata .append (ShardMetadata (offset , size , "cuda" ))
245+ shard_metadata .append (ShardMetadata (offset , size , placement ))
242246
243247 tensor = some_sh_ten .data
244248 sharded_tensor_metadata = ShardedTensorMetadata (
You can’t perform that action at this time.
0 commit comments