@@ -141,8 +141,13 @@ def __init__(self, ctx):
141
141
input_dtype = ctx .data_type ,
142
142
output_dtype = ctx .data_type ,
143
143
)
144
- self .flux_ag_op = flux .GemmGroupedV3AGScatter (tp_env = tp_env , moe_args = moe_args )
145
- self .flux_rs_op = flux .GemmGroupedV3GatherRS (ctx .nexperts , flux_m_max , ctx .h , ctx .topk , RANK , WORLD_SIZE , ctx .ffn_tp_size , ctx .ep_size , 1 )
144
+
145
+ if flux .util .get_arch () >= 90 :
146
+ self .flux_ag_op = flux .GemmGroupedV3AGScatter (tp_env = tp_env , moe_args = moe_args )
147
+ self .flux_rs_op = flux .GemmGroupedV3GatherRS (ctx .nexperts , flux_m_max , ctx .h , ctx .topk , RANK , WORLD_SIZE , ctx .ffn_tp_size , ctx .ep_size , 1 )
148
+ else :
149
+ self .flux_ag_op = flux .GemmGroupedV2AGScatterOp (tp_env = tp_env , moe_args = moe_args )
150
+ self .flux_rs_op = flux .GemmGroupedV2GatherRSOp (TP_GROUP , ctx .nexperts , flux_m_max , ctx .h , ctx .topk , ctx .data_type , ctx .ffn_tp_size , ctx .ep_size , 1 )
146
151
147
152
def forward (self ):
148
153
@@ -165,7 +170,7 @@ def forward(self):
165
170
input = self .ctx .intermediate_output ,
166
171
weight = self .ctx .weight1 ,
167
172
splits_cpu = self .ctx .splits_cpu ,
168
- routing_idx = self .ctx .scatter_index .view (- 1 ),
173
+ scatter_idx = self .ctx .scatter_index .view (- 1 ),
169
174
)
170
175
171
176
return mlp_output
0 commit comments