|
3 | 3 | import json |
4 | 4 | import re |
5 | 5 |
|
6 | | -from typing import Callable, Iterable, TYPE_CHECKING |
| 6 | +from typing import Callable, Iterable, TYPE_CHECKING, Sequence |
7 | 7 |
|
8 | 8 | import torch |
9 | 9 |
|
@@ -765,6 +765,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter |
765 | 765 | yield from super().modify_tensors(data_torch, name, bid) |
766 | 766 |
|
767 | 767 |
|
| 768 | +@ModelBase.register("Gemma4UnifiedForConditionalGeneration") |
| 769 | +class Gemma4UnifiedModel(Gemma4Model): |
| 770 | + model_arch = gguf.MODEL_ARCH.GEMMA4 |
| 771 | + |
| 772 | + def _get_suppress_tokens(self) -> Sequence[int] | None: |
| 773 | + gen_cfg_path = self.dir_model / "generation_config.json" |
| 774 | + if gen_cfg_path.is_file(): |
| 775 | + with open(gen_cfg_path, encoding="utf-8") as f: |
| 776 | + gen_cfg = json.load(f) |
| 777 | + return gen_cfg.get("suppress_tokens") |
| 778 | + return None |
| 779 | + |
| 780 | + def set_gguf_parameters(self): |
| 781 | + super().set_gguf_parameters() |
| 782 | + |
| 783 | + suppress_tokens = self._get_suppress_tokens() |
| 784 | + if suppress_tokens is not None: |
| 785 | + self.gguf_writer.add_suppress_tokens(suppress_tokens) |
| 786 | + |
| 787 | + |
768 | 788 | @ModelBase.register("Gemma4ForConditionalGeneration") |
769 | 789 | class Gemma4VisionAudioModel(MmprojModel): |
770 | 790 | has_audio_encoder = True |
@@ -839,3 +859,61 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter |
839 | 859 | data_torch = data_torch.permute(0, 3, 1, 2).contiguous() |
840 | 860 | mapped_name = self.map_tensor_name(name, (".weight", ".bias", ".input_max", ".input_min", ".output_max", ".output_min")) |
841 | 861 | yield (mapped_name, data_torch) |
| 862 | + |
| 863 | + |
| 864 | +@ModelBase.register("Gemma4UnifiedForConditionalGeneration") |
| 865 | +class Gemma4UnifiedVisionAudioModel(Gemma4VisionAudioModel): |
| 866 | + has_audio_encoder = True |
| 867 | + has_vision_encoder = True |
| 868 | + |
| 869 | + def __init__(self, *args, **kwargs): |
| 870 | + super().__init__(*args, **kwargs) |
| 871 | + assert self.hparams_vision is not None |
| 872 | + assert self.hparams_audio is not None |
| 873 | + text_embd_dim = self.hparams_vision["mm_embed_dim"] |
| 874 | + self.hparams_vision["hidden_size"] = text_embd_dim |
| 875 | + self.hparams_audio["hidden_size"] = text_embd_dim |
| 876 | + # this is a transformer-less vision tower, the params below are redundant but set to avoid error |
| 877 | + self.hparams_vision["intermediate_size"] = 0 |
| 878 | + self.hparams_vision["num_layers"] = 0 |
| 879 | + self.hparams_vision["num_attention_heads"] = 0 |
| 880 | + self.hparams_audio["intermediate_size"] = 0 |
| 881 | + self.hparams_audio["num_layers"] = 0 |
| 882 | + self.hparams_audio["num_attention_heads"] = 0 |
| 883 | + |
| 884 | + def set_gguf_parameters(self): |
| 885 | + super().set_gguf_parameters() |
| 886 | + self.gguf_writer.add_clip_vision_projector_type(gguf.VisionProjectorType.GEMMA4UV) |
| 887 | + self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4UA) |
| 888 | + |
| 889 | + def modify_tensors(self, data_torch, name, bid): |
| 890 | + if name.endswith("pos_embedding"): |
| 891 | + name += ".weight" |
| 892 | + data_torch = data_torch.permute(1, 0, 2) |
| 893 | + elif ".pos_norm." in name: |
| 894 | + # rename to patch_ln3 to reuse the tensor name scheme |
| 895 | + name = name.replace(".pos_norm.", ".patch_ln3.") |
| 896 | + elif "patch_dense.weight" in name: |
| 897 | + # ggml im2col outputs in RR..GG..BB.. (CHW) order, but weight expects RGBRGB.. (HWC). |
| 898 | + # Permute columns so column i aligns with CHW input position i. |
| 899 | + assert self.hparams_vision is not None |
| 900 | + p = self.hparams_vision["model_patch_size"] |
| 901 | + i = torch.arange(p * p * 3) |
| 902 | + ch = i // (p * p) |
| 903 | + row = (i % (p * p)) // p |
| 904 | + col = i % p |
| 905 | + # perm[i] = HWC column index for CHW position i |
| 906 | + perm = row * p * 3 + col * 3 + ch |
| 907 | + data_torch = data_torch[:, perm] |
| 908 | + elif "patch_ln1.weight" in name or "patch_ln1.bias" in name: |
| 909 | + # same permutation for patch_ln1 as patch_dense to align with CHW input order |
| 910 | + assert self.hparams_vision is not None |
| 911 | + p = self.hparams_vision["model_patch_size"] |
| 912 | + i = torch.arange(p * p * 3) |
| 913 | + ch = i // (p * p) |
| 914 | + row = (i % (p * p)) // p |
| 915 | + col = i % p |
| 916 | + # perm[i] = HWC index for CHW position i |
| 917 | + perm = row * p * 3 + col * 3 + ch |
| 918 | + data_torch = data_torch[perm] |
| 919 | + return super().modify_tensors(data_torch, name, bid) |
0 commit comments