Skip to content

Commit 58a8a62

Browse files
mikolajblazericharper
authored andcommitted
ADLR/megatron-lm!1823 - PyT Dist fix for 24.05 container
1 parent 1efefa7 commit 58a8a62

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

  • megatron/core/dist_checkpointing/strategies

megatron/core/dist_checkpointing/strategies/torch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def sharded_tensor_to_torch_sharded_tensor(
209209
]
210210

211211
# Create a ShardedTensor without invoking communication. Determine global shards
212+
world_size = torch.distributed.get_world_size()
212213
shard_metadata = []
213214
# NOTE: here we assume a regular grid of shards
214215
for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)):
@@ -232,13 +233,16 @@ def sharded_tensor_to_torch_sharded_tensor(
232233

233234
else:
234235
# for shards from other ranks we provide simplistic data - this information will be discarded
235-
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
236+
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
237+
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
238+
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
239+
placement = f"rank:{(rank + 1) % world_size}/cuda"
236240
if has_flattened_range and not is_flattened_range_1d:
237241
offset = offset + (0,)
238242
size = (1,) * len(offsets_shape) + global_shape[-1:]
239243
else:
240244
size = offsets_shape
241-
shard_metadata.append(ShardMetadata(offset, size, "cuda"))
245+
shard_metadata.append(ShardMetadata(offset, size, placement))
242246

243247
tensor = some_sh_ten.data
244248
sharded_tensor_metadata = ShardedTensorMetadata(

0 commit comments

Comments
 (0)