Skip to content

Commit 2dea44b

Browse files
committed
fix pre commit
1 parent f353ff7 commit 2dea44b

13 files changed

Lines changed: 91 additions & 201 deletions

File tree

src/parallax/models/kimi_vl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from typing import Any, List, Optional
1010

1111
import mlx.core as mx
12-
from mlx import nn
1312
from mlx_lm.models.base import scaled_dot_product_attention
1413

1514
# Import from mlx-vlm kimi_vl language module
1615
from mlx_vlm.models.kimi_vl.language import DeepseekV3Attention as MLXKimiVLAttention
17-
from mlx_vlm.models.kimi_vl.language import DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer
18-
from mlx_vlm.models.kimi_vl.language import DeepseekV3MLP, DeepseekV3MoE
16+
from mlx_vlm.models.kimi_vl.language import (
17+
DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer,
18+
)
1919

2020
from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache
2121
from parallax.server.cache.base import BaseCache

src/parallax/server/executor/factory.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,21 @@ def run_executor_process(args, shared_state=None, conn=None):
111111
"""Run executor as a subprocess"""
112112
# Set rank to suppress logs on non-zero ranks
113113
# Must be called AFTER set_log_level to override the level
114-
tp_rank = getattr(args, 'tp_rank', 0)
115-
tp_size = getattr(args, 'tp_size', 1)
116-
114+
tp_rank = getattr(args, "tp_rank", 0)
115+
tp_size = getattr(args, "tp_size", 1)
116+
117117
# For non-zero ranks, suppress logs before any imports
118118
if tp_size > 1 and tp_rank != 0:
119119
import logging
120+
120121
logging.getLogger().setLevel(logging.CRITICAL + 1)
121-
122+
122123
set_log_level(args.log_level)
123-
124+
124125
# Now set rank properly (will re-suppress for non-zero ranks)
125126
if tp_size > 1:
126127
set_rank(tp_rank, enable_filter=True)
127-
128+
128129
executor = None
129130
try:
130131
executor = create_from_args(args, shared_state, conn)

src/parallax/server/executor/mlx_executor.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,6 @@ def __init__(
159159
self.vlm_processor = None
160160
except Exception as e:
161161
logger.debug(f"AutoProcessor failed: {e}")
162-
if not processor_loaded:
163-
try:
164-
# Must import torch first to avoid flex_attention import errors in transformers
165-
from transformers import Qwen2VLProcessor
166-
167-
self.vlm_processor = Qwen2VLProcessor.from_pretrained(
168-
processor_path, trust_remote_code=True
169-
)
170-
logger.info(f"Loaded VLM processor (Qwen2VLProcessor) for {self.model_type}")
171-
processor_loaded = True
172-
except Exception as e:
173-
logger.debug(f"Qwen2VLProcessor failed: {e}")
174162

175163
if not processor_loaded:
176164
logger.warning(

src/parallax/server/model.py

Lines changed: 18 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,25 @@ def __init__(
112112
if has_norm_in:
113113
self.norm_in = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
114114

115-
# Initialize vision components for VLM on first shard
116115
if self.is_vlm:
117116
logger.info(
118117
f"Initializing VLM components: vision_tower ({self.vision_config.model_type})"
119118
)
120119
self.vision_tower = vision_tower_class(self.vision_config)
121-
# Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel
122-
# In these cases, multi_modal_projector_class can be None
123120
if multi_modal_projector_class is not None:
124-
# Some projectors (e.g., KimiVL) need both vision_config and text_config
125-
# Create a combined config object if the projector expects it
126121
try:
127122
self.multi_modal_projector = multi_modal_projector_class(config)
128123
except (TypeError, AttributeError):
129-
# Projector expects a combined config with vision_config + text_config
130-
combined_config = type("CombinedConfig", (), {
131-
"vision_config": self.vision_config,
132-
"text_config": config,
133-
})()
124+
combined_config = type(
125+
"CombinedConfig",
126+
(),
127+
{
128+
"vision_config": self.vision_config,
129+
"text_config": config,
130+
},
131+
)()
134132
self.multi_modal_projector = multi_modal_projector_class(combined_config)
135-
logger.info(
136-
"Initialized projector with combined vision+text config"
137-
)
133+
logger.info("Initialized projector with combined vision+text config")
138134
else:
139135
self.multi_modal_projector = None
140136
logger.info(
@@ -197,17 +193,10 @@ def get_input_embeddings(
197193
if not self.is_first_shard:
198194
raise ValueError("get_input_embeddings should only be called on the first shard")
199195

200-
# Get text embeddings
201196
inputs_embeds = self.embed_tokens(input_ids)
202-
203-
# If no images or not a VLM, return text embeddings directly
204197
if pixel_values is None or not self.is_vlm:
205198
return InputEmbeddingsOutput(inputs_embeds=inputs_embeds)
206-
207-
# Process vision features
208199
image_features = self._encode_images(pixel_values, **kwargs)
209-
210-
# Merge image features with text embeddings
211200
final_embeds = self._merge_input_ids_with_image_features(
212201
image_features, inputs_embeds, input_ids
213202
)
@@ -220,20 +209,10 @@ def _encode_images(
220209
image_grid_thw: Optional[mx.array] = None,
221210
**kwargs,
222211
) -> mx.array:
223-
"""Encode images through vision tower and projector.
224-
225-
Args:
226-
pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W)
227-
image_grid_thw: Grid size (T, H, W) for Qwen-VL models
228-
**kwargs: Additional model-specific arguments
229-
230-
Returns:
231-
Projected image features ready to be merged with text embeddings
232-
"""
212+
"""Encode images through vision tower and projector."""
233213
if self.vision_tower is None:
234214
raise ValueError("Vision tower not initialized for this model")
235215

236-
# Check if this is a model that uses grid_thw for vision encoding
237216
model_type = getattr(self.vision_config, "model_type", "") if self.vision_config else ""
238217
is_qwen_vl = "qwen" in model_type.lower() and "vl" in model_type.lower()
239218
is_moonvit = model_type.lower() == "moonvit"
@@ -243,74 +222,57 @@ def _encode_images(
243222
if hasattr(self.vision_tower, "patch_embed") and hasattr(
244223
self.vision_tower.patch_embed, "proj"
245224
):
246-
target_dtype = self.vision_tower.patch_embed.proj.weight.dtype
247-
pixel_values = pixel_values.astype(target_dtype)
225+
pixel_values = pixel_values.astype(self.vision_tower.patch_embed.proj.weight.dtype)
248226
else:
249227
pixel_values = pixel_values.astype(self.dtype)
250228

251-
# Get vision features from vision tower
252229
if uses_grid_thw and image_grid_thw is not None:
253230
if is_moonvit:
254-
# KimiVL/MoonViT style: VisionModel expects NHWC input and grid_thw
255-
# pixel_values may be NCHW from processor, convert to NHWC
231+
# MoonViT (KimiVL) expects NHWC input
256232
if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]:
257233
pixel_values = pixel_values.transpose(0, 2, 3, 1)
258234
vision_outputs = self.vision_tower(
259235
pixel_values, grid_thw=image_grid_thw, output_hidden_states=True
260236
)
261237
else:
262-
# Qwen-VL style: VisionModel(pixel_values, grid_thw) -> hidden_states
263-
# No format conversion needed - Qwen-VL expects flat patches
238+
# Qwen-VL expects flat patches
264239
vision_outputs = self.vision_tower(pixel_values, image_grid_thw)
265240

266241
if isinstance(vision_outputs, tuple):
267-
# First element is the merged hidden states (already projected by merger)
268242
selected_features = vision_outputs[0]
269243
elif isinstance(vision_outputs, list):
270-
# KimiVL patch_merger returns a list of arrays
271244
selected_features = vision_outputs
272245
else:
273246
selected_features = vision_outputs
274247
else:
275-
# Standard CLIP/SigLIP style
276-
# Convert to vision tower expected format (typically NHWC for MLX)
248+
# CLIP/SigLIP style: NCHW -> NHWC
277249
if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]:
278-
# NCHW -> NHWC
279250
pixel_values = pixel_values.transpose(0, 2, 3, 1)
280251

281252
vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
282253

283-
# Handle different output formats
284254
if isinstance(vision_outputs, tuple):
285-
# CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states)
286255
if len(vision_outputs) >= 3:
287-
hidden_states = vision_outputs[2] # All hidden states
256+
hidden_states = vision_outputs[2]
288257
if isinstance(self.vision_feature_layer, int):
289258
selected_features = hidden_states[self.vision_feature_layer]
290259
if self.vision_feature_select_strategy == "default":
291-
# Remove CLS token
292260
selected_features = selected_features[:, 1:]
293261
else:
294-
# Multiple layers
295262
hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer]
296263
if self.vision_feature_select_strategy == "default":
297264
hs_pool = [hs[:, 1:] for hs in hs_pool]
298265
selected_features = mx.concatenate(hs_pool, axis=-1)
299266
else:
300-
# Simple (pooler, hidden_state) output
301267
selected_features = vision_outputs[1]
302268
if self.vision_feature_select_strategy == "default":
303269
selected_features = selected_features[:, 1:]
304270
else:
305-
# Direct hidden state output
306271
selected_features = vision_outputs
307272

308-
# Project to language model dimension if projector exists
309-
# Qwen-VL models have projection built into VisionModel's merger
310273
if self.multi_modal_projector is not None:
311274
image_features = self.multi_modal_projector(selected_features)
312275
else:
313-
# VisionModel already outputs projected features
314276
image_features = selected_features
315277

316278
return image_features
@@ -321,36 +283,19 @@ def _merge_input_ids_with_image_features(
321283
inputs_embeds: mx.array,
322284
input_ids: mx.array,
323285
) -> mx.array:
324-
"""Merge image features into input embeddings at image token positions.
325-
326-
This replaces <image> placeholder tokens with actual image feature embeddings.
327-
328-
Args:
329-
image_features: (num_images, num_patches, hidden_dim) or (total_patches, hidden_dim)
330-
inputs_embeds: (batch, seq_len, hidden_dim) Text embeddings
331-
input_ids: (batch, seq_len) Token IDs for finding image positions
332-
333-
Returns:
334-
Merged embeddings with image features inserted at image token positions
335-
"""
286+
"""Replace <image> placeholder tokens with actual image feature embeddings."""
336287
if self.image_token_index is None:
337288
logger.warning("image_token_index not set, cannot merge image features")
338289
return inputs_embeds
339290

340291
batch_size, seq_len, hidden_dim = inputs_embeds.shape
341-
342-
# Find positions of image tokens
343292
image_positions = input_ids == self.image_token_index
344293

345-
# Flatten image features if needed
346294
if image_features.ndim == 3:
347-
# (num_images, num_patches, dim) -> (total_patches, dim)
348295
image_features = image_features.reshape(-1, image_features.shape[-1])
349296

350-
# Cast image features to match embedding dtype
351297
image_features = image_features.astype(inputs_embeds.dtype)
352298

353-
# Process each batch item
354299
batch_outputs = []
355300
feature_start_idx = 0
356301

@@ -359,28 +304,23 @@ def _merge_input_ids_with_image_features(
359304
num_positions = int(mx.sum(batch_mask).item())
360305

361306
if num_positions > 0:
362-
# Extract features for this batch
363307
batch_features = image_features[
364308
feature_start_idx : feature_start_idx + num_positions
365309
]
366310

367311
if batch_features.shape[0] != num_positions:
368312
raise ValueError(
369-
f"Number of image token positions ({num_positions}) does not match "
370-
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
313+
f"Image token positions ({num_positions}) does not match "
314+
f"image features ({batch_features.shape[0]}) for batch {batch_idx}"
371315
)
372316

373-
# Create indices for gathering
374317
cumsum = mx.cumsum(batch_mask.astype(mx.int32))
375318
feature_indices = mx.where(batch_mask, cumsum - 1, 0)
376-
377-
# Gather features and create merged output
378319
gathered_features = batch_features[feature_indices]
379320
batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1)
380321
batch_output = mx.where(
381322
batch_mask_expanded, gathered_features, inputs_embeds[batch_idx]
382323
)
383-
384324
feature_start_idx += num_positions
385325
else:
386326
batch_output = inputs_embeds[batch_idx]

src/parallax/server/request.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,10 @@ class VLMInputs:
7878
receive pre-computed image embeddings merged into hidden_states.
7979
"""
8080

81-
# Preprocessed image tensor, shape varies by model:
82-
# - LLaVA: (num_images, C, H, W) or (num_patches, C, patch_H, patch_W)
83-
# - Qwen-VL: (num_patches, C, patch_H, patch_W) with temporal dim for video
84-
# Can be numpy array or PyTorch tensor - mx.array() can convert both
8581
pixel_values: Optional[Any] = None
86-
87-
# For models with dynamic resolution (e.g., Qwen2-VL):
88-
# Tuple of (temporal, height, width) grid sizes for each image
89-
# Shape: (num_images, 3) where each row is (t, h, w)
90-
# Can be numpy array or PyTorch tensor
9182
image_grid_thw: Optional[Any] = None
92-
93-
# Number of image tokens per image (for variable-length image tokens)
9483
image_token_counts: Optional[List[int]] = None
95-
96-
# Original image sizes before preprocessing (height, width)
97-
# Useful for models that need aspect ratio information
9884
image_sizes: Optional[List[tuple]] = None
99-
100-
# Whether images have been processed into embeddings
101-
# (set to True after first peer processes images)
10285
images_processed: bool = False
10386

10487
def has_images(self) -> bool:
@@ -410,16 +393,14 @@ def from_initial_request(
410393
else:
411394
next_token_id = initial_request.output_ids[-1]
412395

413-
# For VLM: after first peer processes images, mark as processed
414-
# and don't pass pixel_values to subsequent peers (only metadata)
415396
vlm_inputs = None
416397
if initial_request.vlm_inputs is not None:
417398
vlm_inputs = VLMInputs(
418-
pixel_values=None, # Don't pass raw pixels to next peers
399+
pixel_values=None,
419400
image_grid_thw=initial_request.vlm_inputs.image_grid_thw,
420401
image_token_counts=initial_request.vlm_inputs.image_token_counts,
421402
image_sizes=initial_request.vlm_inputs.image_sizes,
422-
images_processed=True, # Mark as processed by first peer
403+
images_processed=True,
423404
)
424405

425406
return IntermediateRequest(

0 commit comments

Comments
 (0)