|
21 | 21 | from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
22 | 22 | Qwen2VisionTransformerPretrainedModel, |
23 | 23 | ) |
| 24 | +from .utils import default_weight_loader |
24 | 25 |
|
25 | | -# from .model_loader import default_weight_loader |
26 | 26 | dtype = "fp32" |
27 | 27 |
|
28 | 28 |
|
29 | | -def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: |
30 | | - """Default weight loader.""" |
31 | | - try: |
32 | | - if param.numel() == 1 and loaded_weight.numel() == 1: |
33 | | - # Sometimes scalar values aren't considered tensors with shapes |
34 | | - # so if both param and loaded_weight are a scalar, |
35 | | - # "broadcast" instead of copy |
36 | | - param.data.fill_(loaded_weight.item()) |
37 | | - else: |
38 | | - assert param.size() == loaded_weight.size(), ( |
39 | | - f"Attempted to load weight ({loaded_weight.size()}) " |
40 | | - f"into parameter ({param.size()})" |
41 | | - ) |
42 | | - |
43 | | - param.data.copy_(loaded_weight) |
44 | | - except Exception: |
45 | | - # NOTE: This exception is added for the purpose of setting breakpoint to |
46 | | - # debug weight loading issues. |
47 | | - raise |
48 | | - |
49 | | - |
50 | 29 | def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: |
51 | 30 | return x * torch.sigmoid(1.702 * x) |
52 | 31 |
|
@@ -400,25 +379,8 @@ def forward(self, x, cu_seqlens, rotary_pos_emb, b) -> torch.Tensor: |
400 | 379 | x = x + self.mlp(self.norm2(x)) |
401 | 380 | return x |
402 | 381 |
|
403 | | - |
404 | | -# class Qwen2VisionTransformer(nn.Module): |
405 | 382 | class Qwen2VisionTransformer(Qwen2VisionTransformerPretrainedModel): |
406 | 383 | def __init__(self, config): |
407 | | - # img_size: int = 378, |
408 | | - # patch_size: int = 14, |
409 | | - # temporal_patch_size: int = 2, |
410 | | - # spatial_merge_size: int = 2, |
411 | | - # in_chans: int = 3, |
412 | | - # hidden_size: int = 1000, |
413 | | - # embed_dim: int = 768, |
414 | | - # depth: int = 12, |
415 | | - # num_heads: int = 16, |
416 | | - # mlp_ratio: float = 4.0, |
417 | | - # norm_layer: nn.Module = partial(LayerNorm, eps=1e-6), |
418 | | - # use_flash_attention: bool = False, |
419 | | - # *args, |
420 | | - # **kwargs, |
421 | | - # ) -> None: |
422 | 384 | super().__init__(config) |
423 | 385 | self.spatial_merge_size = config.spatial_merge_size |
424 | 386 |
|
|
0 commit comments