Skip to content

Commit 70457b7

Browse files
committed
MUSA: Use monkey patch for auto converting CUDA backend to MUSA backend
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 7b2a669 commit 70457b7

23 files changed

+388
-313
lines changed

ktransformers/models/modeling_deepseek.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# coding=utf-8
22
'''
3-
Description :
3+
Description :
44
Author : Boxin Zhang
55
Version : 0.1.0
66
'''
77
# Adapted from
88
# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py
99
# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
1010
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
11-
#
11+
#
1212
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
1313
# and OPT implementations in this library. It has been modified from its
1414
# original forms to accommodate minor architectural differences compared
@@ -31,6 +31,7 @@
3131
from typing import List, Optional, Tuple, Union
3232

3333
import torch
34+
from util.torch_auto_backend import CUDA
3435
import torch.nn.functional as F
3536
import torch.utils.checkpoint
3637
from torch import nn
@@ -145,7 +146,7 @@ def forward(self, x, position_ids):
145146
emb = torch.cat((freqs, freqs), dim=-1)
146147
cos = emb.cos()
147148
sin = emb.sin()
148-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
149+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
149150

150151
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
151152
class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
@@ -322,7 +323,7 @@ def forward(self, x, position_ids):
322323
emb = torch.cat((freqs, freqs), dim=-1)
323324
cos = emb.cos()* self._mscale
324325
sin = emb.sin()* self._mscale
325-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
326+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
326327

327328
# Copied from transformers.models.llama.modeling_llama.rotate_half
328329
def rotate_half(x):
@@ -1112,7 +1113,7 @@ def _flash_attention_forward(
11121113
cache_seqlens=position_ids,
11131114
softmax_scale=softmax_scale,
11141115
causal=causal,
1115-
)
1116+
)
11161117
else:
11171118
attn_output = flash_attn_func(
11181119
query_states,
@@ -1557,7 +1558,7 @@ def forward(
15571558
hidden_states=all_hidden_states,
15581559
attentions=all_self_attns,
15591560
)
1560-
1561+
15611562
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
15621563
def _update_causal_mask(
15631564
self,
@@ -1629,7 +1630,7 @@ def _update_causal_mask(
16291630
if (
16301631
self.config._attn_implementation == "sdpa"
16311632
and attention_mask is not None
1632-
and attention_mask.device.type == "cuda"
1633+
and attention_mask.device.type == CUDA
16331634
and not output_attentions
16341635
):
16351636
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

ktransformers/models/modeling_deepseek_v3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from typing import List, Optional, Tuple, Union
2424

2525
import torch
26+
from util.torch_auto_backend import CUDA
2627
import torch.nn.functional as F
2728
import torch.utils.checkpoint
2829
from torch import nn
@@ -1587,7 +1588,7 @@ def _update_causal_mask(
15871588
if (
15881589
self.config._attn_implementation == "sdpa"
15891590
and attention_mask is not None
1590-
and attention_mask.device.type == "cuda"
1591+
and attention_mask.device.type == CUDA
15911592
and not output_attentions
15921593
):
15931594
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

ktransformers/models/modeling_llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import List, Optional, Tuple, Union
2222

2323
import torch
24+
from util.torch_auto_backend import CUDA
2425
import torch.nn.functional as F
2526
import torch.utils.checkpoint
2627
from torch import nn
@@ -709,7 +710,7 @@ def forward(
709710

710711
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
711712
# Reference: https://github.com/pytorch/pytorch/issues/112577.
712-
if query_states.device.type == "cuda" and causal_mask is not None:
713+
if query_states.device.type == CUDA and causal_mask is not None:
713714
query_states = query_states.contiguous()
714715
key_states = key_states.contiguous()
715716
value_states = value_states.contiguous()
@@ -1220,7 +1221,7 @@ def _update_causal_mask(
12201221
if (
12211222
self.config._attn_implementation == "sdpa"
12221223
and attention_mask is not None
1223-
and attention_mask.device.type == "cuda"
1224+
and attention_mask.device.type == CUDA
12241225
and not output_attentions
12251226
):
12261227
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

ktransformers/models/modeling_mixtral.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# coding=utf-8
22
'''
3-
Description :
3+
Description :
44
Author : kkk1nak0
55
Date : 2024-07-29 02:58:57
66
Version : 1.0.0
77
LastEditors : kkk1nak0
88
LastEditTime : 2024-08-02 06:08:34
99
'''
1010

11-
# Adapted from
11+
# Adapted from
1212
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
1313
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
1414
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
@@ -31,11 +31,12 @@
3131
# limitations under the License.
3232
"""PyTorch Mixtral model."""
3333

34-
import inspect
34+
import inspect
3535
import math
3636
from typing import List, Optional, Tuple, Union
3737

3838
import torch
39+
from util.torch_auto_backend import CUDA
3940
import torch.nn.functional as F
4041
import torch.utils.checkpoint
4142
from torch import nn
@@ -201,7 +202,7 @@ def extra_repr(self):
201202
class MixtralRotaryEmbedding(nn.Module):
202203
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
203204
super().__init__()
204-
205+
205206
self.dim = dim
206207
self.max_position_embeddings = max_position_embeddings
207208
self.base = base
@@ -544,7 +545,7 @@ def forward(
544545
attn_weights = None
545546

546547
return attn_output, attn_weights, past_key_value
547-
548+
548549

549550
def _flash_attention_forward(
550551
self,
@@ -575,9 +576,9 @@ def _flash_attention_forward(
575576
position of padding tokens and 1 for the position of non-padding tokens.
576577
dropout (`float`):
577578
Attention dropout
578-
579+
579580
"""
580-
581+
581582
# Decide whether to use SWA or not by layer index.
582583
# if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
583584
# use_sliding_windows = False
@@ -633,7 +634,7 @@ def _flash_attention_forward(
633634
cache_seqlens=position_ids,
634635
softmax_scale=softmax_scale,
635636
causal=is_causal,
636-
)
637+
)
637638
else:
638639
attn_output = flash_attn_func(
639640
query_states,
@@ -766,7 +767,7 @@ def forward(
766767

767768
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
768769
# Reference: https://github.com/pytorch/pytorch/issues/112577.
769-
if query_states.device.type == "cuda" and attention_mask is not None:
770+
if query_states.device.type == CUDA and attention_mask is not None:
770771
query_states = query_states.contiguous()
771772
key_states = key_states.contiguous()
772773
value_states = value_states.contiguous()
@@ -1323,7 +1324,7 @@ def _update_causal_mask(
13231324
if (
13241325
self.config._attn_implementation == "sdpa"
13251326
and attention_mask is not None
1326-
and attention_mask.device.type == "cuda"
1327+
and attention_mask.device.type == CUDA
13271328
and not output_attentions
13281329
):
13291330
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

ktransformers/models/modeling_qwen2_moe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# coding=utf-8
22
'''
3-
Description :
3+
Description :
44
Author : Boxin Zhang
55
Version : 0.1.0
6-
'''
6+
'''
77
# Adapted from
88
# https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
99
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
1010
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
11-
#
11+
#
1212
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
1313
# and OPT implementations in this library. It has been modified from its
1414
# original forms to accommodate minor architectural differences compared
@@ -32,6 +32,7 @@
3232
from typing import List, Optional, Tuple, Union
3333

3434
import torch
35+
from util.torch_auto_backend import CUDA
3536
import torch.nn.functional as F
3637
import torch.utils.checkpoint
3738
from torch import nn
@@ -636,7 +637,7 @@ def _flash_attention_forward(
636637
cache_seqlens=position_ids,
637638
softmax_scale=softmax_scale,
638639
causal=causal,
639-
)
640+
)
640641
else:
641642
attn_output = flash_attn_func(
642643
query_states,
@@ -766,7 +767,7 @@ def forward(
766767

767768
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
768769
# Reference: https://github.com/pytorch/pytorch/issues/112577.
769-
if query_states.device.type == "cuda" and attention_mask is not None:
770+
if query_states.device.type == CUDA and attention_mask is not None:
770771
query_states = query_states.contiguous()
771772
key_states = key_states.contiguous()
772773
value_states = value_states.contiguous()
@@ -1314,7 +1315,7 @@ def _update_causal_mask(
13141315
if (
13151316
self.config._attn_implementation == "sdpa"
13161317
and attention_mask is not None
1317-
and attention_mask.device.type == "cuda"
1318+
and attention_mask.device.type == CUDA
13181319
and not output_attentions
13191320
):
13201321
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when

ktransformers/operators/RoPE.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
2-
Description :
2+
Description :
33
Author : Boxin Zhang
44
Version : 0.1.0
5-
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66
"""
77

88
from torch import nn
@@ -27,6 +27,7 @@
2727
from ktransformers.util.utils import InferenceState
2828
from transformers.configuration_utils import PretrainedConfig
2929
import torch
30+
from util.torch_auto_backend import CUDA
3031

3132
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
3233
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
@@ -37,8 +38,8 @@ def __init__(
3738
config: PretrainedConfig,
3839
orig_module: nn.Module,
3940
# device: str = "cuda",
40-
generate_device: str = "cuda",
41-
prefill_device: str = "cuda",
41+
generate_device: str = CUDA,
42+
prefill_device: str = CUDA,
4243
**kwargs,
4344
):
4445
BaseInjectedModule.__init__(
@@ -67,16 +68,16 @@ def __init__(
6768
config: PretrainedConfig,
6869
orig_module: nn.Module,
6970
# device: str = "cuda",
70-
generate_device: str = "cuda",
71-
prefill_device: str = "cuda",
71+
generate_device: str = CUDA,
72+
prefill_device: str = CUDA,
7273
**kwargs,
7374
):
7475
BaseInjectedModule.__init__(
7576
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
7677
)
7778
self.generate_device = generate_device
7879
self.prefill_device = prefill_device
79-
80+
8081
@torch.no_grad()
8182
def forward(self, x, position_ids):
8283
# x: [bs, num_attention_heads, seq_len, head_size]
@@ -91,7 +92,7 @@ def forward(self, x, position_ids):
9192
emb = torch.cat((freqs, freqs), dim=-1)
9293
cos = emb.cos()
9394
sin = emb.sin()
94-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
95+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
9596

9697
def load(self):
9798
self._init(
@@ -117,8 +118,8 @@ def __init__(
117118
gguf_loader: GGUFLoader,
118119
config: PretrainedConfig,
119120
orig_module: nn.Module,
120-
generate_device: str = "cuda",
121-
prefill_device: str = "cuda",
121+
generate_device: str = CUDA,
122+
prefill_device: str = CUDA,
122123
**kwargs,
123124
):
124125
BaseInjectedModule.__init__(
@@ -155,8 +156,8 @@ def __init__(
155156
config: PretrainedConfig,
156157
orig_module: nn.Module,
157158
# device: str = "cuda",
158-
generate_device: str = "cuda",
159-
prefill_device: str = "cuda",
159+
generate_device: str = CUDA,
160+
prefill_device: str = CUDA,
160161
**kwargs,
161162
):
162163
BaseInjectedModule.__init__(
@@ -225,16 +226,16 @@ def __init__(
225226
config: PretrainedConfig,
226227
orig_module: nn.Module,
227228
# device: str = "cuda",
228-
generate_device: str = "cuda",
229-
prefill_device: str = "cuda",
229+
generate_device: str = CUDA,
230+
prefill_device: str = CUDA,
230231
**kwargs,
231232
):
232233
BaseInjectedModule.__init__(
233234
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
234235
)
235236
self.generate_device = generate_device
236237
self.prefill_device = prefill_device
237-
238+
238239
def load(self):
239240
kwargs = {
240241
key: self.config.rope_scaling[key]
@@ -270,7 +271,7 @@ def forward(self, x, position_ids):
270271
emb = torch.cat((freqs, freqs), dim=-1)
271272
cos = emb.cos()* self._mscale
272273
sin = emb.sin()* self._mscale
273-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
274+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
274275

275276
def _init(
276277
self,
@@ -332,8 +333,8 @@ def __init__(
332333
gguf_loader: GGUFLoader,
333334
config: PretrainedConfig,
334335
orig_module: nn.Module,
335-
prefill_device: str = "cuda",
336-
generate_device: str = "cuda",
336+
prefill_device: str = CUDA,
337+
generate_device: str = CUDA,
337338
**kwargs,
338339
):
339340
BaseInjectedModule.__init__(

0 commit comments

Comments
 (0)