-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexport_voxtral_ggml.py
More file actions
350 lines (307 loc) · 13 KB
/
export_voxtral_ggml.py
File metadata and controls
350 lines (307 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""Export Voxtral-Mini-4B-Realtime-2602 to ExecuTorch with GGML backend.
Reuses model wrappers from the upstream Voxtral export script and adds
GGML-specific lowering.
Usage:
python export_voxtral_ggml.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602
"""
import os
import torch
import torch.nn as nn
from torch.export import Dim, export
from executorch.examples.models.voxtral_realtime.export_voxtral_rt import (
AudioEncoderExport,
TextDecoderExport,
TokenEmbeddingExport,
)
from executorch.examples.models.voxtral_realtime.model import (
StreamingAudioEncoderExport,
load_model,
)
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.exir.passes import MemoryPlanningPass
from executorch_ggml import GgmlPartitioner
from executorch_ggml.passes import BF16UnsafeOpsCastPass, RemoveGraphAssertsPass
from executorch_ggml.passes.replace_copy_ops_pass import ReplaceCopyOpsPass
class MelPreprocessor(nn.Module):
"""Wraps torchaudio MelSpectrogram for portable-op export."""
def __init__(self, sample_rate=16000, n_fft=400, hop_length=160, n_mels=128):
super().__init__()
import torchaudio
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
power=2.0,
)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
# waveform: (1, N_samples) -> mel: (1, n_mels, T_mel)
mel = self.mel_spec(waveform)
mel = torch.clamp(mel, min=1e-10).log10()
mel = torch.maximum(mel, mel.max() - 8.0)
mel = (mel + 4.0) / 4.0
return mel
def export_all_ggml(model, max_seq_len, dtype=torch.float32):
"""Export all Voxtral components as ExportedPrograms."""
# Replace llama custom ops (update_cache, custom_sdpa) with standard ATen
# ops (index_copy_, F.scaled_dot_product_attention) to avoid
# auto_functionalized_v2 wrapping that breaks ExecuTorch delegation.
from executorch_ggml.modules.voxtral_attention import swap_voxtral_attention
swap_voxtral_attention(model)
programs = {}
# Use target dtype for float tensors, but preserve integer dtypes for samples
float_dtype = dtype # BF16 for float sample inputs
# Integer sample inputs should always use appropriate integer dtypes
# Preprocessor (MelSpectrogram) is computed in Python at inference time
# rather than exported into the PTE — torchaudio STFT ops don't have
# out-variant registrations needed by ExecuTorch's portable runtime.
# --- Audio encoder ---
print("\nExporting audio_encoder...")
audio_encoder = AudioEncoderExport(model)
audio_encoder.eval()
max_t_mel = 24000 # 3000 * 8
sample_mel = torch.randn(
1, model.config.num_mel_bins, max_t_mel, dtype=float_dtype # Use target float dtype
)
programs["audio_encoder"] = export(
audio_encoder,
(sample_mel,),
dynamic_shapes={"mel": {2: Dim.AUTO}},
strict=False,
)
print(f" audio_encoder exported (sample input: {sample_mel.shape})")
# --- Text decoder ---
print("\nExporting text_decoder...")
text_decoder = TextDecoderExport(model)
text_decoder.eval()
seq_dim = Dim("seq_len", min=1, max=max_seq_len)
sample_embeds = torch.randn(1, 4, model.config.dim, dtype=float_dtype)
sample_pos = torch.arange(4, dtype=torch.long)
programs["text_decoder"] = export(
text_decoder,
(sample_embeds, sample_pos),
dynamic_shapes={
"input_embeds": {1: seq_dim},
"cache_position": {0: seq_dim},
},
strict=False,
)
print(f" text_decoder exported (sample input: {sample_embeds.shape})")
# --- Token embedding ---
print("\nExporting token_embedding...")
tok_emb = TokenEmbeddingExport(model)
tok_emb.eval()
tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len)
sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
programs["token_embedding"] = export(
tok_emb,
(sample_ids,),
dynamic_shapes={"token_ids": {1: tok_seq_dim}},
strict=False,
)
print(f" token_embedding exported (sample input: {sample_ids.shape})")
metadata = {
"sample_rate": 16000,
"num_mel_bins": model.config.num_mel_bins,
"hop_length": 160,
"window_size": 400,
"downsample_factor": model.config.downsample_factor,
"dim": model.config.dim,
"vocab_size": model.config.vocab_size,
"max_seq_len": max_seq_len,
}
return programs, metadata
def export_streaming_ggml(model, max_seq_len, max_enc_len=750, dtype=torch.float32):
"""Export streaming Voxtral components as ExportedPrograms."""
programs = {}
param_dtype = dtype
float_dtype = dtype # Use target float dtype for sample inputs
# Preprocessor (MelSpectrogram) computed in Python at inference time.
# --- Streaming audio encoder ---
print("\nExporting encode_audio_chunk...")
streaming_enc = StreamingAudioEncoderExport(model, max_enc_len=max_enc_len)
# Swap custom KVCache/SDPA with standard ops for GGML export
from executorch.examples.models.voxtral_realtime.model import KVCache, SDPA
from executorch_ggml.modules.voxtral_attention import IndexCopyKVCache, StandardSDPA
for i, kv in enumerate(streaming_enc.kv_caches):
if isinstance(kv, KVCache):
new_kv = IndexCopyKVCache(kv.max_seq_len, kv.n_kv_heads, kv.head_dim)
new_kv.k_cache = kv.k_cache
new_kv.v_cache = kv.v_cache
streaming_enc.kv_caches[i] = new_kv
if isinstance(streaming_enc.sdpa, SDPA):
streaming_enc.sdpa = StandardSDPA(
streaming_enc.n_heads, streaming_enc.n_heads, streaming_enc.head_dim,
)
print(" Swapped encoder KVCache/SDPA for GGML export")
# CRITICAL FIX: Only convert float parameters to BF16, preserve integer types
if param_dtype == torch.bfloat16:
print(" Converting only float parameters to BF16, preserving integer types...")
for name, param in streaming_enc.named_parameters():
if param.dtype.is_floating_point:
param.data = param.data.to(param_dtype)
print(f" {name}: {param.dtype} -> converted to BF16")
else:
print(f" {name}: {param.dtype} -> preserved (integer type)")
# Also handle buffers (position embeddings, etc.)
for name, buffer in streaming_enc.named_buffers():
if buffer.dtype.is_floating_point:
buffer.data = buffer.data.to(param_dtype)
print(f" {name}: {buffer.dtype} -> converted to BF16")
else:
print(f" {name}: {buffer.dtype} -> preserved (integer type)")
else:
streaming_enc.to(dtype=param_dtype)
streaming_enc.eval()
sample_mel_chunk = torch.randn(
1, model.config.num_mel_bins, 8, dtype=float_dtype
)
sample_enc_pos = torch.arange(4, dtype=torch.long)
programs["encode_audio_chunk"] = export(
streaming_enc,
(sample_mel_chunk, sample_enc_pos),
dynamic_shapes=None,
strict=False,
)
print(f" encode_audio_chunk exported (fixed shapes: mel_chunk={sample_mel_chunk.shape})")
# --- Text decoder ---
print("\nExporting text_decoder...")
text_decoder = TextDecoderExport(model)
text_decoder.eval()
seq_dim = Dim("seq_len", min=1, max=max_seq_len)
sample_embeds = torch.randn(1, 4, model.config.dim, dtype=float_dtype)
sample_pos = torch.arange(4, dtype=torch.long)
programs["text_decoder"] = export(
text_decoder,
(sample_embeds, sample_pos),
dynamic_shapes={
"input_embeds": {1: seq_dim},
"cache_position": {0: seq_dim},
},
strict=False,
)
print(f" text_decoder exported (sample input: {sample_embeds.shape})")
# --- Token embedding ---
print("\nExporting token_embedding...")
tok_emb = TokenEmbeddingExport(model)
tok_emb.eval()
tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len)
sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
programs["token_embedding"] = export(
tok_emb,
(sample_ids,),
dynamic_shapes={"token_ids": {1: tok_seq_dim}},
strict=False,
)
print(f" token_embedding exported (sample input: {sample_ids.shape})")
hop_length = 160
n_fft = 400
sample_rate = 16000
frame_rate = 12.5
step_samples = int(sample_rate / frame_rate)
stft_left_overlap = ((n_fft // 2 + hop_length - 1) // hop_length) * hop_length
mel_skip_frames = stft_left_overlap // hop_length
chunk_mel_len = 8
stft_right_lookahead = (
(chunk_mel_len - 1) * hop_length + n_fft // 2 - chunk_mel_len * hop_length
)
metadata = {
"sample_rate": sample_rate,
"num_mel_bins": model.config.num_mel_bins,
"hop_length": hop_length,
"window_size": n_fft,
"downsample_factor": model.config.downsample_factor,
"dim": model.config.dim,
"enc_dim": model.config.enc_dim,
"vocab_size": model.config.vocab_size,
"max_seq_len": max_seq_len,
"streaming": 1,
"step_samples": step_samples,
"chunk_mel_len": chunk_mel_len,
"max_enc_len": max_enc_len,
"conv1_pad": 2,
"conv2_pad": 2,
"stft_left_overlap": stft_left_overlap,
"stft_right_lookahead": stft_right_lookahead,
"mel_skip_frames": mel_skip_frames,
}
return programs, metadata
def lower_to_ggml(programs, metadata=None, quant_config=None, target_dtype=None):
"""Lower exported programs to ExecuTorch with GGML backend."""
print("\nLowering to ExecuTorch with GGML backend...")
# Streaming encoder uses data-dependent scalar_tensor ops for KV cache
# masks that the GGML backend can't handle. Lower it separately with
# portable ops (no delegation), then merge with the GGML-delegated
# decoder/embedding methods.
encoder_prog = programs.pop("encode_audio_chunk", None)
partitioner = {key: [GgmlPartitioner(quant_config=quant_config)] for key in programs}
constant_methods = {}
if metadata:
for key, value in metadata.items():
constant_methods[key] = value
# Build transform passes based on target dtype
transform_passes = [ReplaceCopyOpsPass(), RemoveGraphAssertsPass()]
# Precision-safe cast pass: protects integer index operations from BF16 corruption
if target_dtype == "BF16":
from executorch_ggml.passes.bf16_cast_pass import BF16UnsafeOpsCastPass
transform_passes.insert(0, BF16UnsafeOpsCastPass())
print(" Using BF16UnsafeOpsCastPass to protect index operations from corruption")
elif target_dtype == "FP16":
print(" FP16 cast pass not yet implemented, using F32")
if encoder_prog is not None:
# Lower encoder separately with portable backend (no delegation).
# Detach parameters to avoid "leaf Variable requires grad" during
# edge pass execution (the encoder graph has in-place KV cache ops).
gm = encoder_prog.graph_module
for name, param in list(gm.named_parameters()):
parts = name.split(".")
mod = gm
for p in parts[:-1]:
mod = getattr(mod, p)
setattr(mod, parts[-1], torch.nn.Parameter(param.detach()))
print(" Lowering encode_audio_chunk (portable, no delegation)...")
enc_edge = to_edge_transform_and_lower(
{"encode_audio_chunk": encoder_prog},
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
)
# Lower decoder + embedding with GGML backend.
print(" Lowering decoder + embedding (GGML)...")
dec_edge = to_edge_transform_and_lower(
programs,
transform_passes=transform_passes,
partitioner=partitioner,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=constant_methods if constant_methods else None,
)
# Merge: take encoder from enc_edge, rest from dec_edge.
# EdgeProgramManager stores programs in ._edge_programs dict.
for name, prog in enc_edge._edge_programs.items():
dec_edge._edge_programs[name] = prog
et_prog = dec_edge
else:
et_prog = to_edge_transform_and_lower(
programs,
transform_passes=transform_passes,
partitioner=partitioner,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
),
constant_methods=constant_methods if constant_methods else None,
)
return et_prog.to_executorch(
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
),
)