Description
Follow https://huggingface.co/docs/transformers/en/quantization/fbgemm_fp8 and I run it successfully.
But when I run it with qwen2 model, with error "RuntimeError: Invalid datatype. input must be BF16".
RuntimeError Traceback (most recent call last)
Cell In[4], line 11
8 input_text = "What are we having for dinner?"
9 input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
---> 11 output = quantized_model.generate(**input_ids, max_new_tokens=10)
12 print(tokenizer.decode(output[0], skip_special_tokens=True))
File /usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False
)
-> 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )
File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
2984 if synced_gpus and this_peer_finished:
2985 continue # don't waste resources running the code we don't need
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:1104, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1101 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1103 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1104 outputs = self.model(
1105 input_ids=input_ids,
1106 attention_mask=attention_mask,
1107 position_ids=position_ids,
1108 past_key_values=past_key_values,
1109 inputs_embeds=inputs_embeds,
1110 use_cache=use_cache,
1111 output_attentions=output_attentions,
1112 output_hidden_states=output_hidden_states,
1113 return_dict=return_dict,
1114 cache_position=cache_position,
1115 )
1117 hidden_states = outputs[0]
1118 logits = self.lm_head(hidden_states)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:915, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
904 layer_outputs = self._gradient_checkpointing_func(
905 decoder_layer.call,
906 hidden_states,
(...)
912 cache_position,
913 )
914 else:
--> 915 layer_outputs = decoder_layer(
916 hidden_states,
917 attention_mask=causal_mask,
918 position_ids=position_ids,
919 past_key_value=past_key_values,
920 output_attentions=output_attentions,
921 use_cache=use_cache,
922 cache_position=cache_position,
923 )
925 hidden_states = layer_outputs[0]
927 if use_cache:
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:655, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
652 hidden_states = self.input_layernorm(hidden_states)
654 # Self Attention
--> 655 hidden_states, self_attn_weights, present_key_value = self.self_attn(
656 hidden_states=hidden_states,
657 attention_mask=attention_mask,
658 position_ids=position_ids,
659 past_key_value=past_key_value,
660 output_attentions=output_attentions,
661 use_cache=use_cache,
662 cache_position=cache_position,
663 )
664 hidden_states = residual + hidden_states
666 # Fully Connected
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:592, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
589 attn_output = attn_output.transpose(1, 2).contiguous()
590 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
--> 592 attn_output = self.o_proj(attn_output)
594 return attn_output, None, past_key_value
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.11/dist-packages/transformers/integrations/fbgemm_fp8.py:50, in FbgemmFp8Linear.forward(self, x)
47 num_tokens = None
48 # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
49 # FBGEMM/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu at e08af8539c391437f447173863df0f3f6f
---> 50 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
51 x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
52 )
53 # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
54 # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
55
56 # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
57 output = torch.ops.fbgemm.f8f8bf16_rowwise(
58 x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
59 )
File /usr/local/lib/python3.11/dist-packages/torch/ops.py:1061, in OpOverloadPacket.call(self, *args, **kwargs)
1059 if self_._has_torchbind_op_overload and must_dispatch_in_python(args, kwargs):
1060 return call_overload_packet_from_python(self, args, kwargs)
-> 1061 return self._op(*args, **(kwargs or {}))
RuntimeError: Invalid datatype. input must be BF16
But I compare qwen2 and llama3 8B, the dtype are all bf16.
Activity