Skip to content

Fix Mistral4 tests#44827

Open
3outeille wants to merge 18 commits intohuggingface:mainfrom
3outeille:fix-aligned-data-ptr-grouped-mm
Open

Fix Mistral4 tests#44827
3outeille wants to merge 18 commits intohuggingface:mainfrom
3outeille:fix-aligned-data-ptr-grouped-mm

Conversation

@3outeille
Copy link
Member

@3outeille 3outeille commented Mar 18, 2026

@3outeille 3outeille changed the title Fix RuntimeError: expected data_ptr to be aligned to 16 bytes Fix RuntimeError 16 bytes alignment for Mistral Mar 18, 2026
@3outeille 3outeille changed the title Fix RuntimeError 16 bytes alignment for Mistral Fix RuntimeError 16 bytes alignment for Mistral4 Mar 18, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@3outeille 3outeille changed the title Fix RuntimeError 16 bytes alignment for Mistral4 Fix Mistral4 tests Mar 18, 2026
else:
# (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim)
out = _grouped_mm(input, weight.transpose(-2, -1), offs=offs)
out = _grouped_mm(input, weight.transpose(-2, -1).contiguous(), offs=offs)
Copy link
Member Author

@3outeille 3outeille Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous() after .transpose(-2, -1) in _grouped_linear is to ensure the weight tensor memory layout is contiguous before passing it to _grouped_mm, fixing RuntimeError: expected data_ptr to be aligned to 16 bytes. We could maybe forced Forced16BytesAlignment as in with MoE during the weight converter maybe ? (this way no other model will have this issue)

partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads

dim = int(dim * partial_rotary_factor) # Mixtral4 doesn't apply ROPE to the full attention head
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was failing test_model_rope_scaling_frequencies with AssertionError: The values for attribute 'shape' do not match: torch.Size([1, 64]) != torch.Size([1, 128]). Mistral4 does not apply the rope to the full attention head cf (qk_rope and qk_nope)

@@ -44,12 +44,21 @@


class Mistral4ModelTester(CausalLMModelTester):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_model_is_small fails because the common inherited tester config made the tiny test model too large (1,233,664 params). Needs to be < 1000000

_supports_flex_attn = True

_can_compile_fullgraph = True
_can_compile_fullgraph = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchInductor error so 🙈

@3outeille 3outeille requested a review from Cyrilvallez March 18, 2026 18:40
cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens
position_ids = kwargs.get("position_ids")
if position_ids is None:
position_ids = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens
Copy link
Member Author

@3outeille 3outeille Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to reuse the RoPE's position_ids otherwise we end up with different positions ( the fact that we recomputetorch.arange(seq_len)all the time) than RoPE's position_ids for the tokens which fucks up generation

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, mistral4

@3outeille
Copy link
Member Author

run-slow: auto, mistral4

@github-actions
Copy link
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/auto", "models/mistral4"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN b90678b2 workflow commit (merge commit)
PR 45f5bb7b branch commit (from PR)
main 21950930 base commit (on main)

Model CI Report

2 new failed tests from this PR 😭

  • mistral4:
    tests/models/mistral4/test_modeling_mistral4.py::Mistral4ModelTest::test_flex_attention_with_grads (✅ ⟹ ❌)
    tests/models/mistral4/test_modeling_mistral4.py::Mistral4ModelTest::test_sdpa_can_dispatch_on_flash (✅ ⟹ ❌)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants