Skip to content

Commit cb9a792

Browse files
committed
Revert "Use global user buffer when the bucket size does not fit FixedPoolAllocator (#2857)"
This reverts commit afe443b.
1 parent 2050da3 commit cb9a792

File tree

5 files changed

+27
-106
lines changed

5 files changed

+27
-106
lines changed

megatron/core/distributed/distributed_data_parallel_config.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,7 @@ class DistributedDataParallelConfig:
122122
This option will cause additional memory overhead, however, it is necessary for
123123
to register user buffer (nccl_ub=True) for the Megatron FSDP.
124124
This option will be automatically set to True when nccl_ub=True.
125-
"""
126-
127-
fsdp_db_use_persist_buf_on_alloc_fail: bool = False
128-
"""Whether to fall back to persistent buffer when a bucket does not
129-
fit FSDP double buffer size. If true, FSDP will use the persistently
130-
allocated buffer for the bucket that does not fit, it will enable NCCL
131-
user buffer with the cost of more memory usage. If false, FSDP will use
132-
Dynamic memory allocator, NCCL user buffer won't not enabled, which
133-
usually leads to low performance.
134-
"""
125+
"""
135126

136127
fsdp_all_gather_in_start_param_sync: bool = True
137128
"""

megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,6 @@ class DistributedDataParallelConfig:
119119
This option will be automatically set to True when nccl_ub=True.
120120
"""
121121

122-
fsdp_all_gather_in_start_param_sync: bool = True
123-
"""
124-
If True, use all-gather during the initial Megatron-FSDP parameter
125-
synchronization step. This can increase overlap between the first
126-
parameter all-gather and computation, helping to better hide the
127-
initial communication cost.
128-
"""
129-
130-
fsdp_db_use_persist_buf_on_alloc_fail: bool = False
131-
"""Whether to fall back to persistent buffer when a bucket does not
132-
fit FSDP double buffer size. If true, FSDP will use the persistently
133-
allocated buffer for the bucket that does not fit, it will enable NCCL
134-
user buffer with the cost of more memory usage. If false, FSDP will use
135-
Dynamic memory allocator, NCCL user buffer won't not enabled, which
136-
usually leads to low performance.
137-
"""
138-
139122
outer_dp_sharding_strategy: str = 'no_shard'
140123
"""
141124
Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP) mode.

megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def fully_shard_model(
9898
keep_fp8_transpose_cache: bool = False,
9999
nccl_ub: bool = False,
100100
fsdp_double_buffer: bool = False,
101-
fsdp_db_use_persist_buf_on_alloc_fail: bool = False,
102101
disable_symmetric_registration: bool = False,
103102
enable_fine_grained_param_gather: bool = False,
104103
) -> torch.nn.Module:
@@ -233,10 +232,6 @@ class that schedules the sharding lifecycle of the model parameters and gradient
233232
fsdp_double_buffer (bool):
234233
Whether to use double buffer for FSDP. Defaults to False.
235234
236-
fsdp_db_use_persist_buf_on_alloc_fail (bool):
237-
Whether to fall back to persistent buffer allocator when a bucket does not
238-
fit FSDP double buffer size.
239-
240235
disable_symmetric_registration (bool):
241236
Whether to disable symmetric (window) registration for NCCL UB registration.
242237
This option forces conventional (local) UB registration when nccl_ub is set.
@@ -342,7 +337,6 @@ class that schedules the sharding lifecycle of the model parameters and gradient
342337
keep_fp8_transpose_cache=keep_fp8_transpose_cache, # pylint: disable=C0301
343338
nccl_ub=nccl_ub,
344339
fsdp_double_buffer=fsdp_double_buffer or nccl_ub,
345-
fsdp_db_use_persist_buf_on_alloc_fail=fsdp_db_use_persist_buf_on_alloc_fail,
346340
disable_symmetric_registration=disable_symmetric_registration,
347341
check_for_nan_in_grad=check_for_nan_in_grad,
348342
)
@@ -640,7 +634,6 @@ def fully_shard(
640634
keep_fp8_transpose_cache: bool = False,
641635
nccl_ub: bool = False,
642636
fsdp_double_buffer: bool = False,
643-
fsdp_db_use_persist_buf_on_alloc_fail: bool = False,
644637
disable_symmetric_registration: bool = False,
645638
enable_fine_grained_param_gather: bool = False,
646639
) -> tuple[MegatronFSDP, torch.optim.Optimizer]:
@@ -689,7 +682,6 @@ def fully_shard(
689682
keep_fp8_transpose_cache=keep_fp8_transpose_cache,
690683
nccl_ub=nccl_ub,
691684
fsdp_double_buffer=fsdp_double_buffer,
692-
fsdp_db_use_persist_buf_on_alloc_fail=fsdp_db_use_persist_buf_on_alloc_fail,
693685
disable_symmetric_registration=disable_symmetric_registration,
694686
enable_fine_grained_param_gather=enable_fine_grained_param_gather,
695687
)

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@ class MegatronFSDP(torch.nn.Module):
136136
fsdp_double_buffer (bool): Whether to use persistently allocated double buffers
137137
for the temporary memory needed in the FSDP communication. This flag is
138138
automatically set to True when nccl_ub is True.
139-
fsdp_db_use_persist_buf_on_alloc_fail (bool): Whether to fall back to persistent buffer
140-
allocator when a bucket does not fit FSDP double buffer size.
141139
disable_symmetric_registration (bool): Whether to disable symmetric (window) registration
142140
for NCCL userbuffer registration. This option will force to use conventional (local)
143141
userbuffer registration when nccl_ub is set.
@@ -157,7 +155,6 @@ class MegatronFSDP(torch.nn.Module):
157155
... keep_fp8_transpose_cache=False,
158156
... nccl_ub=False,
159157
... fsdp_double_buffer=False,
160-
... fsdp_db_use_persist_buf_on_alloc_fail=False,
161158
... disable_symmetric_registration=False,
162159
... )
163160
"""
@@ -176,7 +173,6 @@ def __init__(
176173
keep_fp8_transpose_cache: bool = False,
177174
nccl_ub: bool = False,
178175
fsdp_double_buffer: bool = False,
179-
fsdp_db_use_persist_buf_on_alloc_fail: bool = False,
180176
disable_symmetric_registration: bool = False,
181177
enable_fine_grained_param_gather_hook: bool = False,
182178
):
@@ -221,7 +217,6 @@ def __init__(
221217
keep_fp8_transpose_cache=keep_fp8_transpose_cache, # pylint: disable=C0301
222218
nccl_ub=nccl_ub,
223219
fsdp_double_buffer=fsdp_double_buffer or nccl_ub,
224-
fsdp_db_use_persist_buf_on_alloc_fail=fsdp_db_use_persist_buf_on_alloc_fail,
225220
disable_symmetric_registration=disable_symmetric_registration,
226221
)
227222
else:

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 26 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -658,13 +658,7 @@ class FixedPoolAllocator(TemporaryBucketAllocator):
658658
deallocation of temporary buffers during FSDP operations.
659659
"""
660660

661-
def __init__(
662-
self,
663-
name: str,
664-
fsdp_param_groups: List["ParameterGroup"],
665-
size: int = 2,
666-
fallback_to_persistent_buffer: bool = False,
667-
):
661+
def __init__(self, name: str, fsdp_param_groups: List["ParameterGroup"], size: int = 2):
668662
self.name = name
669663
self.fsdp_param_groups = fsdp_param_groups
670664
self.size = size # Number of buffers in the pool (default is 2 for double buffering)
@@ -697,29 +691,6 @@ def __init__(
697691
), "Found no FSDP units to use fixed-size buffering"
698692
self.fsdp_double_buffer_units = fsdp_units_to_double_buffer
699693

700-
if torch.distributed.get_rank() == 0:
701-
for bucket_id, param_group in enumerate(fsdp_param_groups):
702-
if (
703-
param_group.fsdp_unit_id == -1
704-
or param_group.fsdp_unit_id is None
705-
or param_group.fsdp_unit_id not in self.fsdp_double_buffer_units
706-
):
707-
logging.info(
708-
f"FSDP unit (id={param_group.fsdp_unit_id}) does not fit "
709-
"in FixedPoolAllcator"
710-
)
711-
if fallback_to_persistent_buffer is False:
712-
logging.info(
713-
"It will fall back to dynamic memory allocator, NCCL user "
714-
"buffer is not supported"
715-
)
716-
else:
717-
logging.info(
718-
"It will be allocated a persistent buffer. If the memory "
719-
"budget is tight, set "
720-
"trainer.strategy.ddp.fsdp_db_use_persist_buf_on_alloc_fail to False."
721-
)
722-
723694
# Initialize buffer group status.
724695
# Each buffer group represents a set of buffers associated with an FSDP unit's bucket group.
725696
self.idle_buffer = [] # List of available (buf_group_id, offset) tuples.
@@ -732,7 +703,6 @@ def __init__(
732703
self.idle_buffer.append((buf_group_id, bucket_offset))
733704

734705
# Fallback allocator used if the fixed pool allocator cannot fulfill a request.
735-
self.fallback_to_persistent_buffer = fallback_to_persistent_buffer
736706
self.backup_allocator = TemporaryBucketAllocator()
737707

738708
def _is_two_bucket_group_equal(self, group_a, group_b):
@@ -785,31 +755,28 @@ def allocate(
785755
f"current using_buffer: {self.using_buffer} \n"
786756
f"current idle_buffer: {self.idle_buffer}"
787757
)
788-
elif self.fallback_to_persistent_buffer is True:
789-
buffer_name = f"{self.name}_not_fit_in_fixed_pool_{bucket_id}_{size}_{dtype}_{device}"
790-
else:
791-
# If the bucket is not eligible for fixed pool buffering, or no buffer is available,
792-
# fall back to dynamic allocation via the backup allocator. This means that we
793-
# will do dynamic memory allocation.
794-
logging.debug(f"[FSDP] Using backup allocator for {bucket_id} {fsdp_unit_id}")
795-
return self.backup_allocator.allocate(
796-
bucket_id=bucket_id, size=size, dtype=dtype, device=device
758+
# Synchronization is required before the allocation for the user buffer
759+
if mem_alloc_context is not None and mem_alloc_context != nullcontext:
760+
# Check if a new buffer allocation is required
761+
if (
762+
self.allocation_tracker.get((buffer_name, dtype), None) is None
763+
or self.allocation_tracker[(buffer_name, dtype)] < size
764+
):
765+
# Requires synchronization for new buffer allocation
766+
self.allocation_tracker[(buffer_name, dtype)] = size
767+
torch.cuda.synchronize()
768+
return Bucket(
769+
data=get_global_memory_buffer().get_tensor(
770+
[size], dtype=dtype, name=buffer_name, mem_alloc_context=mem_alloc_context
771+
)
797772
)
798773

799-
# Use buffer_name to get memory from global memory.
800-
if mem_alloc_context is not None and mem_alloc_context != nullcontext:
801-
# Check if a new buffer allocation is required
802-
if (
803-
self.allocation_tracker.get((buffer_name, dtype), None) is None
804-
or self.allocation_tracker[(buffer_name, dtype)] < size
805-
):
806-
# Requires synchronization for new buffer allocation
807-
self.allocation_tracker[(buffer_name, dtype)] = size
808-
torch.cuda.synchronize()
809-
return Bucket(
810-
data=get_global_memory_buffer().get_tensor(
811-
[size], dtype=dtype, name=buffer_name, mem_alloc_context=mem_alloc_context
812-
)
774+
# If the bucket is not eligible for fixed pool buffering, or no buffer is available,
775+
# fall back to dynamic allocation via the backup allocator. This means that we
776+
# will do dynamic memory allocation.
777+
logging.debug(f"[FSDP] Using backup allocator for {bucket_id} {fsdp_unit_id}")
778+
return self.backup_allocator.allocate(
779+
bucket_id=bucket_id, size=size, dtype=dtype, device=device
813780
)
814781

815782
def _get_gbuf_name(self, buf_group_id: int, bucket_index: int):
@@ -828,10 +795,9 @@ def free(self, bucket_id: int):
828795
self.idle_buffer.append(self.using_buffer[bucket_id])
829796
del self.using_buffer[bucket_id]
830797
return
831-
if self.fallback_to_persistent_buffer is False:
832-
# If not managed by fixed pool allocator, delegate to the backup allocator.
833-
logging.debug(f"[FSDP] Free from the backup allocator for {bucket_id} {fsdp_unit_id}")
834-
self.backup_allocator.free(bucket_id)
798+
# If not managed by fixed pool allocator, delegate to the backup allocator.
799+
logging.debug(f"[FSDP] Free from the backup allocator for {bucket_id} {fsdp_unit_id}")
800+
self.backup_allocator.free(bucket_id)
835801

836802

837803
class DataParallelBuffer:
@@ -1908,21 +1874,15 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
19081874
if self.ddp_config.fsdp_double_buffer and len(self.bucketing_policy.fsdp_unit_modules) > 0:
19091875
UB_BUFFER_NUM = 2
19101876
self.weight_alloc = FixedPoolAllocator(
1911-
name="fsdp_params",
1912-
fsdp_param_groups=self.parameter_groups,
1913-
size=UB_BUFFER_NUM,
1914-
fallback_to_persistent_buffer=self.ddp_config.fsdp_db_use_persist_buf_on_alloc_fail,
1877+
name="fsdp_params", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM
19151878
)
19161879
self.transpose_weight_alloc = FixedPoolAllocator(
19171880
name="fsdp_fp8_transpose_params",
19181881
fsdp_param_groups=self.parameter_groups,
19191882
size=UB_BUFFER_NUM,
19201883
)
19211884
self.main_grad_alloc = FixedPoolAllocator(
1922-
name="fsdp_grads",
1923-
fsdp_param_groups=self.parameter_groups,
1924-
size=UB_BUFFER_NUM,
1925-
fallback_to_persistent_buffer=self.ddp_config.fsdp_db_use_persist_buf_on_alloc_fail,
1885+
name="fsdp_grads", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM
19261886
)
19271887
self.double_buf_units = self.weight_alloc.fsdp_double_buffer_units
19281888
else:

0 commit comments

Comments
 (0)