Skip to content

Commit 76b5898

Browse files
committed
format
1 parent 620ea83 commit 76b5898

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

csrc/moe/moe_ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
88
const int64_t n_topk_group, const c10::string_view scoring_func,
99
const c10::optional<torch::Tensor>& bias);
1010

11-
1211
void moe_sum(torch::Tensor& input, torch::Tensor& output);

tests/register_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def register_fake(fn):
1717
except ImportError:
1818
from torch.library import impl_abstract as register_fake
1919

20+
2021
# layer norm ops
2122
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
2223
epsilon: float) -> None:
@@ -127,6 +128,7 @@ def dynamic_per_token_scaled_fp8_quant(
127128
def moe_sum(input: torch.Tensor, output: torch.Tensor) -> None:
128129
torch.ops._moe_C.moe_sum(input, output)
129130

131+
130132
def grouped_topk(
131133
hidden_states: torch.Tensor,
132134
gating_output: torch.Tensor,
@@ -162,4 +164,4 @@ def _grouped_topk_fake(
162164
topk_indices = torch.empty((gating_output.size(0), topk),
163165
dtype=torch.int32,
164166
device=hidden_states.device)
165-
return topk_weights, topk_indices
167+
return topk_weights, topk_indices

0 commit comments

Comments
 (0)