Skip to content

Commit 09869ee

Browse files
authored
[Bugfix] Route Gemma4 ClippableLinear clip buffers during weight loading (#352)
Gemma4ClippableLinear registers input_max/input_min/output_max/output_min as buffers rather than parameters, so AutoWeightsLoader cannot find them via named_parameters(). Intercept these weights and load them directly into the corresponding buffers before passing the remaining weights to the loader. Signed-off-by: GrootLiu <1219671600@qq.com>
1 parent 189a443 commit 09869ee

1 file changed

Lines changed: 37 additions & 1 deletion

File tree

vllm_kunlun/models/gemma4_mm.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,11 +1359,47 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
13591359
"embed_audio.",
13601360
]
13611361
)
1362+
# Gemma4ClippableLinear registers input_max/input_min/output_max/
1363+
# output_min as buffers (not parameters), so AutoWeightsLoader cannot
1364+
# find them via named_parameters(). We intercept these weights here,
1365+
# load them directly into the corresponding buffer, and hide them
1366+
# from the loader.
1367+
clip_suffixes = (
1368+
".input_max",
1369+
".output_max",
1370+
".input_min",
1371+
".output_min",
1372+
)
1373+
1374+
def _route_clip_buffers(ws):
1375+
for name, tensor in ws:
1376+
if name.endswith(clip_suffixes):
1377+
# Resolve module by hierarchical name, e.g.
1378+
# audio_tower.layers.0.feed_forward1.ffw_layer_1.input_max
1379+
module_path, _, buf_name = name.rpartition(".")
1380+
module = self
1381+
try:
1382+
for attr in module_path.split("."):
1383+
module = (
1384+
getattr(module, attr)
1385+
if not attr.isdigit()
1386+
else module[int(attr)]
1387+
)
1388+
if hasattr(module, buf_name):
1389+
buf = getattr(module, buf_name)
1390+
buf.data.copy_(tensor.to(buf.device, buf.dtype))
1391+
except (AttributeError, IndexError):
1392+
pass
1393+
continue
1394+
yield name, tensor
1395+
13621396
loader = AutoWeightsLoader(
13631397
self,
13641398
ignore_unexpected_prefixes=ignore_prefixes,
13651399
)
1366-
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
1400+
return loader.load_weights(
1401+
_route_clip_buffers(weights), mapper=self.hf_to_vllm_mapper
1402+
)
13671403

13681404
# ------------------------------------------------------------------ #
13691405
# LoRA / multimodal mapping

0 commit comments

Comments
 (0)