Skip to content

Commit f98c811

Browse files
committed
support ds on CUDA and 70b on XPU
Signed-off-by: He, Xin3 <[email protected]>
1 parent e04db30 commit f98c811

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

auto_round/compressors/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,9 +2604,11 @@ def _quantize_block(
26042604

26052605
total_loss += loss.item() / num_elm
26062606
# Sometimes the cached memory is not released during training and cause OOM
2607-
if self.low_gpu_mem_usage:
2608-
clear_memory_if_reached_threshold(threshold=0.85)
2607+
if self.low_gpu_mem_usage and torch.xpu.is_available():
2608+
clear_memory_if_reached_threshold(threshold=0.5)
26092609
self._scale_loss_and_backward(scaler, loss)
2610+
if self.low_gpu_mem_usage:
2611+
clear_memory_if_reached_threshold(threshold=0.9)
26102612

26112613
if i == 0:
26122614
init_loss = total_loss

auto_round/utils/device.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,18 +437,14 @@ def clear_memory_if_reached_threshold(threshold=0.85):
437437
for i in range(num_devices):
438438
try:
439439
total_memory = device_api.get_device_properties(i).total_memory
440-
allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i)
441-
memory_usage_ratio = allocated_memory / total_memory
440+
reserved_memory = device_api.memory_reserved(i)
441+
memory_usage_ratio = reserved_memory / total_memory
442442

443443
if memory_usage_ratio >= threshold:
444444
logger.warning_once(
445-
f"{name} device {i}: Memory usage {memory_usage_ratio*100:.2f}% "
446-
f"exceeds threshold {threshold*100:.2f}%. Clearing memory..."
445+
f"{name} device {i} has reached memory threshold. During the tuning process, a memory clearing operation will be called, which will result in more time consumption."
447446
)
448447
clear_memory()
449-
allocated_memory = device_api.memory_reserved(i) if name == "CUDA" else device_api.memory_allocated(i)
450-
memory_usage_ratio = allocated_memory / total_memory
451-
logger.warning_once(f"Cleared memory. {name} device {i}: Memory usage {memory_usage_ratio*100:.2f}%")
452448
return True
453449
except Exception as e:
454450
logger.warning_once(f"Failed to check memory for {name} device {i}: {e}")
@@ -890,7 +886,7 @@ def estimate_tuning_block_mem(
890886
# TODO: Cannot estimate the memory usage correctly for MoE models yet.
891887
# For MoE models, additional memory usage can be higher due to routing, gating,
892888
# and multiple expert activations. Here we use a conservative estimate.
893-
moe_additional_memory = additional_memory * 3 # GB
889+
moe_additional_memory = additional_memory * 6 # GB
894890
additional_memory += moe_additional_memory
895891
if torch.xpu.is_available():
896892
# https://github.com/intel/torch-xpu-ops/issues/2232

0 commit comments

Comments
 (0)