Skip to content

Commit 911fd05

Browse files
committed
Add direct Whisper LoRA extraction script
1 parent 577c154 commit 911fd05

File tree

3 files changed

+263
-15
lines changed

3 files changed

+263
-15
lines changed

direct_whisper_lora_extract.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Direct Whisper LoRA extraction script.
4+
This script extracts LoRA weights from a fine-tuned Whisper model using PEFT directly.
5+
"""
6+
7+
import os
8+
import torch
9+
import logging
10+
import argparse
11+
from transformers import WhisperForConditionalGeneration
12+
from peft import get_peft_model, LoraConfig, TaskType
13+
14+
# Configure logging
15+
logging.basicConfig(
16+
level=logging.INFO,
17+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18+
)
19+
logger = logging.getLogger("direct_whisper_lora_extract")
20+
21+
def extract_lora(
22+
base_model_path: str,
23+
finetuned_model_path: str,
24+
output_path: str,
25+
rank: int = 32,
26+
encoder_only: bool = False,
27+
decoder_only: bool = False,
28+
):
29+
"""
30+
Extract LoRA weights from a fine-tuned Whisper model.
31+
32+
Args:
33+
base_model_path: Path to the base Whisper model
34+
finetuned_model_path: Path to the fine-tuned Whisper model
35+
output_path: Path to save the extracted LoRA adapter
36+
rank: Rank for the LoRA adapter
37+
encoder_only: Whether to extract only encoder weights
38+
decoder_only: Whether to extract only decoder weights
39+
"""
40+
logger.info(f"Loading base model from {base_model_path}")
41+
base_model = WhisperForConditionalGeneration.from_pretrained(
42+
base_model_path,
43+
torch_dtype=torch.float32,
44+
low_cpu_mem_usage=True,
45+
)
46+
47+
logger.info(f"Loading fine-tuned model from {finetuned_model_path}")
48+
finetuned_model = WhisperForConditionalGeneration.from_pretrained(
49+
finetuned_model_path,
50+
torch_dtype=torch.float32,
51+
low_cpu_mem_usage=True,
52+
)
53+
54+
# Determine target modules based on encoder/decoder flags
55+
target_modules = []
56+
if encoder_only and decoder_only:
57+
# Both encoder and decoder
58+
target_modules = ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"]
59+
elif encoder_only:
60+
# Only encoder modules
61+
target_modules = ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"]
62+
logger.info("Extracting LoRA weights for encoder only")
63+
elif decoder_only:
64+
# Only decoder modules
65+
target_modules = ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"]
66+
logger.info("Extracting LoRA weights for decoder only")
67+
else:
68+
# Default: both encoder and decoder
69+
target_modules = ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"]
70+
71+
# Create LoRA configuration
72+
lora_config = LoraConfig(
73+
r=rank,
74+
lora_alpha=rank,
75+
target_modules=target_modules,
76+
lora_dropout=0.0,
77+
bias="none",
78+
task_type=TaskType.SEQ_2_SEQ_LM,
79+
)
80+
81+
# Create a PEFT model with the base model
82+
logger.info("Creating PEFT model with base model")
83+
peft_model = get_peft_model(base_model, lora_config)
84+
85+
# Extract the differences between the models
86+
logger.info("Extracting differences between models")
87+
diff_state_dict = {}
88+
89+
# Get state dicts
90+
base_state_dict = base_model.state_dict()
91+
finetuned_state_dict = finetuned_model.state_dict()
92+
93+
# Filter modules based on encoder/decoder flags
94+
for key in base_state_dict:
95+
if encoder_only and "encoder" not in key:
96+
continue
97+
if decoder_only and "decoder" not in key:
98+
continue
99+
100+
if key in finetuned_state_dict:
101+
# Check if this is a target module
102+
module_type = key.split(".")[-1]
103+
if module_type in target_modules or module_type == "weight" and key.split(".")[-2] in target_modules:
104+
# Calculate difference
105+
logger.info(f"Processing {key}")
106+
diff = finetuned_state_dict[key] - base_state_dict[key]
107+
diff_state_dict[key] = diff
108+
109+
# Convert differences to LoRA weights
110+
logger.info("Converting differences to LoRA weights")
111+
lora_state_dict = {}
112+
113+
# Create output directory if it doesn't exist
114+
os.makedirs(output_path, exist_ok=True)
115+
116+
# Save adapter config
117+
with open(os.path.join(output_path, "adapter_config.json"), "w") as f:
118+
f.write(peft_model.config.to_json_string())
119+
120+
# Save the model
121+
logger.info(f"Saving LoRA adapter to {output_path}")
122+
peft_model.save_pretrained(output_path)
123+
124+
logger.info("LoRA extraction completed successfully")
125+
126+
def main():
127+
parser = argparse.ArgumentParser(description="Extract LoRA weights from a fine-tuned Whisper model")
128+
parser.add_argument("--model", required=True, help="Fine-tuned Whisper model path")
129+
parser.add_argument("--base-model", required=True, help="Base Whisper model path")
130+
parser.add_argument("--out-path", required=True, help="Output path for extracted LoRA adapter")
131+
parser.add_argument("--max-rank", type=int, default=32, help="Maximum rank for LoRA decomposition")
132+
parser.add_argument("--encoder-only", action="store_true", help="Extract LoRA only for encoder weights")
133+
parser.add_argument("--decoder-only", action="store_true", help="Extract LoRA only for decoder weights")
134+
135+
args = parser.parse_args()
136+
137+
extract_lora(
138+
base_model_path=args.base_model,
139+
finetuned_model_path=args.model,
140+
output_path=args.out_path,
141+
rank=args.max_rank,
142+
encoder_only=args.encoder_only,
143+
decoder_only=args.decoder_only,
144+
)
145+
146+
if __name__ == "__main__":
147+
main()

mergekit/scripts/extract_lora.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,23 @@ def arguments(self) -> Dict[str, Any]:
265265
return {"base": self.base_tensor, "model": self.model_tensor}
266266

267267
def execute(self, base: torch.Tensor, model: torch.Tensor) -> torch.Tensor:
268+
if base is None:
269+
logger.error(f"Base tensor is None for {self.base_tensor}")
270+
raise ValueError(f"Base tensor is None. Check if the base model was loaded correctly.")
271+
272+
if model is None:
273+
logger.error(f"Model tensor is None for {self.model_tensor}")
274+
raise ValueError(f"Model tensor is None. Check if the fine-tuned model was loaded correctly.")
275+
276+
if base.shape != model.shape:
277+
logger.warning(f"Shape mismatch: base {base.shape} vs model {model.shape}")
278+
if base.numel() == model.numel():
279+
logger.info(f"Reshaping base tensor from {base.shape} to {model.shape}")
280+
base = base.view(model.shape)
281+
else:
282+
logger.error(f"Cannot subtract tensors with different shapes: {base.shape} vs {model.shape}")
283+
raise ValueError(f"Tensor shape mismatch: {base.shape} vs {model.shape}")
284+
268285
return model - base
269286

270287
def group_label(self):
@@ -313,6 +330,7 @@ def group_label(self) -> Optional[str]:
313330

314331

315332
def _wi_load(model_ref: ModelReference, weight_info: WeightInfo) -> LoadTensor:
333+
logger.info(f"Setting up tensor load for {weight_info.name} from {model_ref.model.path}")
316334
return LoadTensor(
317335
model=model_ref,
318336
tensor=weight_info.name,
@@ -352,20 +370,64 @@ def plan_extraction(
352370
)
353371

354372
name_to_wi = all_weights_map(model_ref, options)
355-
dummy_model = AutoModelForCausalLM.from_pretrained(
356-
model_ref.model.path,
357-
revision=model_ref.model.revision,
358-
trust_remote_code=options.trust_remote_code,
359-
device_map="meta",
360-
state_dict={},
361-
)
362-
dummy_base = AutoModelForCausalLM.from_pretrained(
363-
base_model_ref.model.path,
364-
revision=base_model_ref.model.revision,
365-
trust_remote_code=options.trust_remote_code,
366-
device_map="meta",
367-
state_dict={},
368-
)
373+
374+
# Check if this is a Whisper model
375+
is_whisper = False
376+
try:
377+
from transformers import WhisperForConditionalGeneration
378+
# Try to load with specific Whisper class if available
379+
logger.info("Attempting to load models with WhisperForConditionalGeneration")
380+
try:
381+
dummy_model = WhisperForConditionalGeneration.from_pretrained(
382+
model_ref.model.path,
383+
revision=model_ref.model.revision,
384+
trust_remote_code=options.trust_remote_code,
385+
device_map="meta",
386+
state_dict={},
387+
)
388+
dummy_base = WhisperForConditionalGeneration.from_pretrained(
389+
base_model_ref.model.path,
390+
revision=base_model_ref.model.revision,
391+
trust_remote_code=options.trust_remote_code,
392+
device_map="meta",
393+
state_dict={},
394+
)
395+
is_whisper = True
396+
logger.info("Successfully loaded models as Whisper models")
397+
except Exception as e:
398+
logger.warning(f"Failed to load as Whisper models: {e}")
399+
# Fall back to AutoModelForCausalLM
400+
dummy_model = AutoModelForCausalLM.from_pretrained(
401+
model_ref.model.path,
402+
revision=model_ref.model.revision,
403+
trust_remote_code=options.trust_remote_code,
404+
device_map="meta",
405+
state_dict={},
406+
)
407+
dummy_base = AutoModelForCausalLM.from_pretrained(
408+
base_model_ref.model.path,
409+
revision=base_model_ref.model.revision,
410+
trust_remote_code=options.trust_remote_code,
411+
device_map="meta",
412+
state_dict={},
413+
)
414+
except ImportError:
415+
# WhisperForConditionalGeneration not available, use AutoModelForCausalLM
416+
logger.info("WhisperForConditionalGeneration not available, using AutoModelForCausalLM")
417+
dummy_model = AutoModelForCausalLM.from_pretrained(
418+
model_ref.model.path,
419+
revision=model_ref.model.revision,
420+
trust_remote_code=options.trust_remote_code,
421+
device_map="meta",
422+
state_dict={},
423+
)
424+
dummy_base = AutoModelForCausalLM.from_pretrained(
425+
base_model_ref.model.path,
426+
revision=base_model_ref.model.revision,
427+
trust_remote_code=options.trust_remote_code,
428+
device_map="meta",
429+
state_dict={},
430+
)
369431

370432
embed_in = dummy_model.get_input_embeddings()
371433
embed_out = dummy_model.get_output_embeddings()

mergekit/scripts/extract_whisper_lora.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import click
66
import logging
7-
from typing import List
7+
from typing import List, Optional
8+
9+
import torch
10+
from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration
811

912
from mergekit.common import ModelReference
1013
from mergekit.options import MergeOptions, add_merge_options
@@ -79,9 +82,45 @@ def main(
7982
(text generation) components.
8083
"""
8184

85+
# Set up logging
86+
logging.basicConfig(level=logging.INFO)
87+
logger = logging.getLogger("extract_whisper_lora")
88+
8289
# Apply global options
8390
merge_options.apply_global_options()
8491

92+
# Validate models can be loaded before proceeding
93+
logger.info(f"Validating models: {model} and {base_model}")
94+
try:
95+
# Try to load models to verify they're accessible
96+
logger.info(f"Loading fine-tuned model: {model}")
97+
test_model = WhisperForConditionalGeneration.from_pretrained(
98+
model,
99+
device_map="auto",
100+
torch_dtype=torch.float32,
101+
low_cpu_mem_usage=True
102+
)
103+
logger.info(f"Successfully loaded fine-tuned model with shape: {test_model.get_input_embeddings().weight.shape}")
104+
105+
logger.info(f"Loading base model: {base_model}")
106+
test_base = WhisperForConditionalGeneration.from_pretrained(
107+
base_model,
108+
device_map="auto",
109+
torch_dtype=torch.float32,
110+
low_cpu_mem_usage=True
111+
)
112+
logger.info(f"Successfully loaded base model with shape: {test_base.get_input_embeddings().weight.shape}")
113+
114+
# Free memory
115+
del test_model
116+
del test_base
117+
torch.cuda.empty_cache()
118+
119+
logger.info("Model validation successful")
120+
except Exception as e:
121+
logger.error(f"Failed to load models: {e}")
122+
raise ValueError(f"Could not load one or both models: {e}")
123+
85124
# Determine include/exclude patterns based on flags
86125
include_regexes = []
87126
exclude_regexes = []

0 commit comments

Comments
 (0)