Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions unsloth_zoo/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
]

import torch
import os
import functools
from .utils import Version
import inspect
Expand Down Expand Up @@ -77,6 +78,10 @@ def get_device_count():
# HSA_STATUS_ERROR_EXCEPTION checks - sometimes AMD fails for BnB
ALLOW_BITSANDBYTES : bool = True
if DEVICE_TYPE == "hip":
# Disable AITER by default on ROCm to avoid JIT build locks and runtime faults.
# Users can override by explicitly setting env vars.
os.environ.setdefault("AITER_DISABLE", "1")
os.environ.setdefault("USE_ROCM_AITER_ROPE_BACKEND", "0")
try:
from bitsandbytes.nn.modules import Params4bit
if "blocksize = 64 if not HIP_ENVIRONMENT else 128" in inspect.getsource(Params4bit):
Expand Down
7 changes: 7 additions & 0 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,13 @@ def forward(
self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None
) -> torch.Tensor:
"""Forward using grouped_mm or loop fallback with LoRA support."""
# ROCm: ensure hidden_states matches expert weight dtype to avoid matmul type errors
if getattr(getattr(torch, "version", None), "hip", None):
target_dtype = getattr(getattr(self.down_proj, "weight", None), "dtype", None)
if target_dtype is None:
target_dtype = self.dtype
if hidden_states is not None and hidden_states.dtype != target_dtype:
hidden_states = hidden_states.to(target_dtype)
# Use optimized grouped_mm if available
if _check_torch_grouped_mm_supported():
return forward_native_grouped_mm(self, hidden_states, router_indices, routing_weights)
Expand Down
24 changes: 24 additions & 0 deletions unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,30 @@ def forward(
TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward)


def patch_rocm_disable_generate_cache():
try:
import transformers.generation.utils as generation_utils
except Exception as e:
return raise_error("GenerationMixin.generate", e)

if not getattr(getattr(torch, "version", None), "hip", None):
return

if getattr(generation_utils.GenerationMixin, "_unsloth_rocm_generate_patched", False):
return

original_generate = generation_utils.GenerationMixin.generate

def generate(self, *args, **kwargs):
kwargs["use_cache"] = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep cache enabled for assisted generation on HIP

Overwriting kwargs["use_cache"] to False unconditionally makes assisted decoding fail on ROCm whenever callers use assistant_model, prompt_lookup_num_tokens, or assistant_early_exit. In transformers (checked 4.57.6), the assisted path raises ValueError("assisted generate requires use_cache=True") when model_kwargs["use_cache"] is false, so this patch turns those valid generate() calls into hard failures instead of just applying a perf workaround.

Useful? React with 👍 / 👎.

return original_generate(self, *args, **kwargs)

generation_utils.GenerationMixin.generate = generate
generation_utils.GenerationMixin._unsloth_rocm_generate_patched = True
pass
TEMPORARY_PATCHES.append(patch_rocm_disable_generate_cache)


def patch_CsmForConditionalGeneration_forward():
try:
import transformers.models.csm.modeling_csm
Expand Down