Skip to content

Commit 1ce0173

Browse files
author
Iman Gohari
authored
fea(Gemma3/2): Added FusedRMSNorm (#2281)
1 parent 31a6581 commit 1ce0173

File tree

8 files changed

+57
-23
lines changed

8 files changed

+57
-23
lines changed

optimum/habana/transformers/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@
260260
gaudi_falcon_linear_forward,
261261
gaudi_FalconMambaForCausalLM_prepare_inputs_for_generation,
262262
gaudi_FalconMambaModel_forward,
263+
gaudi_gemma2_rmsnorm_forward,
264+
gaudi_gemma3_rmsnorm_forward,
263265
gaudi_generate_speech,
264266
gaudi_get_extended_attention_mask,
265267
gaudi_gpt2_forward,
@@ -615,6 +617,7 @@ def adapt_transformers_to_gaudi():
615617
transformers.models.gemma2.modeling_gemma2.Gemma2DecoderLayer = GaudiGemma2DecoderLayer
616618
transformers.models.gemma2.modeling_gemma2.Gemma2Model = GaudiGemma2Model
617619
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GaudiGemma2RotaryEmbedding
620+
transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm.forward = gaudi_gemma2_rmsnorm_forward
618621

619622
# Optimization for gemma3 on Gaudi
620623
transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM = GaudiGemma3ForCausalLM
@@ -624,6 +627,7 @@ def adapt_transformers_to_gaudi():
624627
transformers.models.gemma3.modeling_gemma3.Gemma3TextModel = GaudiGemma3TextModel
625628
transformers.models.gemma3.modeling_gemma3.Gemma3Model = GaudiGemma3Model
626629
transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration = GaudiGemma3ForConditionalGeneration
630+
transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = gaudi_gemma3_rmsnorm_forward
627631

628632
# Optimization for blip Text model on Gaudi
629633
transformers.models.blip.BlipTextModel.forward = gaudi_BlipTextModel_forward

optimum/habana/transformers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
GaudiGemma2MLP,
118118
GaudiGemma2Model,
119119
GaudiGemma2RotaryEmbedding,
120+
gaudi_gemma2_rmsnorm_forward,
120121
)
121122
from .gemma3 import (
122123
GaudiGemma3Attention,
@@ -126,6 +127,7 @@
126127
GaudiGemma3MLP,
127128
GaudiGemma3Model,
128129
GaudiGemma3TextModel,
130+
gaudi_gemma3_rmsnorm_forward,
129131
)
130132
from .glm4v import (
131133
ChatGLM4Tokenizer,

optimum/habana/transformers/models/gemma2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
GaudiGemma2MLP,
66
GaudiGemma2Model,
77
GaudiGemma2RotaryEmbedding,
8+
gaudi_gemma2_rmsnorm_forward,
89
)

optimum/habana/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ def __init__(self, config: Gemma2Config):
7474
super().__init__(config=config)
7575

7676

77+
def gaudi_gemma2_rmsnorm_forward(self, x):
78+
if x.device.type == "hpu" and FusedRMSNorm is not None:
79+
output = FusedRMSNorm.apply(x.float(), torch.ones_like(self.weight), self.eps)
80+
output = output * (1.0 + self.weight.float())
81+
return output.type_as(x)
82+
else:
83+
output = self._norm(x.float())
84+
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
85+
# See https://github.com/huggingface/transformers/pull/29402
86+
output = output * (1.0 + self.weight.float())
87+
return output.type_as(x)
88+
89+
7790
def gaudi_gemma2_repeat_kv(
7891
query_states: torch.Tensor,
7992
key_states: torch.Tensor,

optimum/habana/transformers/models/gemma3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
GaudiGemma3MLP,
77
GaudiGemma3Model,
88
GaudiGemma3TextModel,
9+
gaudi_gemma3_rmsnorm_forward,
910
)

optimum/habana/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@
7272
logger = logging.get_logger(__name__)
7373

7474

75+
def gaudi_gemma3_rmsnorm_forward(self, x):
76+
if x.device.type == "hpu" and FusedRMSNorm is not None:
77+
output = FusedRMSNorm.apply(x.float(), torch.ones_like(self.weight), self.eps)
78+
output = output * (1.0 + self.weight.float())
79+
return output.type_as(x)
80+
else:
81+
output = self._norm(x.float())
82+
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
83+
# See https://github.com/huggingface/transformers/pull/29402
84+
output = output * (1.0 + self.weight.float())
85+
return output.type_as(x)
86+
87+
7588
def gaudi_gemma3_repeat_kv(
7689
query_states: torch.Tensor,
7790
key_states: torch.Tensor,

tests/baselines/fixture/tests/test_text_generation_example.json

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -244,54 +244,54 @@
244244
"throughput": 357.46365062825083
245245
}
246246
},
247-
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-3-27b-it-1-False-True-False]": {
247+
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-3-27b-it-1-False-False-False]": {
248248
"gaudi2": {
249-
"output": "DeepSpeed is a machine learning framework designed to make distributed training easier and more efficient. It is created by Microsoft and has gained significant traction in the research and industry communities due to its ability to train large models with limited hardware resources.\nHere's a breakdown of the key features and benefits of DeepSpeed:\n\n**Core Features & Technologies:**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the cornerstone of DeepSpeed. It tackles the memory bottleneck in distributed training by partitioning model states (weights,",
250-
"throughput": 32.452270975799294
249+
"output": "DeepSpeed is a machine learning framework that enables you to train models with hundreds of billions or even trillions of parameters. Here's a breakdown of what it is, its key features, and how it compares to other approaches:\n\n**What is DeepSpeed?**\n\nDeveloped by Microsoft, DeepSpeed is a deep learning optimization library designed to make large-scale model training more efficient, accessible, and cost-effective. It's built on PyTorch and is open-source. It's particularly notable for enabling the training of",
250+
"throughput": 34.082512922376125
251251
},
252252
"gaudi3": {
253-
"output": "DeepSpeed is a machine learning framework that enables you to train models with hundreds of billions or even trillions of parameters. Here's a breakdown of what it is, its key features, and how it compares to other approaches:\n\n**What is DeepSpeed?**\n\nDeveloped by Microsoft, DeepSpeed is a deep learning optimization library designed to make large-scale model training more efficient, accessible, and cost-effective. It's built on PyTorch and is open-source. It's particularly notable for enabling the training of",
254-
"throughput": 38.03334992095207
253+
"output": "DeepSpeed is a machine learning framework that enables training very large models with high efficiency. It is often used with PyTorch, and it can significantly reduce memory usage and increase training throughput. Here's a breakdown of how it works, its key features, and how to use it:\n\n**How DeepSpeed Works**\n\nDeepSpeed achieves its efficiency through a combination of innovative techniques, primarily focusing on these areas:\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the cornerstone of DeepSpeed. ZeRO partitions",
254+
"throughput": 42.50246201556991
255255
}
256256
},
257257
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-3-12b-it-1-False-True-False]": {
258258
"gaudi2": {
259-
"output": "DeepSpeed is a machine learning framework focused on improving the performance and scalability of training deep learning models. It's designed to handle very large models and datasets, often used in areas like Natural Language Processing (NLP) and Computer Vision. Here's a breakdown of what DeepSpeed offers, how it works, and why it's valuable.\n\n**Key Features and Benefits**\n\n* **Scalability:** DeepSpeed's primary goal is to allow you to train models that are too large to fit on a single GPU or even",
260-
"throughput": 58.022545212763546
259+
"output": "DeepSpeed is a machine learning framework focused on improving the performance and scalability of training deep learning models. It's designed to handle very large models and datasets, often using techniques like model parallelism, data parallelism, and optimization techniques like ZeRO.\n\nHere's a breakdown of its core concepts and how it works, along with common use cases and examples:\n\n**Key Concepts & Techniques**\n\n* **ZeRO (Zero Redundancy Optimizer):** The cornerstone of DeepSpeed. ZeRO dramatically reduces memory consumption by partitioning model",
260+
"throughput": 68.30196577764131
261261
},
262262
"gaudi3": {
263-
"output": "DeepSpeed is a machine learning framework focused on improving the performance and scalability of training deep learning models. It's designed to handle very large models and datasets, often using techniques like model parallelism, data parallelism, and optimization techniques like ZeRO (Zero Redundancy Optimizer).\n\nHere's a breakdown of key concepts and functionalities within DeepSpeed:\n\n**1. Key Goals & Benefits**\n\n* **Scalability:** DeepSpeed's primary goal is to enable training extremely large models (billions or even trillions of parameters",
264-
"throughput": 69.15032921221514
263+
"output": "DeepSpeed is a machine learning framework focused on improving the performance and scalability of training deep learning models. It's designed to handle very large models and datasets, often used in areas like Natural Language Processing (NLP) and Computer Vision. Here's a breakdown of its key features and how it works:\n\n**1. Core Technologies & Benefits**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the *key* innovation in DeepSpeed. ZeRO tackles the memory bottleneck that arises when training massive",
264+
"throughput": 81.15713149929978
265265
}
266266
},
267267
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-3-4b-it-1-False-True-False]": {
268268
"gaudi2": {
269-
"output": "DeepSpeed is a machine learning framework focused on efficient training and inference of large models. It is built on PyTorch and aims to overcome the memory and computational limitations that often arise when training large neural networks.\n\nHere's a breakdown of DeepSpeed's key features, benefits, and how it works:\n\n**1. Core Technologies & Techniques:**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the cornerstone of DeepSpeed. It dramatically reduces memory consumption by partitioning the optimizer states, gradients, and parameters",
270-
"throughput": 112.10840313346671
269+
"output": "DeepSpeed is a machine learning framework focused on efficient training and inference of large models. It is developed by Microsoft and offers various optimizations like ZeRO, which is a memory partitioning technique to reduce memory footprint, and other features for faster training.\n\nHere's a breakdown of key aspects of DeepSpeed:\n\n**1. Core Technologies:**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the heart of DeepSpeed. It's a memory optimization technique that breaks down model parameters, gradients, and optimizer states",
270+
"throughput": 141.28462701651054
271271
},
272272
"gaudi3": {
273-
"output": "DeepSpeed is a machine learning framework focused on efficient training and serving of large models. It is built on PyTorch and aims to overcome the memory and computational limitations that often arise when training large neural networks.\n\nHere's a breakdown of DeepSpeed's key features, benefits, and how it works:\n\n**1. Core Technologies & Techniques:**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the cornerstone of DeepSpeed. It drastically reduces memory consumption by partitioning the optimizer states, gradients, and parameters",
274-
"throughput": 125.14846550177148
273+
"output": "DeepSpeed is a machine learning framework focused on efficient training and inference of large models. It is developed by Microsoft and offers various optimizations like ZeRO, which is a memory partitioning technique to reduce memory footprint, and other features such as data parallelism and pipeline parallelism.\n\nHere's a breakdown of its key features and how it compares to other frameworks like PyTorch and TensorFlow:\n\n**Key Features of DeepSpeed:**\n\n* **ZeRO (Zero Redundancy Optimizer):** This is the core of DeepSpeed. It comes in",
274+
"throughput": 153.30181217767333
275275
}
276276
},
277-
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-2-27b-1-False-True-False]": {
277+
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-2-27b-1-False-False-False]": {
278278
"gaudi2": {
279-
"output": "DeepSpeed is a machine learning framework that enables you to train models with trillions of parameters and beyond, using model parallelism to partition large models over multiple GPUs.\n\nThe following is a brief introduction to the DeepSpeed model parallel training.\n\n<h2>1. Introduction</h2>\n\nThe DeepSpeed model parallel training is a simple and effective way to train large models. It is a framework that enables you to train models with trillions of parameters and beyond.\n\nDeepSpeed is a distributed deep learning optimization toolkit that makes it easy and efficient",
280-
"throughput": 36.578709544111
279+
"output": "DeepSpeed is a machine learning framework that is designed to help you train your models faster and more efficiently. DeepSpeed is a system that allows you to train your models faster and more efficiently. DeepSpeed is a machine learning framework that is designed to help you train your models faster and more efficiently. DeepSpeed is a system that allows you to train your models faster and more efficiently.\n\nI’m going to repeat myself. DeepSpeed is a machine learning framework that is designed to help you train your models faster and more efficiently.\n\n",
280+
"throughput": 38.30642095055842
281281
},
282282
"gaudi3": {
283-
"output": "DeepSpeed is a machine learning framework that enables you to train models with trillions of parameters and tera-scale datasets, while also providing the flexibility to customize the training process.\n\nThe DeepSpeed library is a system and deep learning optimization toolkit that makes it possible to efficiently train deep learning models with hundreds of billions of parameters and beyond. It includes a set of techniques that can be used together or separately to achieve the best possible performance for training deep learning models.\n\nDeepSpeed is a library that enables you to use these techniques in",
284-
"throughput": 46.04685368495098
283+
"output": "DeepSpeed is a machine learning framework that enables you to train deep learning models at any scale. It is a system and deep learning software optimization system that makes distributed deep learning practical. DeepSpeed allows researchers and engineers to train deep learning models with terabyte-scale with hundreds of billions of parameters with acceptable speed and accuracy.\n\nDeepSpeed is a deep learning optimization and parallelism library from Microsoft Research optimized for efficiency, developed to train large models efficiently on multiple GPUs.\n\nDeepSpeed is a system and library effort ongoing for over three years.",
284+
"throughput": 48.842401849049224
285285
}
286286
},
287287
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-2-9b-1-False-True-False]": {
288288
"gaudi2": {
289289
"output": "DeepSpeed is a machine learning framework that enables training of large-scale deep learning models on a single GPU or across multiple GPUs. It is designed to be easy to use and highly scalable, making it a popular choice for training large-scale models such as GPT-3 and BERT.\n\nDeepSpeed is built on top of PyTorch, a popular deep learning framework, and provides a set of tools and libraries that make it easy to train large-scale models. It includes features such as zero-shot learning, which allows models to",
290-
"throughput": 92.302359446567
290+
"throughput": 99.98690579203925
291291
},
292292
"gaudi3": {
293293
"output": "DeepSpeed is a machine learning framework that enables training of large-scale deep learning models on a single GPU or across multiple GPUs. It is designed to be easy to use and to provide high performance.\n\nDeepSpeed is built on top of PyTorch, a popular deep learning framework. It provides a number of features that make it easy to train large-scale models, including:\n\n* Automatic model parallelism: DeepSpeed automatically parallelizes the model across multiple GPUs, so you don’t have to worry about how to do it yourself",
294-
"throughput": 111.60209707224463
294+
"throughput": 117.69951320588835
295295
}
296296
},
297297
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[google/gemma-7b-1-False-False]": {
@@ -837,4 +837,4 @@
837837
"throughput": 0.7583387
838838
}
839839
}
840-
}
840+
}

tests/test_text_generation_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@
5050
# ("Qwen/Qwen1.5-7B", 4, False, False, False),
5151
("google/gemma-7b", 1, False, True, False),
5252
("google/gemma-2-9b", 1, False, True, False),
53-
("google/gemma-2-27b", 1, False, True, False),
53+
("google/gemma-2-27b", 1, False, False, False),
5454
("google/gemma-3-4b-it", 1, False, True, False),
5555
("google/gemma-3-12b-it", 1, False, True, False),
56-
("google/gemma-3-27b-it", 1, False, True, False),
56+
("google/gemma-3-27b-it", 1, False, False, False),
5757
pytest.param(
5858
"state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated")
5959
),

0 commit comments

Comments
 (0)