|
| 1 | +import itertools |
| 2 | +import math |
| 3 | +from collections.abc import Iterable |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import einops |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +import torch.nn.functional as F |
| 10 | +from torch import Tensor |
| 11 | +from transformers.configuration_utils import PretrainedConfig |
| 12 | +from transformers.modeling_outputs import BaseModelOutputWithPooling |
| 13 | +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config |
| 14 | +from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel |
| 15 | + |
| 16 | +import sglang.srt.managers.mm_utils as mm_utils |
| 17 | +import sglang.srt.model_loader.weight_utils as weight_utils |
| 18 | +import sglang.srt.utils as utils |
| 19 | +from sglang.srt.layers.logits_processor import LogitsProcessorOutput |
| 20 | +from sglang.srt.layers.quantization.base_config import QuantizationConfig |
| 21 | +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens |
| 22 | +from sglang.srt.managers.schedule_batch import ( |
| 23 | + Modality, |
| 24 | + MultimodalDataItem, |
| 25 | + MultimodalInputs, |
| 26 | +) |
| 27 | +from sglang.srt.model_executor.forward_batch_info import ForwardBatch |
| 28 | +from sglang.srt.models.qwen2 import Qwen2ForCausalLM |
| 29 | + |
| 30 | +MM_HIDDEN_SIZE = 3456 |
| 31 | + |
| 32 | + |
| 33 | +class NVILAConfig(PretrainedConfig): |
| 34 | + model_type = "nvila" |
| 35 | + sub_configs = { |
| 36 | + "text_config": Qwen2Config, |
| 37 | + "vision_config": SiglipVisionConfig, |
| 38 | + } |
| 39 | + _auto_class = "AutoConfig" |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + *, |
| 44 | + text_config: dict[str, Any] | None = None, |
| 45 | + vision_config: dict[str, Any] | None = None, |
| 46 | + image_token_id: int | None = None, |
| 47 | + video_token_id: int | None = None, |
| 48 | + **kwargs, |
| 49 | + ): |
| 50 | + self.text_config = ( |
| 51 | + Qwen2Config(**text_config) if text_config is not None else Qwen2Config() |
| 52 | + ) |
| 53 | + self.vision_config = ( |
| 54 | + SiglipVisionConfig(**vision_config) |
| 55 | + if vision_config is not None |
| 56 | + else SiglipVisionConfig() |
| 57 | + ) |
| 58 | + |
| 59 | + self.image_token_id = image_token_id if image_token_id is not None else -1 |
| 60 | + self.video_token_id = video_token_id if video_token_id is not None else -1 |
| 61 | + |
| 62 | + super().__init__(**kwargs) |
| 63 | + |
| 64 | + |
| 65 | +class NVILAMultiModalProjectorDownsampleBlock(nn.Module): |
| 66 | + def forward(self, x: Tensor) -> Tensor: |
| 67 | + batch_size, sequence_length, hidden_size = x.shape |
| 68 | + |
| 69 | + feat_size = math.isqrt(sequence_length) |
| 70 | + |
| 71 | + features = x.reshape(batch_size, feat_size, feat_size, hidden_size) |
| 72 | + |
| 73 | + pad_after = feat_size % 2 |
| 74 | + if pad_after > 0: |
| 75 | + features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) |
| 76 | + feat_size = feat_size + pad_after |
| 77 | + |
| 78 | + features = features.reshape( |
| 79 | + batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size |
| 80 | + ) |
| 81 | + features = features.permute(0, 1, 3, 2, 4, 5).contiguous() |
| 82 | + features = features.reshape(batch_size, -1, 4 * hidden_size) |
| 83 | + |
| 84 | + return features |
| 85 | + |
| 86 | + |
| 87 | +class NVILAMultiModalProjector(nn.Module): |
| 88 | + def __init__(self, config: NVILAConfig): |
| 89 | + super().__init__() |
| 90 | + |
| 91 | + self.layers = nn.Sequential( |
| 92 | + NVILAMultiModalProjectorDownsampleBlock(), |
| 93 | + nn.LayerNorm(MM_HIDDEN_SIZE * 4), |
| 94 | + nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size), |
| 95 | + nn.GELU(), |
| 96 | + nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size), |
| 97 | + ) |
| 98 | + |
| 99 | + def forward(self, x: Tensor) -> Tensor: |
| 100 | + return self.layers(x) |
| 101 | + |
| 102 | + |
| 103 | +class NVILAForConditionalGeneration(nn.Module): |
| 104 | + def __init__( |
| 105 | + self, |
| 106 | + config: NVILAConfig, |
| 107 | + quant_config: QuantizationConfig | None = None, |
| 108 | + prefix: str = "", |
| 109 | + ) -> None: |
| 110 | + super().__init__() |
| 111 | + |
| 112 | + self.config = config |
| 113 | + |
| 114 | + self.vision_tower = SiglipVisionModel(config.vision_config) |
| 115 | + self.mm_projector = NVILAMultiModalProjector(config) |
| 116 | + self.llm = Qwen2ForCausalLM( |
| 117 | + config=config.text_config, |
| 118 | + quant_config=quant_config, |
| 119 | + prefix=utils.add_prefix("llm", prefix), |
| 120 | + ) |
| 121 | + |
| 122 | + def forward( |
| 123 | + self, |
| 124 | + input_ids: Tensor, |
| 125 | + positions: Tensor, |
| 126 | + forward_batch: ForwardBatch, |
| 127 | + get_embedding: bool = False, |
| 128 | + ) -> LogitsProcessorOutput: |
| 129 | + output = mm_utils.general_mm_embed_routine( |
| 130 | + input_ids=input_ids, |
| 131 | + forward_batch=forward_batch, |
| 132 | + language_model=self.llm, |
| 133 | + data_embedding_funcs={ |
| 134 | + Modality.IMAGE: self.get_image_feature, |
| 135 | + Modality.VIDEO: self.get_image_feature, |
| 136 | + }, |
| 137 | + get_embedding=get_embedding, |
| 138 | + positions=positions, |
| 139 | + ) |
| 140 | + |
| 141 | + assert isinstance(output, LogitsProcessorOutput) |
| 142 | + |
| 143 | + return output |
| 144 | + |
| 145 | + def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor: |
| 146 | + block_sizes = ( |
| 147 | + list( |
| 148 | + itertools.chain.from_iterable( |
| 149 | + x.block_sizes for x in mm_input if hasattr(x, "block_sizes") |
| 150 | + ) |
| 151 | + ) |
| 152 | + or None |
| 153 | + ) |
| 154 | + pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0) |
| 155 | + |
| 156 | + vision_tower_output: BaseModelOutputWithPooling = self.vision_tower( |
| 157 | + pixel_values.to( |
| 158 | + device=self.vision_tower.device, dtype=self.vision_tower.dtype |
| 159 | + ), |
| 160 | + output_hidden_states=True, |
| 161 | + ) |
| 162 | + assert vision_tower_output.hidden_states is not None |
| 163 | + |
| 164 | + vision_features: Tensor = vision_tower_output.hidden_states[-2] |
| 165 | + |
| 166 | + vision_features_list, block_sizes = merge_features_for_dynamic_s2( |
| 167 | + vision_features, |
| 168 | + block_sizes=( |
| 169 | + block_sizes |
| 170 | + if block_sizes is not None |
| 171 | + else [None] * vision_features.shape[0] |
| 172 | + ), |
| 173 | + resize_output_to_scale_idx=-1, |
| 174 | + scales=[448, 896, 1344], |
| 175 | + ) |
| 176 | + |
| 177 | + vision_features_list = [ |
| 178 | + split_chessboard(x, block_size[0], block_size[1]) |
| 179 | + for x, block_size in zip(vision_features_list, block_sizes) |
| 180 | + ] |
| 181 | + |
| 182 | + vision_features = torch.cat( |
| 183 | + [einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list] |
| 184 | + ) |
| 185 | + |
| 186 | + vision_features = self.mm_projector(vision_features) |
| 187 | + |
| 188 | + vision_features_list = list( |
| 189 | + vision_features.split( |
| 190 | + [block_size[0] * block_size[1] for block_size in block_sizes], dim=0 |
| 191 | + ) |
| 192 | + ) |
| 193 | + vision_features_list = [ |
| 194 | + merge_chessboard(x, block_size[0], block_size[1]) |
| 195 | + for x, block_size in zip(vision_features_list, block_sizes) |
| 196 | + ] |
| 197 | + |
| 198 | + vision_features = torch.stack( |
| 199 | + [einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list] |
| 200 | + ) |
| 201 | + |
| 202 | + vision_features = einops.rearrange(vision_features, "n p d -> (n p) d") |
| 203 | + |
| 204 | + return vision_features |
| 205 | + |
| 206 | + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None: |
| 207 | + params_dict = dict(self.named_parameters()) |
| 208 | + |
| 209 | + for name, loaded_weight in weights: |
| 210 | + if name.startswith("llm."): |
| 211 | + self.llm.load_weights([(name[len("llm.") :], loaded_weight)]) |
| 212 | + else: |
| 213 | + param = params_dict[name] |
| 214 | + weight_loader = getattr( |
| 215 | + param, "weight_loader", weight_utils.default_weight_loader |
| 216 | + ) |
| 217 | + weight_loader(param, loaded_weight) |
| 218 | + |
| 219 | + def pad_input_ids( |
| 220 | + self, input_ids: list[int], mm_inputs: MultimodalInputs |
| 221 | + ) -> list[int]: |
| 222 | + pattern = MultiModalityDataPaddingPatternMultimodalTokens() |
| 223 | + return pattern.pad_input_tokens(input_ids, mm_inputs) |
| 224 | + |
| 225 | + |
| 226 | +def merge_chessboard(x, num_split_h, num_split_w): |
| 227 | + """ |
| 228 | + x: b * n * c or b * h * w * c |
| 229 | + out: b * c * h * w |
| 230 | + Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. |
| 231 | + """ |
| 232 | + B = x.shape[0] |
| 233 | + if x.dim() == 3: |
| 234 | + N = x.shape[1] |
| 235 | + x = einops.rearrange( |
| 236 | + x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N) |
| 237 | + ) |
| 238 | + |
| 239 | + assert B % (num_split_h * num_split_w) == 0 |
| 240 | + b = B // (num_split_h * num_split_w) |
| 241 | + |
| 242 | + x_merge = torch.cat( |
| 243 | + [ |
| 244 | + torch.cat( |
| 245 | + [ |
| 246 | + x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] |
| 247 | + for j in range(num_split_w) |
| 248 | + ], |
| 249 | + dim=-1, |
| 250 | + ) |
| 251 | + for i in range(num_split_h) |
| 252 | + ], |
| 253 | + dim=-2, |
| 254 | + ) |
| 255 | + |
| 256 | + return x_merge |
| 257 | + |
| 258 | + |
| 259 | +def merge_features_for_dynamic_s2( |
| 260 | + image_features, block_sizes, *, scales, resize_output_to_scale_idx |
| 261 | +): |
| 262 | + image_features_each_image = [] |
| 263 | + new_block_sizes = [] |
| 264 | + block_cnt = 0 |
| 265 | + for block_size_each_image in block_sizes: |
| 266 | + if block_size_each_image is None: |
| 267 | + cur_features = image_features[block_cnt : block_cnt + 1] |
| 268 | + cur_features = einops.rearrange( |
| 269 | + cur_features, |
| 270 | + "1 (h w) c -> 1 c h w", |
| 271 | + h=math.isqrt(cur_features.shape[1]), |
| 272 | + ) |
| 273 | + cur_features = cur_features.repeat(1, len(scales), 1, 1) |
| 274 | + image_features_each_image.append(cur_features) |
| 275 | + new_block_sizes.append((1, 1)) |
| 276 | + block_cnt += 1 |
| 277 | + else: |
| 278 | + cur_features_each_scale = [] |
| 279 | + for scale in scales[:-1]: |
| 280 | + num_blocks_this_scale = (scale // scales[0]) ** 2 |
| 281 | + cur_features_each_scale.append( |
| 282 | + merge_chessboard( |
| 283 | + image_features[block_cnt : block_cnt + num_blocks_this_scale], |
| 284 | + num_split_h=scale // scales[0], |
| 285 | + num_split_w=scale // scales[0], |
| 286 | + ) |
| 287 | + ) # 1 * C * H * W |
| 288 | + block_cnt += num_blocks_this_scale |
| 289 | + num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] |
| 290 | + cur_features_each_scale.append( |
| 291 | + merge_chessboard( |
| 292 | + image_features[block_cnt : block_cnt + num_blocks_last_scale], |
| 293 | + num_split_h=block_size_each_image[0], |
| 294 | + num_split_w=block_size_each_image[1], |
| 295 | + ) |
| 296 | + ) # 1 * C * H * W |
| 297 | + block_cnt += num_blocks_last_scale |
| 298 | + |
| 299 | + # resize and concat features from different scales |
| 300 | + output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] |
| 301 | + cur_features = torch.cat( |
| 302 | + [ |
| 303 | + F.interpolate( |
| 304 | + cur_features_each_scale[i].to(torch.float32), |
| 305 | + size=output_size, |
| 306 | + mode="area", |
| 307 | + ).to(cur_features_each_scale[i].dtype) |
| 308 | + for i in range(len(cur_features_each_scale)) |
| 309 | + ], |
| 310 | + dim=1, |
| 311 | + ) |
| 312 | + |
| 313 | + image_features_each_image.append(cur_features) |
| 314 | + |
| 315 | + if ( |
| 316 | + resize_output_to_scale_idx == len(scales) - 1 |
| 317 | + or resize_output_to_scale_idx == -1 |
| 318 | + ): |
| 319 | + new_block_sizes.append(block_size_each_image) |
| 320 | + else: |
| 321 | + new_block_sizes.append( |
| 322 | + ( |
| 323 | + scales[resize_output_to_scale_idx] // scales[0], |
| 324 | + scales[resize_output_to_scale_idx] // scales[0], |
| 325 | + ) |
| 326 | + ) |
| 327 | + |
| 328 | + assert block_cnt == len( |
| 329 | + image_features |
| 330 | + ), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!" |
| 331 | + |
| 332 | + return image_features_each_image, new_block_sizes |
| 333 | + |
| 334 | + |
| 335 | +def split_chessboard(x, num_split_h, num_split_w): |
| 336 | + """ |
| 337 | + x: b * c * h * w |
| 338 | + out: b * c * h * w |
| 339 | + Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension |
| 340 | + """ |
| 341 | + B, C, H, W = x.shape |
| 342 | + assert H % num_split_h == 0 and W % num_split_w == 0 |
| 343 | + h, w = H // num_split_h, W // num_split_w |
| 344 | + x_split = torch.cat( |
| 345 | + [ |
| 346 | + x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] |
| 347 | + for i in range(num_split_h) |
| 348 | + for j in range(num_split_w) |
| 349 | + ], |
| 350 | + dim=0, |
| 351 | + ) |
| 352 | + return x_split |
| 353 | + |
| 354 | + |
| 355 | +EntryClass = [NVILAForConditionalGeneration] |
0 commit comments