Skip to content

Commit cb82e5c

Browse files
committed
Fixed accuracy issue with hpu graph and dynamicity
1 parent 2f09e05 commit cb82e5c

File tree

4 files changed

+13
-1
lines changed

4 files changed

+13
-1
lines changed

optimum/habana/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def eager_attention_forward(
135135
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
136136

137137
if token_idx is not None:
138-
probs[..., token_idx] = 0
138+
# index_copy_() was used to avoid dynamicity in probs[..., token_idx]
139+
zeros = torch.zeros(probs.shape[:-1] + (1,), dtype=probs.dtype, device=probs.device)
140+
probs.index_copy_(-1, token_idx, zeros)
139141
scores = probs
140142
else:
141143
scores = probs[..., :-1]

optimum/habana/transformers/models/mistral/configuration_mistral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
attention_dropout,
5656
**kwargs,
5757
)
58+
5859
self.rope_scaling = rope_scaling
5960

6061
# Validate the correctness of rotary position embeddings parameters

tests/baselines/fixture/tests/test_text_generation_example.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,14 @@
423423
"throughput": 45.90538768350833
424424
}
425425
},
426+
"tests/test_text_generation_example.py::test_text_generation_bf16_1x[unsloth/gpt-oss-20b-BF16-1-False-False]": {
427+
"gaudi2": {
428+
"throughput": 49.2845966607741
429+
},
430+
"gaudi3": {
431+
"throughput": 59.51780208740626
432+
}
433+
},
426434
"tests/test_text_generation_example.py::test_text_generation_contrastive_search[gpt2-xl-1-False]": {
427435
"gaudi1": {
428436
"throughput": 34.48141280163397

tests/test_text_generation_example.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
("moonshotai/Moonlight-16B-A3B", 1, False, False),
6565
("Qwen/Qwen3-8B", 1, False, False),
6666
("Qwen/Qwen3-30B-A3B", 1, False, False),
67+
("unsloth/gpt-oss-20b-BF16", 1, False, False),
6768
],
6869
"fp8": [
6970
pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, marks=pytest.mark.x4),

0 commit comments

Comments
 (0)