1+ #!/usr/bin/env python3
2+ """
3+ Direct Whisper LoRA extraction script.
4+ This script extracts LoRA weights from a fine-tuned Whisper model using PEFT directly.
5+ """
6+
7+ import os
8+ import torch
9+ import logging
10+ import argparse
11+ from transformers import WhisperForConditionalGeneration
12+ from peft import get_peft_model , LoraConfig , TaskType
13+
14+ # Configure logging
15+ logging .basicConfig (
16+ level = logging .INFO ,
17+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18+ )
19+ logger = logging .getLogger ("direct_whisper_lora_extract" )
20+
21+ def extract_lora (
22+ base_model_path : str ,
23+ finetuned_model_path : str ,
24+ output_path : str ,
25+ rank : int = 32 ,
26+ encoder_only : bool = False ,
27+ decoder_only : bool = False ,
28+ ):
29+ """
30+ Extract LoRA weights from a fine-tuned Whisper model.
31+
32+ Args:
33+ base_model_path: Path to the base Whisper model
34+ finetuned_model_path: Path to the fine-tuned Whisper model
35+ output_path: Path to save the extracted LoRA adapter
36+ rank: Rank for the LoRA adapter
37+ encoder_only: Whether to extract only encoder weights
38+ decoder_only: Whether to extract only decoder weights
39+ """
40+ logger .info (f"Loading base model from { base_model_path } " )
41+ base_model = WhisperForConditionalGeneration .from_pretrained (
42+ base_model_path ,
43+ torch_dtype = torch .float32 ,
44+ low_cpu_mem_usage = True ,
45+ )
46+
47+ logger .info (f"Loading fine-tuned model from { finetuned_model_path } " )
48+ finetuned_model = WhisperForConditionalGeneration .from_pretrained (
49+ finetuned_model_path ,
50+ torch_dtype = torch .float32 ,
51+ low_cpu_mem_usage = True ,
52+ )
53+
54+ # Determine target modules based on encoder/decoder flags
55+ target_modules = []
56+ if encoder_only and decoder_only :
57+ # Both encoder and decoder
58+ target_modules = ["q_proj" , "v_proj" , "k_proj" , "out_proj" , "fc1" , "fc2" ]
59+ elif encoder_only :
60+ # Only encoder modules
61+ target_modules = ["q_proj" , "v_proj" , "k_proj" , "out_proj" , "fc1" , "fc2" ]
62+ logger .info ("Extracting LoRA weights for encoder only" )
63+ elif decoder_only :
64+ # Only decoder modules
65+ target_modules = ["q_proj" , "v_proj" , "k_proj" , "out_proj" , "fc1" , "fc2" ]
66+ logger .info ("Extracting LoRA weights for decoder only" )
67+ else :
68+ # Default: both encoder and decoder
69+ target_modules = ["q_proj" , "v_proj" , "k_proj" , "out_proj" , "fc1" , "fc2" ]
70+
71+ # Create LoRA configuration
72+ lora_config = LoraConfig (
73+ r = rank ,
74+ lora_alpha = rank ,
75+ target_modules = target_modules ,
76+ lora_dropout = 0.0 ,
77+ bias = "none" ,
78+ task_type = TaskType .SEQ_2_SEQ_LM ,
79+ )
80+
81+ # Create a PEFT model with the base model
82+ logger .info ("Creating PEFT model with base model" )
83+ peft_model = get_peft_model (base_model , lora_config )
84+
85+ # Extract the differences between the models
86+ logger .info ("Extracting differences between models" )
87+ diff_state_dict = {}
88+
89+ # Get state dicts
90+ base_state_dict = base_model .state_dict ()
91+ finetuned_state_dict = finetuned_model .state_dict ()
92+
93+ # Filter modules based on encoder/decoder flags
94+ for key in base_state_dict :
95+ if encoder_only and "encoder" not in key :
96+ continue
97+ if decoder_only and "decoder" not in key :
98+ continue
99+
100+ if key in finetuned_state_dict :
101+ # Check if this is a target module
102+ module_type = key .split ("." )[- 1 ]
103+ if module_type in target_modules or module_type == "weight" and key .split ("." )[- 2 ] in target_modules :
104+ # Calculate difference
105+ logger .info (f"Processing { key } " )
106+ diff = finetuned_state_dict [key ] - base_state_dict [key ]
107+ diff_state_dict [key ] = diff
108+
109+ # Convert differences to LoRA weights
110+ logger .info ("Converting differences to LoRA weights" )
111+ lora_state_dict = {}
112+
113+ # Create output directory if it doesn't exist
114+ os .makedirs (output_path , exist_ok = True )
115+
116+ # Save adapter config
117+ with open (os .path .join (output_path , "adapter_config.json" ), "w" ) as f :
118+ f .write (peft_model .config .to_json_string ())
119+
120+ # Save the model
121+ logger .info (f"Saving LoRA adapter to { output_path } " )
122+ peft_model .save_pretrained (output_path )
123+
124+ logger .info ("LoRA extraction completed successfully" )
125+
126+ def main ():
127+ parser = argparse .ArgumentParser (description = "Extract LoRA weights from a fine-tuned Whisper model" )
128+ parser .add_argument ("--model" , required = True , help = "Fine-tuned Whisper model path" )
129+ parser .add_argument ("--base-model" , required = True , help = "Base Whisper model path" )
130+ parser .add_argument ("--out-path" , required = True , help = "Output path for extracted LoRA adapter" )
131+ parser .add_argument ("--max-rank" , type = int , default = 32 , help = "Maximum rank for LoRA decomposition" )
132+ parser .add_argument ("--encoder-only" , action = "store_true" , help = "Extract LoRA only for encoder weights" )
133+ parser .add_argument ("--decoder-only" , action = "store_true" , help = "Extract LoRA only for decoder weights" )
134+
135+ args = parser .parse_args ()
136+
137+ extract_lora (
138+ base_model_path = args .base_model ,
139+ finetuned_model_path = args .model ,
140+ output_path = args .out_path ,
141+ rank = args .max_rank ,
142+ encoder_only = args .encoder_only ,
143+ decoder_only = args .decoder_only ,
144+ )
145+
146+ if __name__ == "__main__" :
147+ main ()
0 commit comments