Skip to content

Comments

ROCm: disable cache in generate and fix GPT-OSS dtype#494

Open
danielhanchen wants to merge 5 commits intomainfrom
rocm-generate-cache-gpt-oss-dtype
Open

ROCm: disable cache in generate and fix GPT-OSS dtype#494
danielhanchen wants to merge 5 commits intomainfrom
rocm-generate-cache-gpt-oss-dtype

Conversation

@danielhanchen
Copy link
Contributor

@danielhanchen danielhanchen commented Feb 10, 2026

Summary

  • Force GenerationMixin.generate to use_cache=False on HIP to avoid ROCm HSA exceptions
  • Cast GPT-OSS expert inputs to expert weight dtype on HIP to avoid matmul type mismatch
  • Disable AITER and ROCm RoPE backend by default on HIP

Testing

  • Llama3.2_(1B_and_3B)-Conversational.ipynb (60 steps)
  • Gemma3_(4B)-Vision.ipynb (30 steps)
  • gpt-oss-(20B)-GRPO.ipynb (30 steps)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances compatibility and stability for Unsloth users running on ROCm (AMD GPUs). It addresses critical issues by disabling caching during generation, ensuring correct data type handling for GPT-OSS models, and deactivating certain experimental backends by default, all aimed at preventing common ROCm-specific errors and improving overall reliability.

Highlights

  • ROCm Cache Management: Disabled use_cache in GenerationMixin.generate specifically for HIP (ROCm) environments to prevent HSA exceptions.
  • GPT-OSS Data Type Handling: Implemented a cast for GPT-OSS expert inputs to match the expert weight data type on HIP, resolving matmul type mismatches.
  • ROCm Backend Defaults: Disabled AITER and the ROCm RoPE backend by default on HIP to avoid JIT build locks and runtime faults, with an option for users to override.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/device_type.py
    • Imported the os module.
    • Added default environment variable settings (AITER_DISABLE=1, USE_ROCM_AITER_ROPE_BACKEND=0) when DEVICE_TYPE is "hip" to disable specific ROCm features.
  • unsloth_zoo/temporary_patches/gpt_oss.py
    • Introduced logic within the forward method to explicitly cast hidden_states to the target_dtype of the expert weights when running on HIP, preventing type mismatch errors during matrix multiplication.
  • unsloth_zoo/temporary_patches/misc.py
    • Added a new patch function patch_rocm_disable_generate_cache that modifies transformers.generation.utils.GenerationMixin.generate to always set use_cache=False when the environment is HIP (ROCm).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several important fixes and compatibility improvements specifically for ROCm (HIP) environments. The changes disable AITER and ROCm RoPE backend by default, ensure correct data type handling for GPT-OSS expert inputs, and force use_cache=False during generation to prevent HSA exceptions. These adjustments are crucial for enhancing the stability and functionality of the system on ROCm hardware. The patches are well-isolated and include mechanisms to prevent re-patching, contributing to overall maintainability.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: db69274f88

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

original_generate = generation_utils.GenerationMixin.generate

def generate(self, *args, **kwargs):
kwargs["use_cache"] = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep cache enabled for assisted generation on HIP

Overwriting kwargs["use_cache"] to False unconditionally makes assisted decoding fail on ROCm whenever callers use assistant_model, prompt_lookup_num_tokens, or assistant_early_exit. In transformers (checked 4.57.6), the assisted path raises ValueError("assisted generate requires use_cache=True") when model_kwargs["use_cache"] is false, so this patch turns those valid generate() calls into hard failures instead of just applying a perf workaround.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Contributor Author

Quick status check:

  • Reviewed current local validation results after latest ROCm reruns.
  • No new unsloth-zoo changes were independently validated in this pass, so I did not push additional commits to this PR yet.
  • Current update in this pass was the verified trainer-init robustness fix in unsloth PR #4021.

@danielhanchen
Copy link
Contributor Author

Validation update with additional ROCm rechecks and a follow-up commit.

What changed

  • Added router dtype alignment in generated standalone classes (file: unsloth_zoo/compiler.py) for ROCm GPT-OSS compiled paths.
  • Made GPT-OSS expert dtype alignment mismatch-based (cast only when needed), not HIP-exclusive (file: unsloth_zoo/temporary_patches/gpt_oss.py).
  • Removed temporary compile-disable additions from this iteration and kept only the validated fix path.

Why

  • Reproduced clean/patch delta where GRPO failed with Float vs BF16 in compiled GPT-OSS router path.
  • A disable-only attempt did not resolve it; router source rewrite resolved it.

Evidence (ROCm, single GPU)

  • temp/run_232_clean_recheck_gpt_oss_grpo: FAILED (ModuleNotFoundError: kernels on non-BF16 path).
  • temp/run_233_patchon_recheck_gpt_oss_grpo: FAILED (Float vs BFloat16 in router path).
  • temp/run_237_patchon_routerrewriteonly_gpt_oss_grpo: SUCCESS.
  • temp/run_238_hwagnostic_dtype_gpt_oss_grpo: SUCCESS (RUN_EXIT=0).

Notes on compatibility

  • Compiler rewrite is HIP-gated, so NVIDIA code paths are unchanged.
  • Expert dtype cast is generic mismatch handling and runs only when dtypes differ.

Commit

  • e000231 pushed to rocm-generate-cache-gpt-oss-dtype.

@danielhanchen
Copy link
Contributor Author

Follow-up tweak on the compiler router cast.

Change

  • In unsloth_zoo/compiler.py, the HIP router input cast was adjusted from unconditional
    hidden_states.to(self.weight.dtype) to a mismatch-only cast:
    • hidden_states if hidden_states.dtype == self.weight.dtype else hidden_states.to(self.weight.dtype)

Why

  • Keeps the ROCm safeguard for Float/BF16 router matmul mismatch.
  • Avoids unnecessary casts when dtype already matches.
  • Preserves NVIDIA behavior unchanged (still HIP-gated).

Validation

  • temp/run_244_router_conditionalcast_gpt_oss_grpo: SUCCESS
    • losses: [0.0, 0.0]
    • grad norms: [0.264846..., 3.079287...]
    • reward columns present.

Commit

  • cdb92b0

@danielhanchen
Copy link
Contributor Author

Added a targeted GRPO VLM fix for Gemma3 Vision on ROCm.

Root cause

  • chunked_hidden_states_selective_log_softmax assumed inputs were always hidden states and always projected with @ lm_head.t().
  • In VLM GRPO paths, TRL/Unsloth can pass precomputed logits ([..., vocab]) instead.
  • That caused shape mismatch failures like RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x262208 and 2560x262208).

Change

  • In unsloth_zoo/rl_replacements.py, detect whether the incoming tensor is logits-shaped (last_dim == vocab_dim) vs hidden-state-shaped (last_dim == hidden_dim).
  • Use logits directly for the former and only project for the latter.
  • Keep fallback behavior unchanged for genuinely incompatible dimensions.

Validation

  • Reproduced failure in temp/run_266_Gemma3_4B_Vision_GRPO_nocompile_no_flex/output.log.
  • Re-ran with this patch in temp/run_267_Gemma3_4B_Vision_GRPO_nocompile_no_flex_logitsfix.
  • Run completed end-to-end (30/30 GRPO steps), with metrics/rewards/completions logged and no shape-mismatch traceback.

@danielhanchen
Copy link
Contributor Author

Added follow-up ROCm notebook fixes for audio/CSM paths.

  • Commit: 6dfd726
  • File:
    • unsloth_zoo/temporary_patches/misc.py

What changed:

  • Added torchcodec fallback handling when shared library load fails, allowing safe decode fallback paths.
  • Added datasets audio decode fallback via soundfile for environments where torchcodec decode is unavailable.
  • Added Deepseek OCR masked scatter guard patch to tolerate runtime shape mismatches encountered during notebook runs.
  • Hardened CSM forward/attention patches for ROCm eager path stability (including attention shape/rope alignment handling).
  • Added Whisper/ffmpeg fallback helpers for audio preprocessing edge cases seen during the sweep.

Validation evidence (ROCm runs):

  • CSM failure path: temp/run_278_Sesame_CSM_1B_TTS / temp/run_295_Sesame_CSM_1B_TTS_retry_ropealign / temp/run_297_Sesame_CSM_1B_TTS_retry_ropealign_v2
  • CSM fixed: temp/run_301_Sesame_CSM_1B_TTS_retry_ropealign_v3 (success)
  • Deepseek OCR failure/fix: temp/run_247_Deepseek_OCR_3B (fail) -> temp/run_249_Deepseek_OCR_3B_rerun (success)
  • Full tracked notebook status is now green in RUN_DETAILS.csv (latest row per notebook = SUCCESS for all 25 notebooks)

del sys.modules["torchcodec"]
except Exception:
pass
pass
return {"array": data, "sampling_rate": sr}
except Exception:
raise
pass
force = True,
match_level = "relaxed",
)
pass
sys.meta_path.insert(0, _DeepseekOCRHook())
except Exception as e:
return raise_error("DeepseekOCRModel", e)
pass

generation_utils.GenerationMixin.generate = generate
generation_utils.GenerationMixin._unsloth_rocm_generate_patched = True
pass
attn_output = attn_output.squeeze(1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
pass
force = True,
match_level = "relaxed",
)
pass
else:
raw_speech = _resample_array(raw_speech, sampling_rate, target_sr)
return original_call(self, raw_speech, sampling_rate = target_sr, *args, **kwargs)
pass
force = True,
match_level = "relaxed",
)
pass
asr.ffmpeg_read = ffmpeg_read
except Exception:
pass
pass
@danielhanchen
Copy link
Contributor Author

Superseding my prior malformed CLI comment with corrected content.

Re-review update for PR #494:

  • Re-checked ROCm patch paths against failing run traces and repeated A/B runs in amd_test.
  • Deepseek OCR masked_scatter patch remains required:
    • removal reproduced failure in temp/run_316_Deepseek_OCR_prunecheck
    • restoring patch fixed it in temp/run_317_Deepseek_OCR_prunecheck_retry
  • Whisper resample fallback remains required to handle 22.05k input path:
    • temp/run_314_Whisper_prunecheck showed immediate sampling-rate failure without fallback.

No new code changes were pushed to this branch in this pass. The only pushed cleanup in this re-review cycle was in unsloth PR #4021 to remove a duplicate call site.

) -> torch.Tensor:
"""Forward using grouped_mm or loop fallback with LoRA support."""
# Keep activations aligned with expert weights to avoid mixed-dtype matmul errors.
target_dtype = getattr(getattr(self.down_proj, "weight", None), "dtype", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we still seeing any errors? I remember checking both fp16 and bf16

kwargs["use_cache"] = False
# HIP-safe generation: drop cache-only kwargs that can route into
# unsupported codepaths and trigger assert_async failures.
for key in (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're disabling kv cache, we should at least warn when people call .generate
Otherwise we might see a barrage of complaints that inference is slow

PS: How does vLLM handle KVCache for HiP?

chunk_logits = chunk_hidden_states.to(lm_head.dtype) @ lm_head.t()
else:
# Fallback: try projection path and let the underlying matmul raise a
# precise error if the dimensions are genuinely incompatible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The elif and else can be combined into single call?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants