Skip to content

Commit 0d5b261

Browse files
committed
remove unused assert
1 parent f243fd1 commit 0d5b261

File tree

2 files changed

+81
-12
lines changed

2 files changed

+81
-12
lines changed

src/optimum/rbln/ops/moe.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from torch import Tensor
19+
20+
21+
@torch.library.custom_op(
22+
"rbln_custom_ops::custom_moe_glu",
23+
mutates_args=(),
24+
)
25+
def custom_moe_glu(
26+
hidden_states: Tensor,
27+
gate_proj_weight: Tensor,
28+
up_proj_weight: Tensor,
29+
down_proj_weight: Tensor,
30+
masked_routing_weight: Tensor,
31+
expert_select_count: Tensor,
32+
gate_proj_bias: Optional[Tensor] = None,
33+
up_proj_bias: Optional[Tensor] = None,
34+
down_proj_bias: Optional[Tensor] = None,
35+
) -> Tensor:
36+
"""
37+
Customized MoE GLU operation.
38+
Expected tensor shapes:
39+
- hidden_states: [batch*seq_len, hidden_size]
40+
- gate_proj_weight: [num_experts, hidden_size, intermediate_size]
41+
- up_proj_weight: [num_experts, hidden_size, intermediate_size]
42+
- down_proj_weight: [num_experts, intermediate_size, hidden_size]
43+
- masked_routing_weight: [batch * seq_len, num_experts]
44+
- gate_proj_bias: [num_experts, intermediate_size]
45+
- up_proj_bias: [num_experts, intermediate_size]
46+
- down_proj_bias: [num_experts, hidden_size]
47+
Returns:
48+
Tensor: [batch * seq_len, hidden_size]
49+
"""
50+
51+
out = torch.zeros_like(hidden_states)
52+
expert_cnt = gate_proj_weight.shape[0]
53+
for i in range(expert_cnt):
54+
gate_proj = torch.nn.functional.linear(hidden_states, gate_proj_weight[i])
55+
up_proj = torch.nn.functional.linear(hidden_states, up_proj_weight[i])
56+
mul = torch.nn.functional.silu(gate_proj) * up_proj
57+
down_proj = torch.nn.functional.linear(mul, down_proj_weight[i])
58+
out += down_proj * masked_routing_weight[:, i : i + 1]
59+
60+
return out
61+
62+
63+
@custom_moe_glu.register_fake
64+
def custom_moe_glu_fake(
65+
hidden_states: Tensor,
66+
gate_proj_weight: Tensor,
67+
up_proj_weight: Tensor,
68+
down_proj_weight: Tensor,
69+
masked_routing_weight: Tensor,
70+
expert_select_count: Tensor,
71+
gate_proj_bias: Optional[Tensor] = None,
72+
up_proj_bias: Optional[Tensor] = None,
73+
down_proj_bias: Optional[Tensor] = None,
74+
) -> Tensor:
75+
return torch.empty_like(hidden_states)

src/optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,18 +289,12 @@ def decode_forward(
289289
if batch_size != cache_position.shape[0]:
290290
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
291291

292-
if is_external_block_tables:
293-
if attention_mask is None:
294-
raise ValueError("attention_mask should be provided with external block tables.")
295-
if local_block_tables is None:
296-
raise ValueError("local_block_tables should be provided with external block tables.")
297-
else:
298-
if self.rbln_config.use_local_attention:
299-
local_block_tables = (
300-
local_block_tables
301-
if local_block_tables is not None
302-
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
303-
)
292+
if self.rbln_config.use_local_attention:
293+
local_block_tables = (
294+
local_block_tables
295+
if local_block_tables is not None
296+
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
297+
)
304298

305299
if self.rbln_config.use_attention_mask and attention_mask is None:
306300
for b_idx in range(batch_size):

0 commit comments

Comments
 (0)