@@ -658,13 +658,7 @@ class FixedPoolAllocator(TemporaryBucketAllocator):
658658 deallocation of temporary buffers during FSDP operations.
659659 """
660660
661- def __init__ (
662- self ,
663- name : str ,
664- fsdp_param_groups : List ["ParameterGroup" ],
665- size : int = 2 ,
666- fallback_to_persistent_buffer : bool = False ,
667- ):
661+ def __init__ (self , name : str , fsdp_param_groups : List ["ParameterGroup" ], size : int = 2 ):
668662 self .name = name
669663 self .fsdp_param_groups = fsdp_param_groups
670664 self .size = size # Number of buffers in the pool (default is 2 for double buffering)
@@ -697,29 +691,6 @@ def __init__(
697691 ), "Found no FSDP units to use fixed-size buffering"
698692 self .fsdp_double_buffer_units = fsdp_units_to_double_buffer
699693
700- if torch .distributed .get_rank () == 0 :
701- for bucket_id , param_group in enumerate (fsdp_param_groups ):
702- if (
703- param_group .fsdp_unit_id == - 1
704- or param_group .fsdp_unit_id is None
705- or param_group .fsdp_unit_id not in self .fsdp_double_buffer_units
706- ):
707- logging .info (
708- f"FSDP unit (id={ param_group .fsdp_unit_id } ) does not fit "
709- "in FixedPoolAllcator"
710- )
711- if fallback_to_persistent_buffer is False :
712- logging .info (
713- "It will fall back to dynamic memory allocator, NCCL user "
714- "buffer is not supported"
715- )
716- else :
717- logging .info (
718- "It will be allocated a persistent buffer. If the memory "
719- "budget is tight, set "
720- "trainer.strategy.ddp.fsdp_db_use_persist_buf_on_alloc_fail to False."
721- )
722-
723694 # Initialize buffer group status.
724695 # Each buffer group represents a set of buffers associated with an FSDP unit's bucket group.
725696 self .idle_buffer = [] # List of available (buf_group_id, offset) tuples.
@@ -732,7 +703,6 @@ def __init__(
732703 self .idle_buffer .append ((buf_group_id , bucket_offset ))
733704
734705 # Fallback allocator used if the fixed pool allocator cannot fulfill a request.
735- self .fallback_to_persistent_buffer = fallback_to_persistent_buffer
736706 self .backup_allocator = TemporaryBucketAllocator ()
737707
738708 def _is_two_bucket_group_equal (self , group_a , group_b ):
@@ -785,31 +755,28 @@ def allocate(
785755 f"current using_buffer: { self .using_buffer } \n "
786756 f"current idle_buffer: { self .idle_buffer } "
787757 )
788- elif self .fallback_to_persistent_buffer is True :
789- buffer_name = f"{ self .name } _not_fit_in_fixed_pool_{ bucket_id } _{ size } _{ dtype } _{ device } "
790- else :
791- # If the bucket is not eligible for fixed pool buffering, or no buffer is available,
792- # fall back to dynamic allocation via the backup allocator. This means that we
793- # will do dynamic memory allocation.
794- logging .debug (f"[FSDP] Using backup allocator for { bucket_id } { fsdp_unit_id } " )
795- return self .backup_allocator .allocate (
796- bucket_id = bucket_id , size = size , dtype = dtype , device = device
758+ # Synchronization is required before the allocation for the user buffer
759+ if mem_alloc_context is not None and mem_alloc_context != nullcontext :
760+ # Check if a new buffer allocation is required
761+ if (
762+ self .allocation_tracker .get ((buffer_name , dtype ), None ) is None
763+ or self .allocation_tracker [(buffer_name , dtype )] < size
764+ ):
765+ # Requires synchronization for new buffer allocation
766+ self .allocation_tracker [(buffer_name , dtype )] = size
767+ torch .cuda .synchronize ()
768+ return Bucket (
769+ data = get_global_memory_buffer ().get_tensor (
770+ [size ], dtype = dtype , name = buffer_name , mem_alloc_context = mem_alloc_context
771+ )
797772 )
798773
799- # Use buffer_name to get memory from global memory.
800- if mem_alloc_context is not None and mem_alloc_context != nullcontext :
801- # Check if a new buffer allocation is required
802- if (
803- self .allocation_tracker .get ((buffer_name , dtype ), None ) is None
804- or self .allocation_tracker [(buffer_name , dtype )] < size
805- ):
806- # Requires synchronization for new buffer allocation
807- self .allocation_tracker [(buffer_name , dtype )] = size
808- torch .cuda .synchronize ()
809- return Bucket (
810- data = get_global_memory_buffer ().get_tensor (
811- [size ], dtype = dtype , name = buffer_name , mem_alloc_context = mem_alloc_context
812- )
774+ # If the bucket is not eligible for fixed pool buffering, or no buffer is available,
775+ # fall back to dynamic allocation via the backup allocator. This means that we
776+ # will do dynamic memory allocation.
777+ logging .debug (f"[FSDP] Using backup allocator for { bucket_id } { fsdp_unit_id } " )
778+ return self .backup_allocator .allocate (
779+ bucket_id = bucket_id , size = size , dtype = dtype , device = device
813780 )
814781
815782 def _get_gbuf_name (self , buf_group_id : int , bucket_index : int ):
@@ -828,10 +795,9 @@ def free(self, bucket_id: int):
828795 self .idle_buffer .append (self .using_buffer [bucket_id ])
829796 del self .using_buffer [bucket_id ]
830797 return
831- if self .fallback_to_persistent_buffer is False :
832- # If not managed by fixed pool allocator, delegate to the backup allocator.
833- logging .debug (f"[FSDP] Free from the backup allocator for { bucket_id } { fsdp_unit_id } " )
834- self .backup_allocator .free (bucket_id )
798+ # If not managed by fixed pool allocator, delegate to the backup allocator.
799+ logging .debug (f"[FSDP] Free from the backup allocator for { bucket_id } { fsdp_unit_id } " )
800+ self .backup_allocator .free (bucket_id )
835801
836802
837803class DataParallelBuffer :
@@ -1908,21 +1874,15 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
19081874 if self .ddp_config .fsdp_double_buffer and len (self .bucketing_policy .fsdp_unit_modules ) > 0 :
19091875 UB_BUFFER_NUM = 2
19101876 self .weight_alloc = FixedPoolAllocator (
1911- name = "fsdp_params" ,
1912- fsdp_param_groups = self .parameter_groups ,
1913- size = UB_BUFFER_NUM ,
1914- fallback_to_persistent_buffer = self .ddp_config .fsdp_db_use_persist_buf_on_alloc_fail ,
1877+ name = "fsdp_params" , fsdp_param_groups = self .parameter_groups , size = UB_BUFFER_NUM
19151878 )
19161879 self .transpose_weight_alloc = FixedPoolAllocator (
19171880 name = "fsdp_fp8_transpose_params" ,
19181881 fsdp_param_groups = self .parameter_groups ,
19191882 size = UB_BUFFER_NUM ,
19201883 )
19211884 self .main_grad_alloc = FixedPoolAllocator (
1922- name = "fsdp_grads" ,
1923- fsdp_param_groups = self .parameter_groups ,
1924- size = UB_BUFFER_NUM ,
1925- fallback_to_persistent_buffer = self .ddp_config .fsdp_db_use_persist_buf_on_alloc_fail ,
1885+ name = "fsdp_grads" , fsdp_param_groups = self .parameter_groups , size = UB_BUFFER_NUM
19261886 )
19271887 self .double_buf_units = self .weight_alloc .fsdp_double_buffer_units
19281888 else :
0 commit comments