Skip to content

Commit 34275e4

Browse files
committed
refactor R1 rotation and partial_hadamard
1 parent a400d18 commit 34275e4

File tree

2 files changed

+15
-32
lines changed

2 files changed

+15
-32
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -266,33 +266,21 @@ def _create_r1_scheme(self) -> TransformScheme:
266266
location="weight_output",
267267
)
268268
)
269-
if getattr(self.mappings, "attn_v_is_kv_combined", False):
270-
apply_list.append(
271-
TransformArgs(
272-
targets=[
273-
self.mappings.attn_q,
274-
self.mappings.attn_k,
275-
*self.mappings.mlp_in,
276-
self.mappings.lm_head,
277-
],
278-
location="weight_input",
279-
inverse=True,
280-
)
281-
)
282-
else:
283-
apply_list.append(
284-
TransformArgs(
285-
targets=[
286-
self.mappings.attn_q,
287-
self.mappings.attn_k,
288-
self.mappings.attn_v,
289-
*self.mappings.mlp_in,
290-
self.mappings.lm_head,
291-
],
292-
location="weight_input",
293-
inverse=True,
294-
)
269+
r1_input_targets = [
270+
self.mappings.attn_q,
271+
self.mappings.attn_k,
272+
*self.mappings.mlp_in,
273+
self.mappings.lm_head,
274+
]
275+
if not getattr(self.mappings, "attn_v_is_kv_combined", False):
276+
r1_input_targets.append(self.mappings.attn_v)
277+
apply_list.append(
278+
TransformArgs(
279+
targets=r1_input_targets,
280+
location="weight_input",
281+
inverse=True,
295282
)
283+
)
296284
return TransformScheme(
297285
type=self.transform_type,
298286
randomize=self.randomize,

src/llmcompressor/modifiers/transform/spinquant/partial_hadamard.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,8 @@ def __init__(
9696
v_head_dim: int = 0,
9797
):
9898
super().__init__(weight, perm, scheme, args, module_type)
99-
self.weight = weight
100-
self.perm = perm
101-
self.scheme = scheme
102-
self.args = args
103-
self.module_type = module_type
10499
self.qk_nope_head_dim = qk_nope_head_dim
105100
self.v_head_dim = v_head_dim
106-
self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()
107101

108102
def forward(self, value: Tensor) -> Tensor:
109103
weight = self.weight
@@ -154,6 +148,7 @@ def apply_partial_transform_weight(
154148
:return: value after transform_weight has been applied
155149
"""
156150
assert transform_weight.shape[0] == transform_weight.shape[1]
151+
assert qk_nope_head_dim > 0 and v_head_dim > 0
157152
if TransformLocation(location).is_online():
158153
return _multihead_matmul(value, transform_weight)
159154

0 commit comments

Comments
 (0)