Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])

if is_complex_device_mapping(self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device
)
# Dispatch model if needed
if is_complex_device_mapping(self.device_map):
Expand Down Expand Up @@ -2332,10 +2332,10 @@ def _quantize_layer(
if total_loss < best_loss:
best_loss = total_loss
if not self.not_use_best_mse:
best_params = collect_best_params(wrapper_linear)
best_params = collect_best_params(wrapper_linear, self.cache_device)
last_best_iter = i
if self.not_use_best_mse and i == self.iters - 1:
best_params = collect_best_params(wrapper_linear)
best_params = collect_best_params(wrapper_linear, self.cache_device)

if not self.not_use_best_mse:
if 0 < self.dynamic_max_gap <= i - last_best_iter:
Expand Down Expand Up @@ -2413,6 +2413,7 @@ def _get_current_q_output(
input_others: dict,
indices: list[int],
device: str,
output_device: str = "cpu",
) -> torch.Tensor:
current_input_ids, current_input_others = self._sampling_inputs(
input_ids,
Expand All @@ -2423,7 +2424,7 @@ def _get_current_q_output(
share_cache_keys=self.shared_cache_keys,
)
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
return output_q
return output_q.to(output_device)

def _get_current_num_elm(
self,
Expand Down Expand Up @@ -2458,13 +2459,11 @@ def _quantize_block(
if is_fp8_linear(m):
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device)
set_module(block, n, new_layer)

if is_complex_device_mapping(self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)
else:
block = block.to(device)
# card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights
# loss_device is used to calculate loss on the second device if available and card_0_in_high_risk
card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)

if is_complex_device_mapping(self.device_map):
for n, m in block.named_modules():
Expand Down Expand Up @@ -2592,18 +2591,18 @@ def _quantize_block(

current_output = self._get_current_output(output, indices)

current_output = to_device(current_output, device)
current_output = to_device(current_output, loss_device)

output_q = self._get_current_q_output(block, input_ids, input_others, indices, device)
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device)

if self.attention_mask:
tmp_attention_mask = [self.attention_mask[i] for i in indices]
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device)
tmp_attention_mask.unsqueeze_(-1)
else:
tmp_attention_mask = 1.0
if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype):
loss = mse_loss( # pylint: disable=not-callable
output_q * tmp_attention_mask, current_output * tmp_attention_mask
)
Expand All @@ -2614,21 +2613,29 @@ def _quantize_block(
)

total_loss += loss.item() / num_elm

if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.5)

self._scale_loss_and_backward(scaler, loss)
clear_memory_if_reached_threshold(threshold=0.85)

if self.low_gpu_mem_usage and card_0_in_high_risk:
# clear memory to avoid OOM due to memory fragmentation
clear_memory_if_reached_threshold(threshold=0.8)

if i == 0:
init_loss = total_loss

if total_loss < best_loss:
best_loss = total_loss
if not self.not_use_best_mse:
best_params = collect_best_params(block)
best_params = collect_best_params(block, self.cache_device)
# print(f"get better result at iter {i}, the loss is {total_loss}", flush=True)

last_best_iter = i
if self.not_use_best_mse and i == self.iters - 1:
best_params = collect_best_params(block)
best_params = collect_best_params(block, self.cache_device)

if not self.not_use_best_mse:
if 0 < self.dynamic_max_gap <= i - last_best_iter:
Expand All @@ -2645,6 +2652,8 @@ def _quantize_block(
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
)
logger.info(dump_info)
if self.low_gpu_mem_usage:
clear_memory() # clear cached memory during training
if len(unquantized_layer_names) != 0:
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
Expand All @@ -2655,7 +2664,8 @@ def _quantize_block(
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")

if self.enable_quanted_input:
clear_memory()
if not self.low_gpu_mem_usage: # In case of clearing memory twice
clear_memory() # clear cached memory during training
q_outputs = self._get_block_outputs(
block,
input_ids,
Expand Down Expand Up @@ -2781,13 +2791,15 @@ def _quantize_blocks(
modules = [get_module(model, n) for n in names]
m = WrapperMultiblock(modules)

m.config = model.config if hasattr(model, "config") else None
q_input, input_ids = quantize_block(
m,
input_ids,
input_others,
q_input=q_input,
device=device,
)
del m.config
if self.is_packing_immediate:
from auto_round.export import PACKING_LAYER_WITH_FORMAT

Expand Down
5 changes: 3 additions & 2 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,14 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non
return True, ""


def collect_best_params(block):
def collect_best_params(block, cache_device="cpu"):
"""Collect the best parameters from the block to the specified device."""
params = {}
for n, m in block.named_modules():
if hasattr(m, "orig_layer"):
params[n] = {}
for key in m.params.keys():
params[n][key] = copy.deepcopy(m.params[key].data)
params[n][key] = m.params[key].data.to(cache_device, copy=True)
return params


Expand Down
Loading