Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.

Commit 6e09bc5

Browse files
authored
Merge pull request databricks#41 from vchiley/enable_flat_inputs
Enable generic dimentionality for input
2 parents 059ae20 + 40e4918 commit 6e09bc5

File tree

3 files changed

+4
-7
lines changed

3 files changed

+4
-7
lines changed

megablocks/layers/dmoe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,9 @@ def sparse_forward_once(self, x, expert_weights, top_experts):
132132
with torch.no_grad():
133133
indices, bin_ids, bins, padded_bins, tokens_per_expert = (
134134
self.indices_and_padded_bins(top_experts))
135-
sl, bs, hs = x.size()
136135

137136
# Route the tokens for MoE computation.
138-
x = x.view(sl * bs, hs)
137+
x = x.view(-1, x.shape[-1])
139138
x = ops.padded_gather(
140139
x,
141140
indices,

megablocks/layers/moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,13 @@ def parallel_forward_once(self, x, expert_weights, top_experts):
419419
return x, tokens_per_expert.flatten()
420420

421421
def forward(self, x, scores, expert_weights, top_experts):
422-
sl, bs, hs = x.size()
422+
in_shape = x.size()
423423

424424
# Compute the experts.
425425
x, tokens_per_expert = self.forward_fn(
426426
x, expert_weights, top_experts)
427427
save_load_balancing_loss((tokens_per_expert, scores))
428-
x = x.view(sl, bs, hs)
428+
x = x.view(in_shape)
429429
if self.bias is not None:
430430
if self.args.return_bias:
431431
return x, self.bias
@@ -448,7 +448,6 @@ def forward(self, x):
448448
# NOTE: If we're going to cast the activations to lower precision
449449
# do it before we permute the tokens to save bandwidth.
450450
x = common.cast_if_autocast_enabled(x)
451-
sl, bs, hs = x.size()
452451

453452
# Compute the expert scores and assignments.
454453
scores, expert_weights, top_experts = self.router(x)

megablocks/layers/router.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def forward(self, x):
5353
if self.training and self.args.moe_jitter_eps is not None:
5454
x = x * self.jitter(x)
5555

56-
sl, bs, hs = x.size()
57-
scores = self.layer(x.view(-1, hs)).softmax(dim=-1)
56+
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
5857
expert_weights, expert_indices = self._top_k(scores)
5958

6059
expert_indices = (

0 commit comments

Comments
 (0)