Skip to content

Commit 2481b52

Browse files
committed
Refactor: Improve block allocation and expert string parsing
This commit refactors the DisTorch safetensor loading and allocation logic for improved performance and correctness. The main changes to the block assignment are: - The primary compute device is now included in the pool of "donor" devices, allowing for more holistic memory quota calculation across all available GPUs. - Unassigned "orphan" blocks are now allocated to the compute device instead of the CPU. This keeps more of the model in VRAM, reducing potential bottlenecks. Additionally, this commit: - Fixes a bug in the byte expert string parser where the wildcard `*` was incorrectly checked in the device name instead of the value. - Standardizes variable names like `allocations_string` for better code clarity and consistency.
1 parent a5ff7fd commit 2481b52

1 file changed

Lines changed: 13 additions & 18 deletions

File tree

distorch_2.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,16 @@ def new_partially_load(self, device_to, extra_memory=0, full_load=False, force_p
130130
logger.info("[MultiGPU_DisTorch2] Successfully patched ModelPatcher.partially_load")
131131

132132

133-
def analyze_safetensor_loading(model_patcher, allocations_str):
133+
def analyze_safetensor_loading(model_patcher, allocations_string):
134134
"""
135135
Analyze and distribute safetensor model blocks across devices
136136
"""
137137
DEVICE_RATIOS_DISTORCH = {}
138138
device_table = {}
139-
distorch_alloc = allocations_str
139+
distorch_alloc = allocations_string
140140
virtual_vram_gb = 0.0
141141

142-
distorch_alloc, virtual_vram_str = allocations_str.split('#')
142+
distorch_alloc, virtual_vram_str = allocations_string.split('#')
143143

144144
compute_device = virtual_vram_str.split(';')[0]
145145
logger.info(f"[MultiGPU_DisTorch2] Compute Device: {compute_device}")
@@ -260,7 +260,7 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
260260
block_assignments = {}
261261

262262
# Create a memory quota for each donor device based on its calculated allocation.
263-
donor_devices = [d for d in sorted_devices if d != compute_device]
263+
donor_devices = [d for d in sorted_devices]
264264
donor_quotas = {
265265
dev: device_table[dev]["alloc_gb"] * (1024**3)
266266
for dev in donor_devices
@@ -276,8 +276,8 @@ def analyze_safetensor_loading(model_patcher, allocations_str):
276276
assigned_to_donor = True
277277
break # Move to the next block
278278

279-
if not assigned_to_donor:
280-
block_assignments[block_name] = "cpu"
279+
if not assigned_to_donor: #Note - small rounding errors and tensor-fitting on devices make a block occasionally an orphan. We treat orphans the same as tiny_block_list as they are generally small rounding errors
280+
block_assignments[block_name] = compute_device
281281

282282
if tiny_block_list:
283283
for block_name, module, block_type, block_memory in tiny_block_list:
@@ -372,10 +372,9 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
372372
if ',' not in allocation:
373373
continue
374374
dev_name, val_str = allocation.split(',', 1)
375-
is_wildcard = '*' in dev_name
375+
is_wildcard = '*' in val_str
376376

377377
if is_wildcard:
378-
dev_name = dev_name.replace('*', '').strip()
379378
wildcard_device = dev_name
380379
# Don't add wildcard to the priority list yet
381380
else:
@@ -414,9 +413,9 @@ def calculate_fraction_from_byte_expert_string(model_patcher, byte_str):
414413
fraction = bytes_alloc / total_device_vram
415414
allocation_parts.append(f"{dev},{fraction:.4f}")
416415

417-
result_string = ";".join(allocation_parts)
416+
allocations_string = ";".join(allocation_parts)
418417

419-
return result_string
418+
return allocations_string
420419

421420
def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str):
422421
"""
@@ -464,9 +463,9 @@ def calculate_fraction_from_ratio_expert_string(model_patcher, ratio_str):
464463

465464
logger.info(f"[MultiGPU_DisTorch2] Ratio(%) Mode - {ratio_str} -> {ratio_string} ratio, put {put_part}")
466465

467-
result_string = ";".join(allocation_parts)
466+
allocations_string = ";".join(allocation_parts)
468467

469-
return result_string
468+
return allocations_string
470469

471470
def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
472471
"""Calculate virtual VRAM allocation string for distributed safetensor loading"""
@@ -545,12 +544,8 @@ def calculate_safetensor_vvram_allocation(model_patcher, virtual_vram_str):
545544
donor_percent = donor_allocations[donor] / donor_vram
546545
allocation_parts.append(f"{donor},{donor_percent:.4f}")
547546

548-
allocation_string = ";".join(allocation_parts)
549-
550-
fmt_mem = "{:<20}{:>20}"
551-
logger.info(fmt_mem.format("[MultiGPU_DisTorch2] Virtual VRAM Expert String", allocation_string))
552-
553-
return allocation_string
547+
allocations_string = ";".join(allocation_parts)
548+
return allocations_string
554549

555550
def override_class_with_distorch_safetensor_v2(cls):
556551
"""DisTorch 2.0 wrapper for safetensor models"""

0 commit comments

Comments
 (0)