@@ -112,29 +112,25 @@ def __init__(
112112 if has_norm_in :
113113 self .norm_in = nn .RMSNorm (self .hidden_size , eps = config .rms_norm_eps )
114114
115- # Initialize vision components for VLM on first shard
116115 if self .is_vlm :
117116 logger .info (
118117 f"Initializing VLM components: vision_tower ({ self .vision_config .model_type } )"
119118 )
120119 self .vision_tower = vision_tower_class (self .vision_config )
121- # Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel
122- # In these cases, multi_modal_projector_class can be None
123120 if multi_modal_projector_class is not None :
124- # Some projectors (e.g., KimiVL) need both vision_config and text_config
125- # Create a combined config object if the projector expects it
126121 try :
127122 self .multi_modal_projector = multi_modal_projector_class (config )
128123 except (TypeError , AttributeError ):
129- # Projector expects a combined config with vision_config + text_config
130- combined_config = type ("CombinedConfig" , (), {
131- "vision_config" : self .vision_config ,
132- "text_config" : config ,
133- })()
124+ combined_config = type (
125+ "CombinedConfig" ,
126+ (),
127+ {
128+ "vision_config" : self .vision_config ,
129+ "text_config" : config ,
130+ },
131+ )()
134132 self .multi_modal_projector = multi_modal_projector_class (combined_config )
135- logger .info (
136- "Initialized projector with combined vision+text config"
137- )
133+ logger .info ("Initialized projector with combined vision+text config" )
138134 else :
139135 self .multi_modal_projector = None
140136 logger .info (
@@ -197,17 +193,10 @@ def get_input_embeddings(
197193 if not self .is_first_shard :
198194 raise ValueError ("get_input_embeddings should only be called on the first shard" )
199195
200- # Get text embeddings
201196 inputs_embeds = self .embed_tokens (input_ids )
202-
203- # If no images or not a VLM, return text embeddings directly
204197 if pixel_values is None or not self .is_vlm :
205198 return InputEmbeddingsOutput (inputs_embeds = inputs_embeds )
206-
207- # Process vision features
208199 image_features = self ._encode_images (pixel_values , ** kwargs )
209-
210- # Merge image features with text embeddings
211200 final_embeds = self ._merge_input_ids_with_image_features (
212201 image_features , inputs_embeds , input_ids
213202 )
@@ -220,20 +209,10 @@ def _encode_images(
220209 image_grid_thw : Optional [mx .array ] = None ,
221210 ** kwargs ,
222211 ) -> mx .array :
223- """Encode images through vision tower and projector.
224-
225- Args:
226- pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W)
227- image_grid_thw: Grid size (T, H, W) for Qwen-VL models
228- **kwargs: Additional model-specific arguments
229-
230- Returns:
231- Projected image features ready to be merged with text embeddings
232- """
212+ """Encode images through vision tower and projector."""
233213 if self .vision_tower is None :
234214 raise ValueError ("Vision tower not initialized for this model" )
235215
236- # Check if this is a model that uses grid_thw for vision encoding
237216 model_type = getattr (self .vision_config , "model_type" , "" ) if self .vision_config else ""
238217 is_qwen_vl = "qwen" in model_type .lower () and "vl" in model_type .lower ()
239218 is_moonvit = model_type .lower () == "moonvit"
@@ -243,74 +222,57 @@ def _encode_images(
243222 if hasattr (self .vision_tower , "patch_embed" ) and hasattr (
244223 self .vision_tower .patch_embed , "proj"
245224 ):
246- target_dtype = self .vision_tower .patch_embed .proj .weight .dtype
247- pixel_values = pixel_values .astype (target_dtype )
225+ pixel_values = pixel_values .astype (self .vision_tower .patch_embed .proj .weight .dtype )
248226 else :
249227 pixel_values = pixel_values .astype (self .dtype )
250228
251- # Get vision features from vision tower
252229 if uses_grid_thw and image_grid_thw is not None :
253230 if is_moonvit :
254- # KimiVL/MoonViT style: VisionModel expects NHWC input and grid_thw
255- # pixel_values may be NCHW from processor, convert to NHWC
231+ # MoonViT (KimiVL) expects NHWC input
256232 if pixel_values .ndim == 4 and pixel_values .shape [1 ] in [1 , 3 , 4 ]:
257233 pixel_values = pixel_values .transpose (0 , 2 , 3 , 1 )
258234 vision_outputs = self .vision_tower (
259235 pixel_values , grid_thw = image_grid_thw , output_hidden_states = True
260236 )
261237 else :
262- # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> hidden_states
263- # No format conversion needed - Qwen-VL expects flat patches
238+ # Qwen-VL expects flat patches
264239 vision_outputs = self .vision_tower (pixel_values , image_grid_thw )
265240
266241 if isinstance (vision_outputs , tuple ):
267- # First element is the merged hidden states (already projected by merger)
268242 selected_features = vision_outputs [0 ]
269243 elif isinstance (vision_outputs , list ):
270- # KimiVL patch_merger returns a list of arrays
271244 selected_features = vision_outputs
272245 else :
273246 selected_features = vision_outputs
274247 else :
275- # Standard CLIP/SigLIP style
276- # Convert to vision tower expected format (typically NHWC for MLX)
248+ # CLIP/SigLIP style: NCHW -> NHWC
277249 if pixel_values .ndim == 4 and pixel_values .shape [1 ] in [1 , 3 , 4 ]:
278- # NCHW -> NHWC
279250 pixel_values = pixel_values .transpose (0 , 2 , 3 , 1 )
280251
281252 vision_outputs = self .vision_tower (pixel_values , output_hidden_states = True )
282253
283- # Handle different output formats
284254 if isinstance (vision_outputs , tuple ):
285- # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states)
286255 if len (vision_outputs ) >= 3 :
287- hidden_states = vision_outputs [2 ] # All hidden states
256+ hidden_states = vision_outputs [2 ]
288257 if isinstance (self .vision_feature_layer , int ):
289258 selected_features = hidden_states [self .vision_feature_layer ]
290259 if self .vision_feature_select_strategy == "default" :
291- # Remove CLS token
292260 selected_features = selected_features [:, 1 :]
293261 else :
294- # Multiple layers
295262 hs_pool = [hidden_states [idx ] for idx in self .vision_feature_layer ]
296263 if self .vision_feature_select_strategy == "default" :
297264 hs_pool = [hs [:, 1 :] for hs in hs_pool ]
298265 selected_features = mx .concatenate (hs_pool , axis = - 1 )
299266 else :
300- # Simple (pooler, hidden_state) output
301267 selected_features = vision_outputs [1 ]
302268 if self .vision_feature_select_strategy == "default" :
303269 selected_features = selected_features [:, 1 :]
304270 else :
305- # Direct hidden state output
306271 selected_features = vision_outputs
307272
308- # Project to language model dimension if projector exists
309- # Qwen-VL models have projection built into VisionModel's merger
310273 if self .multi_modal_projector is not None :
311274 image_features = self .multi_modal_projector (selected_features )
312275 else :
313- # VisionModel already outputs projected features
314276 image_features = selected_features
315277
316278 return image_features
@@ -321,36 +283,19 @@ def _merge_input_ids_with_image_features(
321283 inputs_embeds : mx .array ,
322284 input_ids : mx .array ,
323285 ) -> mx .array :
324- """Merge image features into input embeddings at image token positions.
325-
326- This replaces <image> placeholder tokens with actual image feature embeddings.
327-
328- Args:
329- image_features: (num_images, num_patches, hidden_dim) or (total_patches, hidden_dim)
330- inputs_embeds: (batch, seq_len, hidden_dim) Text embeddings
331- input_ids: (batch, seq_len) Token IDs for finding image positions
332-
333- Returns:
334- Merged embeddings with image features inserted at image token positions
335- """
286+ """Replace <image> placeholder tokens with actual image feature embeddings."""
336287 if self .image_token_index is None :
337288 logger .warning ("image_token_index not set, cannot merge image features" )
338289 return inputs_embeds
339290
340291 batch_size , seq_len , hidden_dim = inputs_embeds .shape
341-
342- # Find positions of image tokens
343292 image_positions = input_ids == self .image_token_index
344293
345- # Flatten image features if needed
346294 if image_features .ndim == 3 :
347- # (num_images, num_patches, dim) -> (total_patches, dim)
348295 image_features = image_features .reshape (- 1 , image_features .shape [- 1 ])
349296
350- # Cast image features to match embedding dtype
351297 image_features = image_features .astype (inputs_embeds .dtype )
352298
353- # Process each batch item
354299 batch_outputs = []
355300 feature_start_idx = 0
356301
@@ -359,28 +304,23 @@ def _merge_input_ids_with_image_features(
359304 num_positions = int (mx .sum (batch_mask ).item ())
360305
361306 if num_positions > 0 :
362- # Extract features for this batch
363307 batch_features = image_features [
364308 feature_start_idx : feature_start_idx + num_positions
365309 ]
366310
367311 if batch_features .shape [0 ] != num_positions :
368312 raise ValueError (
369- f"Number of image token positions ({ num_positions } ) does not match "
370- f"number of image features ({ batch_features .shape [0 ]} ) for batch { batch_idx } "
313+ f"Image token positions ({ num_positions } ) does not match "
314+ f"image features ({ batch_features .shape [0 ]} ) for batch { batch_idx } "
371315 )
372316
373- # Create indices for gathering
374317 cumsum = mx .cumsum (batch_mask .astype (mx .int32 ))
375318 feature_indices = mx .where (batch_mask , cumsum - 1 , 0 )
376-
377- # Gather features and create merged output
378319 gathered_features = batch_features [feature_indices ]
379320 batch_mask_expanded = mx .expand_dims (batch_mask , axis = - 1 )
380321 batch_output = mx .where (
381322 batch_mask_expanded , gathered_features , inputs_embeds [batch_idx ]
382323 )
383-
384324 feature_start_idx += num_positions
385325 else :
386326 batch_output = inputs_embeds [batch_idx ]
0 commit comments