Skip to content

Commit 13e48df

Browse files
authored
update the moe to support arch < 90 (#123)
* update the moe to support arch < 90 * revert mlp_output * fix the flux_rs_op 4-th arg name, maybe the origin version routing_idx is wrong
1 parent 4854650 commit 13e48df

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

examples/moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,13 @@ def __init__(self, ctx):
141141
input_dtype=ctx.data_type,
142142
output_dtype=ctx.data_type,
143143
)
144-
self.flux_ag_op = flux.GemmGroupedV3AGScatter(tp_env=tp_env, moe_args=moe_args)
145-
self.flux_rs_op = flux.GemmGroupedV3GatherRS(ctx.nexperts, flux_m_max, ctx.h, ctx.topk, RANK, WORLD_SIZE, ctx.ffn_tp_size, ctx.ep_size, 1)
144+
145+
if flux.util.get_arch() >= 90:
146+
self.flux_ag_op = flux.GemmGroupedV3AGScatter(tp_env=tp_env, moe_args=moe_args)
147+
self.flux_rs_op = flux.GemmGroupedV3GatherRS(ctx.nexperts, flux_m_max, ctx.h, ctx.topk, RANK, WORLD_SIZE, ctx.ffn_tp_size, ctx.ep_size, 1)
148+
else:
149+
self.flux_ag_op = flux.GemmGroupedV2AGScatterOp(tp_env=tp_env, moe_args=moe_args)
150+
self.flux_rs_op = flux.GemmGroupedV2GatherRSOp(TP_GROUP, ctx.nexperts, flux_m_max, ctx.h, ctx.topk, ctx.data_type, ctx.ffn_tp_size, ctx.ep_size, 1)
146151

147152
def forward(self):
148153

@@ -165,7 +170,7 @@ def forward(self):
165170
input=self.ctx.intermediate_output,
166171
weight=self.ctx.weight1,
167172
splits_cpu=self.ctx.splits_cpu,
168-
routing_idx=self.ctx.scatter_index.view(-1),
173+
scatter_idx=self.ctx.scatter_index.view(-1),
169174
)
170175

171176
return mlp_output

0 commit comments

Comments
 (0)