Skip to content

MUSA: Use Monkey Patching to Automatically Convert CUDA Backend to MUSA #583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ktransformers/models/modeling_deepseek.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# coding=utf-8
'''
Description :
Description :
Author : Boxin Zhang
Version : 0.1.0
'''
# Adapted from
# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
Expand All @@ -31,6 +31,7 @@
from typing import List, Optional, Tuple, Union

import torch
from ktransformers.util.torch_auto_backend import CUDA
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -145,7 +146,7 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
Expand Down Expand Up @@ -322,7 +323,7 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()* self._mscale
sin = emb.sin()* self._mscale
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
Expand Down Expand Up @@ -1112,7 +1113,7 @@ def _flash_attention_forward(
cache_seqlens=position_ids,
softmax_scale=softmax_scale,
causal=causal,
)
)
else:
attn_output = flash_attn_func(
query_states,
Expand Down Expand Up @@ -1557,7 +1558,7 @@ def forward(
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
Expand Down Expand Up @@ -1629,7 +1630,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type == CUDA
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down
3 changes: 2 additions & 1 deletion ktransformers/models/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List, Optional, Tuple, Union

import torch
from ktransformers.util.torch_auto_backend import CUDA
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -1587,7 +1588,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type == CUDA
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down
5 changes: 3 additions & 2 deletions ktransformers/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import List, Optional, Tuple, Union

import torch
from ktransformers.util.torch_auto_backend import CUDA
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -709,7 +710,7 @@ def forward(

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
if query_states.device.type == CUDA and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand Down Expand Up @@ -1220,7 +1221,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type == CUDA
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down
21 changes: 11 additions & 10 deletions ktransformers/models/modeling_mixtral.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# coding=utf-8
'''
Description :
Description :
Author : kkk1nak0
Date : 2024-07-29 02:58:57
Version : 1.0.0
LastEditors : kkk1nak0
LastEditTime : 2024-08-02 06:08:34
'''

# Adapted from
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Expand All @@ -31,11 +31,12 @@
# limitations under the License.
"""PyTorch Mixtral model."""

import inspect
import inspect
import math
from typing import List, Optional, Tuple, Union

import torch
from ktransformers.util.torch_auto_backend import CUDA
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -201,7 +202,7 @@ def extra_repr(self):
class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
Expand Down Expand Up @@ -544,7 +545,7 @@ def forward(
attn_weights = None

return attn_output, attn_weights, past_key_value


def _flash_attention_forward(
self,
Expand Down Expand Up @@ -575,9 +576,9 @@ def _flash_attention_forward(
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout

"""

# Decide whether to use SWA or not by layer index.
# if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
# use_sliding_windows = False
Expand Down Expand Up @@ -633,7 +634,7 @@ def _flash_attention_forward(
cache_seqlens=position_ids,
softmax_scale=softmax_scale,
causal=is_causal,
)
)
else:
attn_output = flash_attn_func(
query_states,
Expand Down Expand Up @@ -766,7 +767,7 @@ def forward(

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
if query_states.device.type == CUDA and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand Down Expand Up @@ -1323,7 +1324,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type == CUDA
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down
13 changes: 7 additions & 6 deletions ktransformers/models/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# coding=utf-8
'''
Description :
Description :
Author : Boxin Zhang
Version : 0.1.0
'''
'''
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
#
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
Expand All @@ -32,6 +32,7 @@
from typing import List, Optional, Tuple, Union

import torch
from ktransformers.util.torch_auto_backend import CUDA
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
Expand Down Expand Up @@ -636,7 +637,7 @@ def _flash_attention_forward(
cache_seqlens=position_ids,
softmax_scale=softmax_scale,
causal=causal,
)
)
else:
attn_output = flash_attn_func(
query_states,
Expand Down Expand Up @@ -766,7 +767,7 @@ def forward(

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
if query_states.device.type == CUDA and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
Expand Down Expand Up @@ -1314,7 +1315,7 @@ def _update_causal_mask(
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and attention_mask.device.type == CUDA
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
Expand Down
37 changes: 19 additions & 18 deletions ktransformers/operators/RoPE.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Description :
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
"""

from torch import nn
Expand All @@ -27,6 +27,7 @@
from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig
import torch
from ktransformers.util.torch_auto_backend import CUDA

# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
Expand All @@ -37,8 +38,8 @@ def __init__(
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = CUDA,
prefill_device: str = CUDA,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefill_device and generate_device is not needed to change.
Change it in your custom yaml file. It's the same elsewhere.

Copy link
Contributor Author

@yeahdongcn yeahdongcn Feb 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to ensure consistency throughout the code (Both CUDA and "cuda" will be present.).
Would you like me to revert this change?

**kwargs,
):
BaseInjectedModule.__init__(
Expand Down Expand Up @@ -67,16 +68,16 @@ def __init__(
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = CUDA,
prefill_device: str = CUDA,
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device

@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -91,7 +92,7 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def load(self):
self._init(
Expand All @@ -117,8 +118,8 @@ def __init__(
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = CUDA,
prefill_device: str = CUDA,
**kwargs,
):
BaseInjectedModule.__init__(
Expand Down Expand Up @@ -155,8 +156,8 @@ def __init__(
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = CUDA,
prefill_device: str = CUDA,
**kwargs,
):
BaseInjectedModule.__init__(
Expand Down Expand Up @@ -225,16 +226,16 @@ def __init__(
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = CUDA,
prefill_device: str = CUDA,
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device

def load(self):
kwargs = {
key: self.config.rope_scaling[key]
Expand Down Expand Up @@ -270,7 +271,7 @@ def forward(self, x, position_ids):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()* self._mscale
sin = emb.sin()* self._mscale
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _init(
self,
Expand Down Expand Up @@ -332,8 +333,8 @@ def __init__(
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = CUDA,
generate_device: str = CUDA,
**kwargs,
):
BaseInjectedModule.__init__(
Expand Down
Loading
Loading