Skip to content

Commit 2b87155

Browse files
authored
language model of paligemma 1 is gemma 1. (#613)
## Summary @lancerts we need to match the patch of gemma version to paligemma version
1 parent 74218f4 commit 2b87155

5 files changed

Lines changed: 375 additions & 20 deletions

File tree

src/liger_kernel/transformers/model/paligemma.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,190 @@
2121
logger = logging.get_logger(__name__)
2222

2323

24+
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
25+
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
26+
def lce_forward_deprecated(
27+
self,
28+
input_ids: torch.LongTensor = None,
29+
pixel_values: torch.FloatTensor = None,
30+
attention_mask: Optional[torch.Tensor] = None,
31+
position_ids: Optional[torch.LongTensor] = None,
32+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
33+
token_type_ids: Optional[torch.LongTensor] = None,
34+
cache_position: Optional[torch.LongTensor] = None,
35+
inputs_embeds: Optional[torch.FloatTensor] = None,
36+
labels: Optional[torch.LongTensor] = None,
37+
use_cache: Optional[bool] = None,
38+
output_attentions: Optional[bool] = None,
39+
output_hidden_states: Optional[bool] = None,
40+
return_dict: Optional[bool] = None,
41+
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
42+
r"""
43+
Args:
44+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48+
49+
Returns:
50+
51+
Example:
52+
53+
```python
54+
>>> from PIL import Image
55+
>>> import requests
56+
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
57+
58+
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
59+
>>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
60+
61+
>>> prompt = "answer en Where is the cow standing?"
62+
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
63+
>>> image = Image.open(requests.get(url, stream=True).raw)
64+
65+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
66+
67+
>>> # Generate
68+
>>> generate_ids = model.generate(**inputs, max_length=30)
69+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70+
"answer en Where is the cow standing?\nbeach"
71+
```"""
72+
73+
if (input_ids is None) ^ (inputs_embeds is not None):
74+
raise ValueError(
75+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
76+
)
77+
78+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
79+
output_hidden_states = (
80+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
81+
)
82+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83+
84+
# the attention mask is turned 4d after, we keep track of the original one
85+
input_attention_mask = attention_mask
86+
87+
if inputs_embeds is None:
88+
# 1. Extra the input embeddings
89+
inputs_embeds = self.get_input_embeddings()(input_ids)
90+
91+
# 2. Merge text and images
92+
if pixel_values is not None and input_ids.shape[1] != 1:
93+
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
94+
selected_image_feature = image_outputs.last_hidden_state
95+
image_features = self.multi_modal_projector(selected_image_feature)
96+
97+
if cache_position is None:
98+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
99+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
100+
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
101+
)
102+
103+
else:
104+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
105+
# generation with cache
106+
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
107+
# Retrieve the first layer to inspect the logits and mask out the hidden states
108+
# that are set to 0
109+
# TODO @molbap this will only work for dynamic cache.
110+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
111+
112+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
113+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
114+
115+
# Get the target length
116+
target_seqlen = cache_position[-1] + 1
117+
extended_attention_mask = torch.ones(
118+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1] + 1),
119+
dtype=attention_mask.dtype,
120+
device=attention_mask.device,
121+
)
122+
# Filter out only the tokens that can be un-attended, this can happen
123+
# if one uses PaliGemma+ Fused modules where the cache on the
124+
# first iteration is already big enough, or if one passes custom cache
125+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
126+
new_batch_index = batch_index[valid_indices]
127+
new_non_attended_tokens = non_attended_tokens[valid_indices]
128+
129+
# Zero-out the places where we don't need to attend
130+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
131+
132+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
133+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
134+
135+
attention_mask = attention_mask.to(inputs_embeds.dtype)
136+
outputs = self.language_model.model(
137+
attention_mask=attention_mask,
138+
position_ids=position_ids,
139+
past_key_values=past_key_values,
140+
inputs_embeds=inputs_embeds,
141+
use_cache=use_cache,
142+
output_attentions=output_attentions,
143+
output_hidden_states=output_hidden_states,
144+
return_dict=return_dict,
145+
cache_position=cache_position,
146+
)
147+
148+
hidden_states = outputs[0]
149+
150+
loss = None
151+
logits = None
152+
153+
if self.training and (labels is not None):
154+
shift_hidden_states = hidden_states[..., :-1, :]
155+
shift_labels = labels[..., 1:]
156+
157+
hidden_device = shift_hidden_states.device
158+
159+
if attention_mask is not None:
160+
# we use the input attention mask to shift the hidden_states and labels, because it is 2D.
161+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
162+
shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
163+
shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
164+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
165+
else:
166+
shift_hidden_states = shift_hidden_states.contiguous()
167+
shift_labels = shift_labels.contiguous()
168+
169+
# Flatten hidden state
170+
shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
171+
shift_labels = shift_labels.view(-1).to(hidden_device)
172+
173+
lce = LigerFusedLinearCrossEntropyLoss()
174+
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
175+
176+
else:
177+
logits = self.language_model.lm_head(hidden_states)
178+
if labels is not None:
179+
shift_logits = logits[..., :-1, :]
180+
shift_labels = labels[..., 1:]
181+
if input_attention_mask is not None:
182+
# we use the input attention mask to shift the logits and labels, because it is 2D.
183+
shift_attention_mask = input_attention_mask[..., 1:]
184+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
185+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
186+
else:
187+
shift_logits = shift_logits.contiguous()
188+
shift_labels = shift_labels.contiguous()
189+
# Flatten the tokens
190+
loss_fct = CrossEntropyLoss()
191+
192+
flat_logits = shift_logits.view(-1, self.config.vocab_size)
193+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
194+
loss = loss_fct(flat_logits, flat_labels)
195+
if not return_dict:
196+
output = (logits,) + outputs[1:]
197+
return (loss,) + output if loss is not None else output
198+
199+
return PaliGemmaCausalLMOutputWithPast(
200+
loss=loss,
201+
logits=logits,
202+
past_key_values=outputs.past_key_values,
203+
hidden_states=outputs.hidden_states,
204+
attentions=outputs.attentions,
205+
)
206+
207+
24208
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25209
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
26210
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def apply_liger_kernel_to_paligemma(
631631

632632
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
633633

634+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
634635
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
635636
from transformers.models.paligemma import modeling_paligemma
636637
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
@@ -639,6 +640,7 @@ def apply_liger_kernel_to_paligemma(
639640
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
640641

641642
from liger_kernel.transformers.model.paligemma import lce_forward
643+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
642644

643645
# The vision_tower is a SiglipVisionModel
644646
if layer_norm:
@@ -647,13 +649,22 @@ def apply_liger_kernel_to_paligemma(
647649
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
648650
# The multi_modal_projector is Linear, nothing to do
649651

650-
# The language_model is Gemma2ForCausalLM
651-
apply_liger_kernel_to_gemma2(rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, geglu=geglu)
652+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
653+
apply_liger_kernel_to_gemma(
654+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
655+
)
656+
apply_liger_kernel_to_gemma2(
657+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
658+
)
652659
# Handle loss function
653660
if cross_entropy:
654661
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
655662
if fused_linear_cross_entropy:
656-
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
663+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
664+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
665+
else: # if version < 4.46.1
666+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
667+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
657668

658669
if model is not None:
659670
# The model instance already exists, so we need to additionally patch the
@@ -672,16 +683,31 @@ def apply_liger_kernel_to_paligemma(
672683
_patch_layer_norm_module(layer.layer_norm1)
673684
_patch_layer_norm_module(layer.layer_norm2)
674685

675-
language_model: Gemma2ForCausalLM = model.language_model
676-
677-
apply_liger_kernel_to_gemma2(
678-
rope=rope,
679-
cross_entropy=False,
680-
fused_linear_cross_entropy=False,
681-
rms_norm=rms_norm,
682-
geglu=geglu,
683-
model=language_model,
684-
)
686+
language_model = model.language_model
687+
688+
if isinstance(language_model, GemmaForCausalLM):
689+
apply_liger_kernel_to_gemma(
690+
rope=rope,
691+
cross_entropy=False,
692+
fused_linear_cross_entropy=False,
693+
rms_norm=rms_norm,
694+
geglu=geglu,
695+
model=language_model,
696+
)
697+
698+
elif isinstance(language_model, Gemma2ForCausalLM):
699+
apply_liger_kernel_to_gemma2(
700+
rope=rope,
701+
cross_entropy=False,
702+
fused_linear_cross_entropy=False,
703+
rms_norm=rms_norm,
704+
geglu=geglu,
705+
model=language_model,
706+
)
707+
else:
708+
raise TypeError(
709+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
710+
)
685711

686712

687713
def apply_liger_kernel_to_qwen2(

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@
6464
MLLAMA_AVAILABLE = False
6565

6666
try:
67+
import transformers
68+
69+
from packaging import version
70+
from transformers.models.gemma.configuration_gemma import GemmaConfig
6771
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
6872
from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
6973
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
@@ -72,7 +76,7 @@
7276
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
7377
from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
7478

75-
PALIGEMMA_AVAILABLE = True
79+
PALIGEMMA_AVAILABLE = version.parse(transformers.__version__) >= version.parse("4.46.0")
7680
except ImportError:
7781
PALIGEMMA_AVAILABLE = False
7882

@@ -152,6 +156,55 @@
152156

153157
if PALIGEMMA_AVAILABLE:
154158
MINI_MODEL_SETUPS["mini_paligemma"] = MiniModelConfig(
159+
liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False),
160+
liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma,
161+
model_class=PaliGemmaForConditionalGeneration,
162+
mini_model_config=PaliGemmaConfig(
163+
vision_config=SiglipVisionConfig(
164+
attention_dropout=0.0,
165+
hidden_act="gelu_pytorch_tanh",
166+
hidden_size=1152,
167+
image_size=224,
168+
intermediate_size=2048, # 4304
169+
layer_norm_eps=1e-06,
170+
num_attention_heads=4, # 16
171+
num_channels=3,
172+
num_hidden_layers=4, # 27
173+
num_image_tokens=256,
174+
num_positions=256,
175+
patch_size=14,
176+
projection_dim=1024, # 2304
177+
),
178+
text_config=GemmaConfig(
179+
vocab_size=32000, # 256000
180+
hidden_size=1024, # 3072
181+
intermediate_size=2048, # 24576
182+
num_hidden_layers=4, # 28
183+
num_attention_heads=4, # 16
184+
num_key_value_heads=4, # 16
185+
head_dim=256,
186+
hidden_activation="gelu_pytorch_tanh",
187+
max_position_embeddings=8192,
188+
initializer_range=0.02,
189+
rms_norm_eps=1e-06,
190+
use_cache=True,
191+
pad_token_id=0,
192+
# Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
193+
# https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
194+
bos_token_id=1, # 128000
195+
eos_token_id=2, # 128001
196+
tie_word_embeddings=True,
197+
rope_theta=10000.0,
198+
attention_bias=False,
199+
attention_dropout=0.0,
200+
),
201+
image_token_index=4, # NOTE: outside the vocab size
202+
attn_implementation="eager",
203+
vocab_size=32000,
204+
projection_dim=1024,
205+
),
206+
)
207+
MINI_MODEL_SETUPS["mini_paligemma2"] = MiniModelConfig(
155208
liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False),
156209
liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma,
157210
model_class=PaliGemmaForConditionalGeneration,
@@ -297,7 +350,7 @@
297350
)
298351

299352

300-
def create_processor(model_name):
353+
def create_processor(model_name: str):
301354
if model_name == "mini_qwen2_vl":
302355
tokenizer_config = load_tokenizer_config(
303356
os.path.join(FAKE_CONFIGS_PATH, "Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json")
@@ -352,7 +405,7 @@ def create_processor(model_name):
352405
image_processor = MllamaImageProcessor(size={"height": 560, "width": 560})
353406
return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer)
354407

355-
elif model_name == "mini_paligemma":
408+
elif model_name.startswith("mini_paligemma"):
356409
tokenizer_config = load_tokenizer_config(
357410
os.path.join(
358411
FAKE_CONFIGS_PATH,
@@ -580,6 +633,25 @@ def run_mini_model_multimodal(
580633
),
581634
],
582635
),
636+
pytest.param(
637+
"mini_paligemma2",
638+
32,
639+
1e-4,
640+
torch.bfloat16,
641+
1e-3,
642+
1e-2,
643+
1e-1,
644+
1e-2,
645+
1e-2,
646+
1e-2,
647+
marks=[
648+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
649+
pytest.mark.skipif(
650+
not PALIGEMMA_AVAILABLE,
651+
reason="Paligemma2 not available in this version of transformers",
652+
),
653+
],
654+
),
583655
],
584656
)
585657
def test_mini_model_multimodal(

0 commit comments

Comments
 (0)