Skip to content

Training on ROCm (gfx1151, Strix Halo) results in NaN losses with Gemma3 fine-tuning #3385

@kyuz0

Description

@kyuz0

Summary

Fine-tuning Gemma-3 with Unsloth on AMD Strix Halo (gfx1151) shows NaN loss from the first step.
NaNs seems to originate in the forward pass (logits/hidden_states), not in the optimizer/backward.
Reproduces with FlashAttention/xformers disabled and even when forcing fp32.

System

  • Hardware: AMD Strix Halo (gfx1151)
  • Host: Fedora 42 toolbox; container Ubuntu 24.04
  • Base image: rocm/pytorch:rocm6.4.4_ubuntu24.04_py3.12_pytorch_release_2.7.1

Key package versions (from container)

Command:

python - <<'PY'
import torch, importlib
mods = ["unsloth","unsloth_zoo","transformers","trl","accelerate","peft","xformers","bitsandbytes","triton"]
for m in mods:
    try:
        print(m, importlib.import_module(m).__version__)
    except Exception as e:
        print(m, "not found")
print("torch:", torch.__version__, "HIP:", torch.version.hip)
print("cuda.is_available:", torch.cuda.is_available(), "bf16_supported:", torch.cuda.is_bf16_supported())
print("device:", torch.cuda.get_device_name(0))
PY

Output:

unsloth 2025.9.9
unsloth_zoo 2025.9.12
transformers 4.56.2
trl 0.23.0
accelerate 1.10.1
peft 0.17.1
xformers 0.0.30+13c93f39.d20250927
bitsandbytes 0.43.3.dev
triton 3.3.1
torch: 2.7.1+git99ccf24 HIP: 6.4.43484-123eb5128
cuda.is_available: True bf16_supported: True
device: AMD Radeon Graphics

Repro (trainer)

import unsloth
from unsloth import FastModel
from transformers import AutoTokenizer
from trl import SFTTrainer, SFTConfig

name = "unsloth/gemma-3-4b-it"
tok = AutoTokenizer.from_pretrained(name)
model, _ = FastModel.from_pretrained(
    name, max_seq_length=2048,
    load_in_4bit=False, load_in_8bit=False, full_finetuning=False,
)
trainer = SFTTrainer(
    model=model, tokenizer=tok,
    train_dataset=[{"text": "hello world"}]*16,
    args=SFTConfig(
        dataset_text_field="text",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5, max_steps=5,
        learning_rate=2e-4, logging_steps=1,
        optim="adamw_8bit", report_to="none",
    ),
)
trainer.train()

Observed: training loss logs as nan from step 1.

Evidence (forward path + toggles)

1) Forward loss is NaN; grads are not NaN

Command:

python - <<'PY'
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
print("forward_loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
out.loss.backward()
has_nan = any(p.grad is not None and torch.isnan(p.grad).any() for p in m.parameters())
print("grad_has_nan:", has_nan)
PY

Output:

forward_loss_is_nan: True loss: nan
grad_has_nan: False

2) Disable FlashAttention/xformers → still NaN

Command:

FLASH_ATTENTION_DISABLE=1 XFORMERS_DISABLE_FLASH_ATTN=1 python - <<'PY'
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.train().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
out = m(**b, labels=b["input_ids"])
print("FA/xformers disabled -> loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
PY

Output:

FA/xformers disabled -> loss_is_nan: True loss: nan

3) Force fp32 → still NaN

Command:

python - <<'PY'
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m = m.to(dtype=torch.float32).cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True)
b = {k:(v.to("cuda").to(torch.float32) if v.dtype.is_floating_point else v.to("cuda")) for k,v in b.items()}
with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=False):
    out = m(**b, labels=b["input_ids"])
print("fp32 forced -> loss_is_nan:", torch.isnan(out.loss).item(), "loss:", float(out.loss))
PY

Output:

fp32 forced -> loss_is_nan: True loss: nan

4) Inference logits contain NaN (no labels)

Command:

python - <<'PY'
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.eval().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
    out = m(**b, return_dict=True)
logits = out.logits
print("logits_dtype:", logits.dtype, "shape:", tuple(logits.shape))
print("logits_has_nan:", torch.isnan(logits).any().item(), "has_inf:", torch.isinf(logits).any().item())
PY

Output:

logits_dtype: torch.bfloat16 shape: (2, 3, 262208)
logits_has_nan: True has_inf: False

5) Hidden states: embeddings OK; first transformer block outputs NaN

Command:

python - <<'PY'
import unsloth, torch
from unsloth import FastModel
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it")
m,_ = FastModel.from_pretrained("unsloth/gemma-3-4b-it", load_in_4bit=False, load_in_8bit=False, full_finetuning=False)
m.eval().cuda()
b = tok(["hello world"]*2, return_tensors="pt", padding=True).to("cuda")
with torch.no_grad():
    out = m(**b, return_dict=True, output_hidden_states=True, output_attentions=False)
hs = out.hidden_states
print("num_hidden_states:", len(hs))
for i,t in enumerate(hs[:5]):  # first few for brevity
    print(f"layer_{i}_nan:", bool(torch.isnan(t).any().item() or torch.isinf(t).any().item()), "dtype:", t.dtype, "shape:", tuple(t.shape))
print("first_bad_layer_index:", next((i for i,t in enumerate(hs) if torch.isnan(t).any().item() or torch.isinf(t).any().item()), None))
PY

Output:

num_hidden_states: 35
layer_0_nan: False dtype: torch.bfloat16 shape: (2, 3, 2560)
layer_1_nan: True dtype: torch.bfloat16 shape: (2, 3, 2560)
layer_2_nan: True dtype: torch.bfloat16 shape: (2, 3, 2560)
layer_3_nan: True dtype: torch.bfloat16 shape: (2, 3, 2560)
layer_4_nan: True dtype: torch.bfloat16 shape: (2, 3, 2560)
first_bad_layer_index: 1

6) Backend sanity (bf16 matmul is fine)

Command:

python - <<'PY'
import torch
x = torch.randn(2048, 2048, device="cuda", dtype=torch.bfloat16)
y = x @ x
print("bf16_matmul_nan:", torch.isnan(y).any().item())
PY

Output:

bf16_matmul_nan: False

Notes

  • Same NaN behavior also observed when load_in_4bit=True (reporter tested separately).

Dockerfile

FROM rocm/pytorch:rocm6.4.4_ubuntu24.04_py3.12_pytorch_release_2.7.1

WORKDIR /opt/src

# bitsandbytes (ROCm)
RUN git clone -b rocm_enabled_multi_backend https://github.com/ROCm/bitsandbytes.git
WORKDIR /opt/src/bitsandbytes
RUN cmake -S . -DGPU_TARGETS="gfx1151" -DBNB_ROCM_ARCH="gfx1151" -DCOMPUTE_BACKEND=hip && \
    make -j && \
    python -m pip install --no-cache-dir .

# Python deps
RUN python -m pip install --no-cache-dir \
      'datasets>=3.4.1' \
      'sentencepiece>=0.2.0' \
      tqdm psutil 'wheel>=0.42.0' \
      'accelerate>=0.34.1' \
      'peft>=0.7.1,!=0.11.0' \
      einops packaging 	

# xformers (pinned)
WORKDIR /opt/src
RUN git clone https://github.com/ROCm/xformers.git
WORKDIR /opt/src/xformers
RUN git submodule update --init --recursive && \
    git checkout 13c93f3 && \
    PYTORCH_ROCM_ARCH=gfx1151 python setup.py install

ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
WORKDIR /root
RUN git clone https://github.com/ROCm/flash-attention.git
RUN cd flash-attention && git checkout v2.7.4-cktile && python setup.py install

# Unsloth (install first), then Zoo
WORKDIR /opt/src
RUN git clone https://github.com/unslothai/unsloth.git
WORKDIR /opt/src/unsloth
RUN python -m pip install --no-cache-dir .
RUN python -m pip install --no-cache-dir 'unsloth_zoo>=2025.5.7'

WORKDIR /opt/src
CMD ["/bin/bash"]

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions