|
2 | 2 | import gc |
3 | 3 | import math |
4 | 4 | import os |
5 | | -from contextlib import nullcontext |
6 | 5 | from functools import partial |
7 | 6 |
|
8 | 7 | import torch |
|
26 | 25 | from .loss import loss_function |
27 | 26 | from .model_provider import get_model_provider_func |
28 | 27 |
|
29 | | -if torch.version.hip: |
30 | | - from vllm.device_allocator.cumem import CuMemAllocator |
31 | | - |
32 | 28 |
|
33 | 29 | def get_optimizer_param_scheduler(args, optimizer): |
34 | 30 | """Build the learning rate scheduler.""" |
@@ -80,71 +76,64 @@ def setup_model_and_optimizer( |
80 | 76 |
|
81 | 77 | model = get_model(get_model_provider_func(args, role), ModelType.encoder_or_decoder, wrap_with_ddp=False) |
82 | 78 |
|
83 | | - with ( |
84 | | - CuMemAllocator.get_instance().use_memory_pool(tag="model") |
85 | | - if args.offload and torch.version.hip |
86 | | - else nullcontext() |
87 | | - ): |
88 | | - config = get_model_config(model[0]) |
89 | | - |
90 | | - kwargs = {} |
91 | | - for f in dataclasses.fields(DistributedDataParallelConfig): |
92 | | - if hasattr(args, f.name): |
93 | | - kwargs[f.name] = getattr(args, f.name) |
94 | | - kwargs["grad_reduce_in_fp32"] = args.accumulate_allreduce_grads_in_fp32 |
95 | | - kwargs["check_for_nan_in_grad"] = args.check_for_nan_in_loss_and_grad |
96 | | - kwargs["check_for_large_grads"] = args.check_for_large_grads |
97 | | - kwargs["bucket_size"] = args.ddp_bucket_size |
98 | | - kwargs["pad_buckets_for_high_nccl_busbw"] = args.ddp_pad_buckets_for_high_nccl_busbw |
99 | | - kwargs["average_in_collective"] = args.ddp_average_in_collective |
100 | | - ddp_config = DistributedDataParallelConfig(**kwargs) |
101 | | - |
102 | | - # In the custom FSDP and DDP use path, we need to initialize the bucket size. |
103 | | - # If bucket_size is not provided as an input, use sane default. |
104 | | - # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL |
105 | | - # ring-reduce implementations are large enough to remain bandwidth-bound rather than |
106 | | - # latency-bound. |
107 | | - if ddp_config.bucket_size is None: |
108 | | - ddp_config.bucket_size = max( |
109 | | - 40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True) |
110 | | - ) |
111 | | - # Set bucket_size to infinity if overlap_grad_reduce is False. |
112 | | - if not ddp_config.overlap_grad_reduce: |
113 | | - ddp_config.bucket_size = None |
114 | | - |
115 | | - model = [ |
116 | | - DDP( |
117 | | - config=config, |
118 | | - ddp_config=ddp_config, |
119 | | - module=model_chunk, |
120 | | - # Turn off bucketing for model_chunk 2 onwards, since communication for these |
121 | | - # model chunks is overlapped with compute anyway. |
122 | | - disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, |
123 | | - ) |
124 | | - for (model_chunk_idx, model_chunk) in enumerate(model) |
125 | | - ] |
126 | | - |
127 | | - # Optimizer |
128 | | - kwargs = {} |
129 | | - for f in dataclasses.fields(OptimizerConfig): |
130 | | - if hasattr(args, f.name): |
131 | | - kwargs[f.name] = getattr(args, f.name) |
132 | | - config = OptimizerConfig(**kwargs) |
133 | | - config.timers = None |
134 | | - |
135 | | - optimizer = get_megatron_optimizer( |
136 | | - config, |
137 | | - model, |
138 | | - no_wd_decay_cond, |
139 | | - scale_lr_cond, |
140 | | - lr_mult, |
141 | | - use_gloo_process_groups=args.enable_gloo_process_groups, |
| 79 | + config = get_model_config(model[0]) |
| 80 | + |
| 81 | + kwargs = {} |
| 82 | + for f in dataclasses.fields(DistributedDataParallelConfig): |
| 83 | + if hasattr(args, f.name): |
| 84 | + kwargs[f.name] = getattr(args, f.name) |
| 85 | + kwargs["grad_reduce_in_fp32"] = args.accumulate_allreduce_grads_in_fp32 |
| 86 | + kwargs["check_for_nan_in_grad"] = args.check_for_nan_in_loss_and_grad |
| 87 | + kwargs["check_for_large_grads"] = args.check_for_large_grads |
| 88 | + kwargs["bucket_size"] = args.ddp_bucket_size |
| 89 | + kwargs["pad_buckets_for_high_nccl_busbw"] = args.ddp_pad_buckets_for_high_nccl_busbw |
| 90 | + kwargs["average_in_collective"] = args.ddp_average_in_collective |
| 91 | + ddp_config = DistributedDataParallelConfig(**kwargs) |
| 92 | + |
| 93 | + # In the custom FSDP and DDP use path, we need to initialize the bucket size. |
| 94 | + # If bucket_size is not provided as an input, use sane default. |
| 95 | + # If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL |
| 96 | + # ring-reduce implementations are large enough to remain bandwidth-bound rather than |
| 97 | + # latency-bound. |
| 98 | + if ddp_config.bucket_size is None: |
| 99 | + ddp_config.bucket_size = max(40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)) |
| 100 | + # Set bucket_size to infinity if overlap_grad_reduce is False. |
| 101 | + if not ddp_config.overlap_grad_reduce: |
| 102 | + ddp_config.bucket_size = None |
| 103 | + |
| 104 | + model = [ |
| 105 | + DDP( |
| 106 | + config=config, |
| 107 | + ddp_config=ddp_config, |
| 108 | + module=model_chunk, |
| 109 | + # Turn off bucketing for model_chunk 2 onwards, since communication for these |
| 110 | + # model chunks is overlapped with compute anyway. |
| 111 | + disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, |
142 | 112 | ) |
143 | | - opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer) |
144 | | - for optimizer in optimizer.chained_optimizers: |
145 | | - if not getattr(optimizer, "init_state_fn", None): |
146 | | - continue |
147 | | - optimizer.init_state_fn(optimizer.optimizer, optimizer.config) |
| 113 | + for (model_chunk_idx, model_chunk) in enumerate(model) |
| 114 | + ] |
| 115 | + |
| 116 | + # Optimizer |
| 117 | + kwargs = {} |
| 118 | + for f in dataclasses.fields(OptimizerConfig): |
| 119 | + if hasattr(args, f.name): |
| 120 | + kwargs[f.name] = getattr(args, f.name) |
| 121 | + config = OptimizerConfig(**kwargs) |
| 122 | + config.timers = None |
| 123 | + |
| 124 | + optimizer = get_megatron_optimizer( |
| 125 | + config, |
| 126 | + model, |
| 127 | + no_wd_decay_cond, |
| 128 | + scale_lr_cond, |
| 129 | + lr_mult, |
| 130 | + use_gloo_process_groups=args.enable_gloo_process_groups, |
| 131 | + ) |
| 132 | + opt_param_scheduler = get_optimizer_param_scheduler(args, optimizer) |
| 133 | + for optimizer in optimizer.chained_optimizers: |
| 134 | + if not getattr(optimizer, "init_state_fn", None): |
| 135 | + continue |
| 136 | + optimizer.init_state_fn(optimizer.optimizer, optimizer.config) |
148 | 137 |
|
149 | 138 | return model, optimizer, opt_param_scheduler |
150 | 139 |
|
|
0 commit comments