Skip to content

Commit d0da8db

Browse files
authored
Dev/xiaoming/fix fsdp2 (#27)
1 parent 87b9c6d commit d0da8db

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

primus/modules/trainer/megatron/trainer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
from megatron.core.distributed.custom_fsdp import (
2323
FullyShardedDataParallel as custom_FSDP,
2424
)
25+
from megatron.core.distributed.distributed_data_parallel_config import (
26+
DistributedDataParallelConfig,
27+
)
28+
from megatron.core.distributed.torch_fully_sharded_data_parallel import (
29+
TorchFullyShardedDataParallel as torch_FSDP,
30+
)
2531
from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config
2632
from megatron.training.checkpointing import (
2733
checkpoint_exists,
@@ -297,7 +303,8 @@ def update_primus_config(
297303
log_kv_rank_0(f"-world_size", f"{args.world_size}")
298304

299305
###################################################cuda
300-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
306+
if not args.use_torch_fsdp2:
307+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
301308

302309
###################################################checkpoint
303310
ckpt_path = os.path.abspath(os.path.join(exp_root_path, "checkpoints"))
@@ -823,6 +830,12 @@ def setup_model_and_optimizer(
823830

824831
log_rank_0(f"-run get_model")
825832
model = get_model(model_provider_func, model_type)
833+
834+
# get_megatron_optimizer will use the ddp_config
835+
if isinstance(model[0], torch_FSDP):
836+
model[0].ddp_config = DistributedDataParallelConfig()
837+
model[0].ddp_config.use_custom_fsdp = False
838+
826839
unwrapped_model = unwrap_model(model)
827840

828841
kwargs = {}

0 commit comments

Comments
 (0)