Skip to content

Commit 43ec3ed

Browse files
committed
[quantization] Apply SpinQuant to Qwen3-VL
This commit applies SpinQuant to Qwen3-VL. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 43b2f1e commit 43ec3ed

11 files changed

Lines changed: 2419 additions & 2 deletions

File tree

test/quantization/algorithm/test_qwen3_vl_spinquant.py

Lines changed: 716 additions & 0 deletions
Large diffs are not rendered by default.

tico/quantization/algorithm/spinquant/quantizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,9 @@ def _copy_runtime_attributes(
360360

361361
if hasattr(src, "config"):
362362
dst.config = src.config
363+
364+
365+
# Register additional SpinQuant variants.
366+
from tico.quantization.algorithm.spinquant.qwen3_vl_quantizer import ( # noqa: E402,F401
367+
Qwen3VLSpinQuantQuantizer,
368+
)
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any
17+
18+
import torch.nn as nn
19+
20+
from tico.quantization.config.qwen3_vl_spinquant import Qwen3VLSpinQuantConfig
21+
22+
23+
@dataclass
24+
class Qwen3VLSpinQuantComponents:
25+
"""
26+
Resolved Qwen3-VL module references required by SpinQuant.
27+
28+
Attributes:
29+
language_model: Qwen3-VL text model.
30+
text_layers: Text decoder layers.
31+
lm_head: Final language modeling head.
32+
visual_deepstack_mergers: DeepStack visual merger modules.
33+
"""
34+
35+
language_model: nn.Module
36+
text_layers: nn.ModuleList
37+
lm_head: nn.Linear
38+
visual_deepstack_mergers: nn.ModuleList
39+
40+
41+
def get_module_by_path(root: nn.Module, path: str) -> Any:
42+
"""
43+
Resolve a dotted attribute path from a root module.
44+
45+
Parameters:
46+
root: Root module.
47+
path: Dotted path such as ``"model.language_model.layers"``.
48+
49+
Returns:
50+
The resolved object.
51+
52+
Raises:
53+
AttributeError: If the path cannot be fully resolved.
54+
ValueError: If the path is empty.
55+
"""
56+
if not path:
57+
raise ValueError("path must be a non-empty string.")
58+
59+
current: Any = root
60+
for part in path.split("."):
61+
if isinstance(current, (nn.ModuleList, list, tuple)) and part.isdigit():
62+
index = int(part)
63+
try:
64+
current = current[index]
65+
except IndexError as exc:
66+
raise AttributeError(
67+
f"Failed to resolve path {path!r}. Index {index} is out of range."
68+
) from exc
69+
continue
70+
71+
if not hasattr(current, part):
72+
raise AttributeError(
73+
f"Failed to resolve attribute path {path!r}. "
74+
f"Missing attribute {part!r} on object of type {type(current).__name__}."
75+
)
76+
current = getattr(current, part)
77+
78+
return current
79+
80+
81+
def require_linear_attr(module: nn.Module, attr_name: str) -> nn.Linear:
82+
"""
83+
Return a required Linear attribute from a module.
84+
85+
Parameters:
86+
module: Parent module.
87+
attr_name: Attribute name.
88+
89+
Returns:
90+
The resolved Linear module.
91+
92+
Raises:
93+
AttributeError: If the attribute is missing.
94+
TypeError: If the attribute is not an nn.Linear.
95+
"""
96+
if not hasattr(module, attr_name):
97+
raise AttributeError(
98+
f"Expected attribute {attr_name!r} on module {type(module).__name__}."
99+
)
100+
101+
value = getattr(module, attr_name)
102+
if not isinstance(value, nn.Linear):
103+
raise TypeError(
104+
f"Expected {attr_name!r} to be nn.Linear, got {type(value).__name__}."
105+
)
106+
107+
return value
108+
109+
110+
def resolve_qwen3_vl_spinquant_components(
111+
model: nn.Module,
112+
config: Qwen3VLSpinQuantConfig,
113+
) -> Qwen3VLSpinQuantComponents:
114+
"""
115+
Resolve Qwen3-VL modules required by SpinQuant.
116+
117+
Parameters:
118+
model: Target model.
119+
config: Qwen3-VL SpinQuant configuration.
120+
121+
Returns:
122+
Resolved component references.
123+
124+
Raises:
125+
TypeError: If a resolved module has an unexpected type.
126+
"""
127+
language_model = get_module_by_path(model, config.language_model_attr)
128+
text_layers = get_module_by_path(model, config.text_layers_attr)
129+
lm_head = get_module_by_path(model, config.lm_head_attr)
130+
131+
if config.fuse_deepstack_visual_outputs:
132+
visual_deepstack_mergers = get_module_by_path(
133+
model,
134+
config.visual_deepstack_mergers_attr,
135+
)
136+
else:
137+
visual_deepstack_mergers = nn.ModuleList()
138+
139+
if not isinstance(language_model, nn.Module):
140+
raise TypeError(
141+
f"{config.language_model_attr!r} must resolve to nn.Module, "
142+
f"got {type(language_model).__name__}."
143+
)
144+
145+
if not isinstance(text_layers, nn.ModuleList):
146+
raise TypeError(
147+
f"{config.text_layers_attr!r} must resolve to nn.ModuleList, "
148+
f"got {type(text_layers).__name__}."
149+
)
150+
151+
if not isinstance(lm_head, nn.Linear):
152+
raise TypeError(
153+
f"{config.lm_head_attr!r} must resolve to nn.Linear, "
154+
f"got {type(lm_head).__name__}."
155+
)
156+
157+
if not isinstance(visual_deepstack_mergers, nn.ModuleList):
158+
raise TypeError(
159+
f"{config.visual_deepstack_mergers_attr!r} must resolve to nn.ModuleList, "
160+
f"got {type(visual_deepstack_mergers).__name__}."
161+
)
162+
163+
return Qwen3VLSpinQuantComponents(
164+
language_model=language_model,
165+
text_layers=text_layers,
166+
lm_head=lm_head,
167+
visual_deepstack_mergers=visual_deepstack_mergers,
168+
)
169+
170+
171+
def is_tied_word_embedding(
172+
model: nn.Module,
173+
config: Qwen3VLSpinQuantConfig,
174+
) -> bool:
175+
"""
176+
Return whether the Qwen3-VL input embedding and LM head share storage.
177+
178+
Parameters:
179+
model: Target model.
180+
config: Qwen3-VL SpinQuant configuration.
181+
182+
Returns:
183+
True if the two weights share the same data pointer.
184+
"""
185+
components = resolve_qwen3_vl_spinquant_components(model, config)
186+
187+
if not hasattr(components.language_model, "embed_tokens"):
188+
return False
189+
190+
embed_tokens = components.language_model.embed_tokens
191+
if not isinstance(embed_tokens, nn.Embedding):
192+
return False
193+
194+
return embed_tokens.weight.data_ptr() == components.lm_head.weight.data_ptr()
195+
196+
197+
def assert_tied_word_embedding(
198+
model: nn.Module,
199+
config: Qwen3VLSpinQuantConfig,
200+
) -> None:
201+
"""
202+
Validate that Qwen3-VL input embedding and LM head are tied.
203+
204+
Parameters:
205+
model: Target model.
206+
config: Qwen3-VL SpinQuant configuration.
207+
208+
Raises:
209+
ValueError: If the weights are not tied.
210+
"""
211+
if not is_tied_word_embedding(model, config):
212+
raise ValueError(
213+
"Qwen3-VL SpinQuant assumes tied word embeddings, but "
214+
"`model.language_model.embed_tokens.weight` and `lm_head.weight` "
215+
"do not share storage."
216+
)
217+
218+
219+
def validate_qwen3_vl_for_spinquant(
220+
model: nn.Module,
221+
config: Qwen3VLSpinQuantConfig,
222+
*,
223+
require_spin_runtime: bool = False,
224+
) -> None:
225+
"""
226+
Validate that a model exposes the modules required by Qwen3-VL SpinQuant.
227+
228+
Parameters:
229+
model: Target model.
230+
config: Qwen3-VL SpinQuant configuration.
231+
require_spin_runtime: Whether to require added SpinQuant runtime layers.
232+
233+
Raises:
234+
TypeError: If the input is not a module or a submodule has an invalid type.
235+
ValueError: If the model type or tied embedding assumption is invalid.
236+
AttributeError: If a required module is missing.
237+
"""
238+
if not isinstance(model, nn.Module):
239+
raise TypeError(f"Expected an nn.Module, got {type(model).__name__}.")
240+
241+
model_config = getattr(model, "config", None)
242+
model_type = getattr(model_config, "model_type", None)
243+
if model_type != "qwen3_vl":
244+
raise ValueError(
245+
"Qwen3-VL SpinQuant supports only Qwen3-VL dense models, "
246+
f"but got model_type={model_type!r}."
247+
)
248+
249+
if not hasattr(model_config, "text_config"):
250+
raise ValueError("Qwen3-VL SpinQuant requires `model.config.text_config`.")
251+
252+
components = resolve_qwen3_vl_spinquant_components(model, config)
253+
254+
if not hasattr(components.language_model, "embed_tokens"):
255+
raise AttributeError("Expected language model to expose `embed_tokens`.")
256+
if not isinstance(components.language_model.embed_tokens, nn.Embedding):
257+
raise TypeError(
258+
"Expected language_model.embed_tokens to be nn.Embedding, "
259+
f"got {type(components.language_model.embed_tokens).__name__}."
260+
)
261+
262+
if not hasattr(components.language_model, "norm"):
263+
raise AttributeError("Expected language model to expose final `norm`.")
264+
265+
for layer_idx, layer in enumerate(components.text_layers):
266+
if not hasattr(layer, "self_attn"):
267+
raise AttributeError(f"Text layer {layer_idx} is missing `self_attn`.")
268+
if not hasattr(layer, "mlp"):
269+
raise AttributeError(f"Text layer {layer_idx} is missing `mlp`.")
270+
if not hasattr(layer, "input_layernorm"):
271+
raise AttributeError(
272+
f"Text layer {layer_idx} is missing `input_layernorm`."
273+
)
274+
if not hasattr(layer, "post_attention_layernorm"):
275+
raise AttributeError(
276+
f"Text layer {layer_idx} is missing `post_attention_layernorm`."
277+
)
278+
279+
for attr_name in ("q_proj", "k_proj", "v_proj", "o_proj"):
280+
require_linear_attr(layer.self_attn, attr_name)
281+
282+
for attr_name in ("gate_proj", "up_proj", "down_proj"):
283+
require_linear_attr(layer.mlp, attr_name)
284+
285+
if config.fuse_deepstack_visual_outputs:
286+
for merger_idx, merger in enumerate(components.visual_deepstack_mergers):
287+
try:
288+
require_linear_attr(merger, "linear_fc2")
289+
except Exception as exc:
290+
raise type(exc)(
291+
f"Invalid DeepStack merger {merger_idx}: {exc}"
292+
) from exc
293+
294+
assert_tied_word_embedding(model, config)
295+
296+
if require_spin_runtime:
297+
require_linear_attr(components.language_model, "rotate_embedding")
298+
require_linear_attr(model, "rotate_lm_head")

0 commit comments

Comments
 (0)