Skip to content

Commit 0c6614a

Browse files
RobotSailclaude
andauthored
add support for qwen3.5 vl model (#693)
* add support for qwen3.5 vl model * enable detection of VLM models and allow using non-Hopper GPUs for GPT-OSS * add broader vlm support * add general vlm support * support gemma3n * address coderabbit review comments - Fix eos_token_id truthiness check (0 is valid) - Add isinstance guard for RopeParameters in mrope detection - Add hasattr fallback for non-dict rope objects Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix CI: import sorting, pylint, and test mocks - Fix isort ordering in vlm_utils.py and model.py - Fix pylint: use 'from torch import nn', mark unused-argument - Mock needs_sdpa and get_module_class_from_name in unit tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix ruff formatting for CI version (0.12.11) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * address remaining review comments - accelerator.py: fall back to warning + default wrap policy instead of ValueError when no _no_split_modules resolve; try underlying HF model as secondary target - model.py: use torch.cuda.current_device() instead of hardcoded 0 - vlm_utils.py: add trust_remote_code param (default False) to all config-loading functions; use init_empty_weights for CausalLM shell; copy quantization metadata from VLM Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix mamba kernel comments and exception handling - Remove fabricated claim about C API incompatibility - Accurately describe the issue as PyTorch/CUDA ABI mismatch - Broaden exception handling to catch AttributeError Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1f02ea6 commit 0c6614a

6 files changed

Lines changed: 491 additions & 45 deletions

File tree

src/instructlab/training/accelerator.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,41 @@ def __getattr__(self, name):
135135

136136
def get_fsdp_config(self):
137137
is_lora = self.model.lora_config is not None
138-
block_name = next(iter(self.model._no_split_modules))
139138

140139
wrap_policy = None
141140
if is_lora > 0:
142141
wrap_policy = fsdp_auto_wrap_policy(self.model)
143142
else:
143+
# Resolve all _no_split_modules names to actual classes present
144+
# in the model. Some models (e.g. Qwen3.5) declare module names
145+
# for architectures not loaded (e.g. vision blocks in a CausalLM),
146+
# so we must filter out None results.
147+
layer_classes = set()
148+
# Try resolving against the wrapper model first, then the
149+
# underlying HF model if the first pass yields nothing.
150+
targets = [self.model]
151+
hf_model = getattr(self.model, "model", None)
152+
if hf_model is not None:
153+
targets.append(hf_model)
154+
155+
for target in targets:
156+
for block_name in self.model._no_split_modules:
157+
cls = get_module_class_from_name(target, block_name)
158+
if cls is not None:
159+
layer_classes.add(cls)
160+
if layer_classes:
161+
break
162+
163+
if not layer_classes:
164+
logger.warning(
165+
"Could not resolve any _no_split_modules "
166+
"(%s) to actual module classes in the model. "
167+
"FSDP will use the default wrap policy.",
168+
self.model._no_split_modules,
169+
)
144170
wrap_policy = partial(
145171
transformer_auto_wrap_policy,
146-
transformer_layer_cls={
147-
get_module_class_from_name(self.model, block_name),
148-
},
172+
transformer_layer_cls=layer_classes,
149173
)
150174

151175
# TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA

src/instructlab/training/model.py

Lines changed: 136 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@
4545
)
4646
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model
4747
from instructlab.training.type_definitions import ModelInputs, ModelLosses
48+
from instructlab.training.vlm_utils import (
49+
extract_causal_lm_from_vlm,
50+
has_timm_vision_tower,
51+
is_vlm_for_direct_loading,
52+
is_vlm_with_causal_lm,
53+
load_vlm_for_text_training,
54+
needs_sdpa,
55+
)
4856

4957

5058
class Model:
@@ -67,6 +75,13 @@ def __init__(
6775
# check model type & set on the mclasss
6876
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
6977
self.is_gpt_oss = is_gpt_oss(model_path)
78+
79+
# Pre-populate the Hub kernel cache with locally installed mamba_ssm
80+
# and causal_conv1d to avoid PyTorch/CUDA ABI mismatches with the
81+
# Hub-provided kernel builds.
82+
if self.is_granitemoehybrid:
83+
self._use_local_mamba_kernels()
84+
7085
if self.is_gpt_oss:
7186
# Third Party
7287
quant_config = Mxfp4Config(dequantize=True)
@@ -95,12 +110,92 @@ def __init__(
95110

96111
# set flash attention accordingly
97112
if flash_enabled:
98-
self.base_model_args["attn_implementation"] = "flash_attention_2"
99-
if self.is_gpt_oss:
100-
self.base_model_args["attn_implementation"] = (
101-
"kernels-community/vllm-flash-attn3"
113+
# Some models are incompatible with Flash Attention 2:
114+
# - M-RoPE models produce 3D position_ids that FA2 misinterprets
115+
# - Models with timm vision towers (TimmWrapperModel rejects FA2)
116+
# Detect these and fall back to SDPA.
117+
use_sdpa = needs_sdpa(model_path)
118+
if use_sdpa:
119+
logger.warning(
120+
"Disabling flash_attention_2 — model is incompatible "
121+
"(M-RoPE or timm vision tower). Using SDPA instead."
122+
)
123+
self.base_model_args["attn_implementation"] = "sdpa"
124+
else:
125+
self.base_model_args["attn_implementation"] = "flash_attention_2"
126+
if self.is_gpt_oss:
127+
# vllm-flash-attn3 requires Hopper (SM 9.0+) GPUs;
128+
# GPT-OSS only supports flash-attn3 or eager
129+
device = (
130+
torch.cuda.current_device() if torch.cuda.is_available() else 0
131+
)
132+
major, _ = torch.cuda.get_device_capability(device)
133+
if major >= 9:
134+
self.base_model_args["attn_implementation"] = (
135+
"kernels-community/vllm-flash-attn3"
136+
)
137+
else:
138+
self.base_model_args["attn_implementation"] = "eager"
139+
logger.warning(
140+
"GPT-OSS: flash-attn3 requires Hopper (SM 9.0+) GPUs, "
141+
"but found SM %d.x. Using eager attention instead.",
142+
major,
143+
)
144+
145+
# For models with timm vision towers: set vision config to eager
146+
# while keeping the text model's attention implementation.
147+
# timm's TimmWrapperModel rejects both FA2 and SDPA.
148+
if has_timm_vision_tower(model_path):
149+
attn_impl = self.base_model_args.get(
150+
"attn_implementation", "flash_attention_2"
151+
)
152+
self.base_model_args["attn_implementation"] = {
153+
"text_config": attn_impl,
154+
"vision_config": "eager",
155+
}
156+
logger.info(
157+
"Model has timm vision tower — using eager attention for vision, "
158+
"%s for text model.",
159+
attn_impl,
102160
)
103161

162+
@staticmethod
163+
def _use_local_mamba_kernels():
164+
"""Use locally installed mamba_ssm/causal_conv1d instead of Hub kernels.
165+
166+
Pre-populate the transformers Hub kernel cache with the locally
167+
installed packages to avoid PyTorch/CUDA ABI mismatches with the
168+
Hub-provided kernel builds.
169+
"""
170+
try:
171+
# Third Party
172+
from mamba_ssm.ops.triton.selective_state_update import (
173+
selective_state_update,
174+
)
175+
from mamba_ssm.ops.triton.ssd_combined import (
176+
mamba_chunk_scan_combined,
177+
mamba_split_conv1d_scan_combined,
178+
)
179+
from transformers.integrations.hub_kernels import _KERNEL_MODULE_MAPPING
180+
import causal_conv1d
181+
import mamba_ssm
182+
183+
mamba_ssm.selective_state_update = selective_state_update
184+
mamba_ssm.mamba_chunk_scan_combined = mamba_chunk_scan_combined
185+
mamba_ssm.mamba_split_conv1d_scan_combined = (
186+
mamba_split_conv1d_scan_combined
187+
)
188+
189+
_KERNEL_MODULE_MAPPING["causal-conv1d"] = causal_conv1d
190+
_KERNEL_MODULE_MAPPING["mamba-ssm"] = mamba_ssm
191+
logger.info("Using local mamba_ssm/causal_conv1d instead of Hub kernels")
192+
except (ImportError, AttributeError) as e:
193+
logger.warning(
194+
"Could not patch mamba kernels (%s); "
195+
"GraniteMoeHybrid may use Hub kernels",
196+
e,
197+
)
198+
104199
def _post_model_init(self):
105200
"""Common initialization steps that should happen after model initialization."""
106201
self.reconcile_tokenizer()
@@ -271,61 +366,61 @@ def _is_causal_lm_model(self) -> bool:
271366
bool: True if the model is a causal language model, False otherwise.
272367
"""
273368
# Third Party
274-
return "ForCausalLM" in self.model.__class__.__name__
369+
class_name = self.model.__class__.__name__
370+
return "ForCausalLM" in class_name or "ForConditionalGeneration" in class_name
371+
372+
def _get_text_config(self):
373+
"""Get the text-relevant config, falling back to text_config for VLMs."""
374+
config = self.model.config
375+
if not hasattr(config, "vocab_size") and hasattr(config, "text_config"):
376+
return config.text_config
377+
return config
275378

276379
def reconcile_tokenizer(self):
277-
if len(self.tokenizer) > self.model.config.vocab_size:
380+
text_config = self._get_text_config()
381+
if len(self.tokenizer) > text_config.vocab_size:
278382
logger.warning(
279-
f"WARNING: tokenizer has {len(self.tokenizer)} tokens but model has {self.model.config.vocab_size} vocab size"
383+
f"WARNING: tokenizer has {len(self.tokenizer)} tokens but model has {text_config.vocab_size} vocab size"
280384
)
281385
self.model.resize_token_embeddings(
282386
int(8 * math.ceil(len(self.tokenizer) / 8.0))
283387
) # make the vocab size multiple of 8 for sharding the embedding layer.
284388

285389
# Fix any discrepancy between model and tokenizer
286390
if (
287-
self.model.config.pad_token_id is not None
391+
text_config.pad_token_id is not None
288392
and self.tokenizer.pad_token_id is not None
289-
and self.model.config.pad_token_id != self.tokenizer.pad_token_id
393+
and text_config.pad_token_id != self.tokenizer.pad_token_id
290394
):
291395
logger.warning(
292-
f"WARNING: There is a mismatch between pad token id of model ({self.model.config.pad_token_id}) and tokenizer({self.tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
396+
f"WARNING: There is a mismatch between pad token id of model ({text_config.pad_token_id}) and tokenizer({self.tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
293397
)
294-
self.model.config.pad_token_id = self.tokenizer.pad_token_id
398+
text_config.pad_token_id = self.tokenizer.pad_token_id
295399
if (
296-
self.model.config.bos_token_id is not None
400+
text_config.bos_token_id is not None
297401
and self.tokenizer.bos_token_id is not None
298-
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
402+
and text_config.bos_token_id != self.tokenizer.bos_token_id
299403
):
300404
logging.warning(
301-
f"WARNING: There is a mismatch between bos token id of model({self.model.config.bos_token_id}) and tokenizer({self.tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
405+
f"WARNING: There is a mismatch between bos token id of model({text_config.bos_token_id}) and tokenizer({self.tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
302406
)
303-
self.model.config.bos_token_id = self.tokenizer.bos_token_id
407+
text_config.bos_token_id = self.tokenizer.bos_token_id
304408
if (
305-
self.model.config.eos_token_id is not None
306-
and self.tokenizer.eos_token_id
307-
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
409+
text_config.eos_token_id is not None
410+
and self.tokenizer.eos_token_id is not None
411+
and text_config.eos_token_id != self.tokenizer.eos_token_id
308412
):
309413
logger.warning(
310-
f"WARNING: There is a mismatch between eos token id of model({self.model.config.eos_token_id}) and tokenizer({self.tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
414+
f"WARNING: There is a mismatch between eos token id of model({text_config.eos_token_id}) and tokenizer({self.tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
311415
)
312-
self.model.config.eos_token_id = self.tokenizer.eos_token_id
416+
text_config.eos_token_id = self.tokenizer.eos_token_id
313417

314-
if (
315-
self.tokenizer.pad_token_id is not None
316-
and self.model.config.pad_token_id is None
317-
):
318-
self.model.config.pad_token_id = self.tokenizer.pad_token_id
319-
if (
320-
self.tokenizer.bos_token_id is not None
321-
and self.model.config.bos_token_id is None
322-
):
323-
self.model.config.bos_token_id = self.tokenizer.bos_token_id
324-
if (
325-
self.tokenizer.eos_token_id is not None
326-
and self.model.config.eos_token_id is None
327-
):
328-
self.model.config.eos_token_id = self.tokenizer.eos_token_id
418+
if self.tokenizer.pad_token_id is not None and text_config.pad_token_id is None:
419+
text_config.pad_token_id = self.tokenizer.pad_token_id
420+
if self.tokenizer.bos_token_id is not None and text_config.bos_token_id is None:
421+
text_config.bos_token_id = self.tokenizer.bos_token_id
422+
if self.tokenizer.eos_token_id is not None and text_config.eos_token_id is None:
423+
text_config.eos_token_id = self.tokenizer.eos_token_id
329424

330425
if not self._is_causal_lm_model():
331426
raise ValueError(
@@ -501,7 +596,12 @@ def __init__(
501596
lora_config=lora_config,
502597
lora_quant_bits=lora_quant_bits,
503598
)
504-
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
599+
if is_vlm_with_causal_lm(model_path):
600+
self.model = extract_causal_lm_from_vlm(model_path, self.base_model_args)
601+
elif is_vlm_for_direct_loading(model_path):
602+
self.model = load_vlm_for_text_training(model_path, self.base_model_args)
603+
else:
604+
self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
505605
self._post_model_init()
506606
self.model.gradient_checkpointing_enable()
507607

0 commit comments

Comments
 (0)