-
Notifications
You must be signed in to change notification settings - Fork 291
[WC] RoPe ignored pattern without transpose #3930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) 2026 Intel Corporation | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from nncf.common.graph.operator_metatypes import OperatorMetatype | ||
| from nncf.common.graph.patterns.patterns import GraphPattern | ||
|
|
||
|
|
||
| def create_rope_pattern( | ||
| mm_metatype: type[OperatorMetatype], | ||
| transpose_metatype: type[OperatorMetatype], | ||
| concat_metatype: type[OperatorMetatype], | ||
| cos_metatype: type[OperatorMetatype], | ||
| sin_metatype: type[OperatorMetatype], | ||
| ) -> GraphPattern: | ||
| """ | ||
| Creates Rotary Positional Embedding (RoPE) pattern. | ||
| Scheme: | ||
|
|
||
| (matmul) (matmul) | ||
| | | | ||
| (transpose) (concat) | ||
| | / \ | ||
| (concat) (cos) (sin) | ||
| / \ | ||
| (cos) (sin) | ||
|
|
||
| :param mm_metatype: MatMul metatype. | ||
| :param transpose_metatype: Transpose metatype. | ||
| :param concat_metatype: Concat metatype. | ||
| :param cos_metatype: Cos metatype. | ||
| :param sin_metatype: Sin metatype. | ||
| :return: The Rotary Positional Embedding (RoPE) pattern. | ||
| """ | ||
| ret_pattern = GraphPattern() | ||
| for with_transpose in [True, False]: | ||
| pattern = GraphPattern() | ||
| matmul_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "MATMUL", GraphPattern.METATYPE_ATTR: mm_metatype}) | ||
| concat_node = pattern.add_node( | ||
| **{GraphPattern.LABEL_ATTR: "CONCAT", GraphPattern.METATYPE_ATTR: concat_metatype} | ||
| ) | ||
| cos_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "COS", GraphPattern.METATYPE_ATTR: cos_metatype}) | ||
| sin_node = pattern.add_node(**{GraphPattern.LABEL_ATTR: "SIN", GraphPattern.METATYPE_ATTR: sin_metatype}) | ||
|
|
||
| if with_transpose: | ||
| transpose_node = pattern.add_node( | ||
| **{GraphPattern.LABEL_ATTR: "TRANSPOSE", GraphPattern.METATYPE_ATTR: transpose_metatype} | ||
| ) | ||
| pattern.add_edge(matmul_node, transpose_node) | ||
| pattern.add_edge(transpose_node, concat_node) | ||
| else: | ||
| pattern.add_edge(matmul_node, concat_node) | ||
| pattern.add_edge(concat_node, cos_node) | ||
| pattern.add_edge(concat_node, sin_node) | ||
| ret_pattern.add_pattern_alternative(pattern) | ||
| return ret_pattern | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -494,16 +494,24 @@ def forward(self, x): | |
| class RoPEModel(nn.Module): | ||
| INPUT_SIZE = [1, 10] | ||
|
|
||
| def __init__(self): | ||
| def __init__(self, with_transpose: bool, with_reshape: bool): | ||
| super().__init__() | ||
| self._with_transpose = with_transpose | ||
| self._with_reshape = with_reshape | ||
| data_shape = [5] if with_reshape else [1, 5, 1] | ||
| with set_torch_seed(): | ||
| self.data = torch.randn([5]) | ||
| self.data = nn.Parameter(torch.randn(data_shape)) | ||
|
|
||
| def forward(self, x): | ||
| x = torch.unsqueeze(x, dim=0) | ||
| reshape = torch.reshape(self.data, [1, 5, 1]) | ||
| x = torch.matmul(reshape, x) | ||
| x = torch.transpose(x, 2, 1) | ||
|
|
||
| if self._with_reshape: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this added? Is it part of pattern?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was changed because the model with reshape is not compressed by default, by the WC algorithm. It is not the part of the pattern, but with the reshape it is not possible to check that ignored pattern worked out properly. The reason why the reshape was added in the first place is unclear #3059 |
||
| data = torch.reshape(self.data, [1, 5, 1]) | ||
| else: | ||
| data = self.data | ||
| x = torch.matmul(data, x) | ||
| if self._with_transpose: | ||
| x = torch.transpose(x, 2, 1) | ||
| x = torch.cat([x], dim=2) | ||
| x1 = x.sin() | ||
| x2 = x.cos() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,7 +122,7 @@ def get_matmul_model() -> TModel: | |
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def get_RoPE_model() -> TModel: | ||
| def get_RoPE_model(with_transpose: bool) -> TModel: | ||
| """Returns a backend model for test_rope_weight_compression.""" | ||
|
|
||
| @staticmethod | ||
|
|
@@ -407,6 +407,11 @@ def get_different_channel_size_model(channel_sizes: list[int]) -> TModel: | |
| def get_num_int4_nodes(model: TModel): | ||
| "Returns number of int4 nodes." | ||
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def get_num_int8_nodes(model: TModel): | ||
| "Returns number of int4 nodes." | ||
|
Comment on lines
+410
to
+413
|
||
|
|
||
| @staticmethod | ||
| @abstractmethod | ||
| def get_num_int4_group_sizes(model: TModel) -> dict[int, int]: | ||
|
|
@@ -445,15 +450,20 @@ def test_awq_with_ignored_scope(self, mocker, is_3d_weights): | |
| int4_num_nodes = self.get_num_int4_nodes(compressed_model) | ||
| assert int4_num_nodes == int4_ref_num_compressed, int4_num_nodes | ||
|
|
||
| def test_rope_weight_compression(self): | ||
| model = self.get_RoPE_model() | ||
| @pytest.mark.parametrize("with_transpose", [True, False]) | ||
| def test_rope_weight_compression(self, with_transpose): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How ``test_rope_weight_compression` checks ROPE patterns?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| model = self.get_RoPE_model(with_transpose=with_transpose) | ||
| sz = 8 | ||
| n_samples = 10 | ||
|
|
||
| dataset = Dataset( | ||
| [self.to_tensor(np.ones([1, i + 1, sz], dtype=np.float32)) for i in range(n_samples)], | ||
| self.get_transform_func(), | ||
| ) | ||
| # First matmul is always compressed in INT8 format, | ||
| # as there is only one matmul in the target model | ||
| # the check int8 num ref == 0 checks that the | ||
| # ignored ROPE pattern is being applied | ||
|
Comment on lines
+463
to
+466
|
||
| compressed_model = compress_weights( | ||
| model, | ||
| mode=CompressWeightsMode.INT4_SYM, | ||
|
|
@@ -462,9 +472,9 @@ def test_rope_weight_compression(self): | |
| dataset=dataset, | ||
| ) | ||
|
|
||
| int4_ref_num_compressed = 0 | ||
| int4_num_nodes = self.get_num_int4_nodes(compressed_model) | ||
| assert int4_num_nodes == int4_ref_num_compressed | ||
| int8_ref_num_compressed = 0 | ||
| int8_num_nodes = self.get_num_int8_nodes(compressed_model) | ||
| assert int8_num_nodes == int8_ref_num_compressed | ||
|
|
||
| def test_sam_pe_weight_compression(self): | ||
| model = self.get_SAM_PE_model() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.