Skip to content

AutoModelForSequenceClassification with attn_implementation="flash_attention_3" causes degenerate training (loss increases, model predicts all-one-class) #44829

@Jantory

Description

@Jantory

System Info

When fine-tuning Qwen3ForSequenceClassification (loaded via AutoModelForSequenceClassification) with attn_implementation="flash_attention_3", training completely fails: loss increases instead of decreasing, and the model collapses to predicting all examples as one class. Removing attn_implementation="flash_attention_3" (falling back to default attention) fixes the issue immediately.

Environment:

Hardware: NVIDIA H100 (Hopper)
transformers version: (your version)
flash-attn version: (your version)
Model: Qwen/Qwen3-Embedding-8B
PEFT / LoRA applied on top

Observed: loss increases (e.g. 0.35 → 0.41), eval_recall=1.0 with threshold≈0 (all predicted positive), F1 stuck at positive-class base rate.

Note: The issue appears specific to Qwen3ForSequenceClassification + FA3. The same model backbone with FA3 works correctly in other use cases (e.g. feature extraction / embedding), suggesting the problem lies in the last-token pooling or score head gradient path under FA3.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen3-Embedding-8B",
    num_labels=2,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_3",  # remove this → works
)
# train with HF Trainer on binary classification task

Expected behavior

normal convergence.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions