Skip to content

Commit 62c70c7

Browse files
committed
fixed: Modified the topkgating function and modified the test_moe file for testing
1 parent 1ca83a6 commit 62c70c7

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

deepspeed/moe/sharded_moe.py

+1
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def topkgating(
429429
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
430430
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
431431
capacity = new_capacity
432+
locations = torch.cumsum(mask, dim=0) - 1
432433

433434
# normalize gates
434435
gates_masked = gates * mask

tests/unit/moe/test_moe.py

+10
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,15 @@ def check_equal(logits, cap, sparse_truth, res):
242242
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
243243

244244

245+
#s=4 e=4 topk=2 drop_tokens=False
246+
logits3 = torch.tensor([[0.95, 0.85, 0.90, 0.80], [0.70, 0.65, 0.75, 0.60], [0.50, 0.55, 0.45, 0.40],
247+
[0.35, 0.30, 0.25, 0.20]])
248+
logits3 *= dist.get_rank() + 1
249+
dispatch_res = topkgating(logits3, 2, 1, min_capacity=1, drop_tokens=False)[2]
250+
sec_sparse = torch.tensor([[0, 0, 0], [0, 2, 0], [1, 0, 1], [1, 2, 1], [2, 0, 2], [2, 1, 0], [3, 0, 3],
251+
[3, 1, 1]])
252+
check_equal(logits3, 4, sec_sparse, dispatch_res)
253+
245254
class TestExpertWeightGradWithZero(DistributedTest):
246255
world_size = 2
247256

@@ -351,3 +360,4 @@ def _get_weight_bias(experts):
351360
assert len(expert_weight_grad_ep1) == len(expert_weight_grad_ep2)
352361
for grad_from_ep1, grad_from_ep2 in zip(expert_weight_grad_ep1, expert_weight_grad_ep2):
353362
assert torch.allclose(grad_from_ep1, grad_from_ep2, atol=0, rtol=1e-4)
363+

0 commit comments

Comments
 (0)