Skip to content

Commit f8dd856

Browse files
committed
Dummy inputs to device
Signed-off-by: J. Pablo Muñoz <[email protected]>
1 parent ba88aa0 commit f8dd856

File tree

1 file changed

+49
-15
lines changed
  • examples/llm_compression/torch/qat_with_lora

1 file changed

+49
-15
lines changed

examples/llm_compression/torch/qat_with_lora/main_nls.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,14 @@
2323
import datasets
2424
import numpy as np
2525
import torch
26-
import transformers
2726
from lm_eval import evaluator
2827
from lm_eval.models.huggingface import HFLM
2928
from torch import Tensor
3029
from torch import nn
3130
from torch.jit import TracerWarning
3231
from torch.utils.tensorboard import SummaryWriter
33-
from transformers import AutoModelForCausalLM
34-
from transformers import AutoTokenizer
35-
from transformers import get_cosine_schedule_with_warmup
3632

33+
import transformers
3734
from examples.llm_compression.torch.qat_with_lora.main import load_checkpoint
3835
from examples.llm_compression.torch.qat_with_lora.main import save_checkpoint
3936
from examples.llm_compression.torch.qat_with_lora.main import set_trainable
@@ -46,6 +43,9 @@
4643
from nncf.torch.function_hook.wrapper import get_hook_storage
4744
from nncf.torch.quantization.layers import AsymmetricLoraNLSQuantizer
4845
from nncf.torch.quantization.layers import SymmetricLoraNLSQuantizer
46+
from transformers import AutoModelForCausalLM
47+
from transformers import AutoTokenizer
48+
from transformers import get_cosine_schedule_with_warmup
4949

5050
warnings.filterwarnings("ignore", category=TracerWarning)
5151

@@ -188,7 +188,10 @@ def lm_eval(model: nn.Module, tokenizer: AutoTokenizer, task: str, batch_size: i
188188

189189

190190
def tokenize(
191-
tokenizer: AutoTokenizer, prompt: str, add_eos_token: bool = True, max_length: int = 256
191+
tokenizer: AutoTokenizer,
192+
prompt: str,
193+
add_eos_token: bool = True,
194+
max_length: int = 256,
192195
) -> dict[str, list[int]]:
193196
"""
194197
Tokenize the given prompt.
@@ -324,7 +327,14 @@ def get_argument_parser() -> argparse.ArgumentParser:
324327
parser.add_argument(
325328
"--task",
326329
type=str,
327-
choices=["openbookqa", "winogrande", "arc_challenge", "arc_easy", "gsm8k", "hellaswag"],
330+
choices=[
331+
"openbookqa",
332+
"winogrande",
333+
"arc_challenge",
334+
"arc_easy",
335+
"gsm8k",
336+
"hellaswag",
337+
],
328338
default="openbookqa",
329339
help="Evaluation task",
330340
)
@@ -439,7 +449,11 @@ def main(argv) -> float:
439449
train_dataset = [tokenize(tokenizer, sample) for sample in train_dataset]
440450
random.shuffle(train_dataset)
441451

442-
model = compress_weights(model, dataset=Dataset([model_input]), **compression_config)
452+
model = compress_weights(
453+
model,
454+
dataset=Dataset([{k: v.to(device) for k, v in model_input.items()}]),
455+
**compression_config,
456+
)
443457
results_of_compressed_model = lm_eval(model, tokenizer, task=args.task, batch_size=args.eval_batch_size)
444458
print(f"Results of NNCF compressed model={json.dumps(results_of_compressed_model, indent=4)}")
445459
overall_result["results_of_compressed_model"] = results_of_compressed_model
@@ -482,7 +496,9 @@ def main(argv) -> float:
482496
else:
483497
# Initialize the counter for tracking activation counts during training
484498
maximal_lora_rank_config = configure_lora_adapters(
485-
layer_id_vs_lora_quantizers_map, lora_rank_space=args.lora_rank_space, adapter_strategy="maximal"
499+
layer_id_vs_lora_quantizers_map,
500+
lora_rank_space=args.lora_rank_space,
501+
adapter_strategy="maximal",
486502
)
487503
activation_counter = [
488504
{rank: 0 for rank in args.lora_rank_space} for _ in range(len(maximal_lora_rank_config))
@@ -498,7 +514,9 @@ def main(argv) -> float:
498514
# configure the LoRA adapters with a random rank configuration from the specified rank space.
499515
if not disable_nls and grad_steps == 0:
500516
current_config = configure_lora_adapters(
501-
layer_id_vs_lora_quantizers_map, lora_rank_space=args.lora_rank_space, adapter_strategy="random"
517+
layer_id_vs_lora_quantizers_map,
518+
lora_rank_space=args.lora_rank_space,
519+
adapter_strategy="random",
502520
)
503521
# Update the activation counter
504522
for idx, rank in enumerate(current_config):
@@ -600,12 +618,16 @@ def get_top_k_min_loss_configs(loss_recorder, k=5):
600618
"results": results_of_nls_finetuned_compressed_model_median,
601619
}
602620
)
603-
best_result = max(best_result, results_of_nls_finetuned_compressed_model_median[args.lm_eval_metric])
621+
best_result = max(
622+
best_result,
623+
results_of_nls_finetuned_compressed_model_median[args.lm_eval_metric],
624+
)
604625

605626
# Test the most frequent configuration
606627
most_frequent_lora_rank_config = get_most_frequent_config(activation_counter)
607628
configure_lora_adapters(
608-
layer_id_vs_lora_quantizers_map, specific_rank_config=most_frequent_lora_rank_config
629+
layer_id_vs_lora_quantizers_map,
630+
specific_rank_config=most_frequent_lora_rank_config,
609631
)
610632
results_of_nls_finetuned_compressed_model_most_frequent = lm_eval(
611633
model, tokenizer, task=args.task, batch_size=args.eval_batch_size
@@ -621,12 +643,18 @@ def get_top_k_min_loss_configs(loss_recorder, k=5):
621643
"results": results_of_nls_finetuned_compressed_model_most_frequent,
622644
}
623645
)
624-
best_result = max(best_result, results_of_nls_finetuned_compressed_model_most_frequent[args.lm_eval_metric])
646+
best_result = max(
647+
best_result,
648+
results_of_nls_finetuned_compressed_model_most_frequent[args.lm_eval_metric],
649+
)
625650

626651
# Test the top 5 min loss configurations
627652
top_5_min_loss_configs = get_top_k_min_loss_configs(loss_recorder, k=5)
628653
for i, min_loss_config in enumerate(top_5_min_loss_configs):
629-
configure_lora_adapters(layer_id_vs_lora_quantizers_map, specific_rank_config=min_loss_config)
654+
configure_lora_adapters(
655+
layer_id_vs_lora_quantizers_map,
656+
specific_rank_config=min_loss_config,
657+
)
630658
results_of_nls_finetuned_compressed_model_min_loss = lm_eval(
631659
model, tokenizer, task=args.task, batch_size=args.eval_batch_size
632660
)
@@ -641,10 +669,16 @@ def get_top_k_min_loss_configs(loss_recorder, k=5):
641669
"results": results_of_nls_finetuned_compressed_model_min_loss,
642670
}
643671
)
644-
best_result = max(best_result, results_of_nls_finetuned_compressed_model_min_loss[args.lm_eval_metric])
672+
best_result = max(
673+
best_result,
674+
results_of_nls_finetuned_compressed_model_min_loss[args.lm_eval_metric],
675+
)
645676
else:
646677
assert args.custom_rank_config is not None, "Please provide `custom_rank_config` for evaluation."
647-
configure_lora_adapters(layer_id_vs_lora_quantizers_map, specific_rank_config=args.custom_rank_config)
678+
configure_lora_adapters(
679+
layer_id_vs_lora_quantizers_map,
680+
specific_rank_config=args.custom_rank_config,
681+
)
648682
results_of_nls_finetuned_compressed_model_custom = lm_eval(
649683
model, tokenizer, task=args.task, batch_size=args.eval_batch_size
650684
)

0 commit comments

Comments
 (0)