|
22 | 22 | from megatron.core.distributed.custom_fsdp import ( |
23 | 23 | FullyShardedDataParallel as custom_FSDP, |
24 | 24 | ) |
| 25 | +from megatron.core.distributed.distributed_data_parallel_config import ( |
| 26 | + DistributedDataParallelConfig, |
| 27 | +) |
| 28 | +from megatron.core.distributed.torch_fully_sharded_data_parallel import ( |
| 29 | + TorchFullyShardedDataParallel as torch_FSDP, |
| 30 | +) |
25 | 31 | from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config |
26 | 32 | from megatron.training.checkpointing import ( |
27 | 33 | checkpoint_exists, |
@@ -297,7 +303,8 @@ def update_primus_config( |
297 | 303 | log_kv_rank_0(f"-world_size", f"{args.world_size}") |
298 | 304 |
|
299 | 305 | ###################################################cuda |
300 | | - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" |
| 306 | + if not args.use_torch_fsdp2: |
| 307 | + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" |
301 | 308 |
|
302 | 309 | ###################################################checkpoint |
303 | 310 | ckpt_path = os.path.abspath(os.path.join(exp_root_path, "checkpoints")) |
@@ -823,6 +830,12 @@ def setup_model_and_optimizer( |
823 | 830 |
|
824 | 831 | log_rank_0(f"-run get_model") |
825 | 832 | model = get_model(model_provider_func, model_type) |
| 833 | + |
| 834 | + # get_megatron_optimizer will use the ddp_config |
| 835 | + if isinstance(model[0], torch_FSDP): |
| 836 | + model[0].ddp_config = DistributedDataParallelConfig() |
| 837 | + model[0].ddp_config.use_custom_fsdp = False |
| 838 | + |
826 | 839 | unwrapped_model = unwrap_model(model) |
827 | 840 |
|
828 | 841 | kwargs = {} |
|
0 commit comments