Skip to content

Commit fc4cfa6

Browse files
committed
Convert aten.embedding_dense_backward to ttnn.embedding_bw
1 parent e0abc51 commit fc4cfa6

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

tests/lowering/embedding/test_embedding.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,38 @@ def test_embedding_tile_layout(device, batch_size, sentence_size, vocabulary_siz
7070
assert [node.target for node in nodes].count(ttnn.embedding) == 1
7171
# Check inference result
7272
assert torch.allclose(result_before, result_after)
73+
74+
75+
@pytest.mark.parametrize(
76+
"batch, sentence_size, vocabulary_size, hidden_embedding_dim, converted",
77+
[
78+
(1, 384, 160, 1024, True),
79+
(8, 384, 256, 512, True),
80+
# TODO(TODO): Not support vocabulary size > 256
81+
(8, 384, 512, 1024, False),
82+
],
83+
)
84+
def test_embedding_backward_tile_layout(device, batch, sentence_size, vocabulary_size, hidden_embedding_dim, converted):
85+
m = EmbeddingTileLayoutModule()
86+
input = torch.randint(0, vocabulary_size, (batch, sentence_size), dtype=torch.int64)
87+
weights = torch.rand((vocabulary_size, hidden_embedding_dim), dtype=torch.bfloat16)
88+
grad_data = torch.rand((batch, sentence_size, hidden_embedding_dim))
89+
90+
weights_before = weights.clone().detach().requires_grad_(True)
91+
forward_output = m.forward(input, weights_before)
92+
forward_output.backward(gradient=grad_data)
93+
94+
option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
95+
# The compilation is lazy, so we need to run forward once to trigger the compilation
96+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
97+
weights_after = weights.clone().detach().requires_grad_(True)
98+
forward_output = m.forward(input, weights_after)
99+
forward_output.backward(gradient=grad_data)
100+
101+
# Check the graph has be rewritten
102+
nodes = list(option._out_fx_graphs[-1].nodes)
103+
assert [node.target for node in nodes].count(ttnn.embedding_bw) == (1 if converted else 0)
104+
# Check inference result
105+
assert weights_before.grad.shape == weights_after.grad.shape
106+
# Multiple float multiplications needs a higher tolerance
107+
assert torch.allclose(weights_before.grad, weights_after.grad, rtol=0.1)

torch_ttnn/passes/lowering/add_data_move_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ def is_tt_compute(node) -> bool:
153153
+ TTNN_NORM_OPS
154154
+ [
155155
ttnn.embedding,
156+
ttnn.embedding_bw,
156157
ttnn.ones,
157158
ttnn.tril,
158159
ttnn.arange,
160+
ttnn.zeros,
159161
ttnn.zeros_like,
160162
ttnn.mean,
161163
ttnn.global_avg_pool2d,

torch_ttnn/passes/lowering/to_tt_pass.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from typing import Tuple
1313
import torch_ttnn.metrics as metrics
14+
import math
1415

1516
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1617
import torch.fx.traceback as fx_traceback
@@ -623,6 +624,28 @@ def rewrite_node(node):
623624
input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout()))
624625
return g.call_function(ttnn.pad, args=(input, full_pad, value))
625626

627+
if node.target == torch.ops.aten.embedding_dense_backward.default:
628+
grad_output, indices, num_weights, padding_idx, scale_grad_by_freq = args
629+
# TODO(TODO): Not support padding_idx and scale_grad_by_freq
630+
if padding_idx != -1 or scale_grad_by_freq:
631+
return None
632+
if num_weights > 256:
633+
return None
634+
# Change indices to row-major layout to support non-tile-aligned shape
635+
indices = g.call_function(ttnn.to_layout, args=(indices, TtnnRowMajorLayout()))
636+
# Reconstruct the weight tensor solely for vocabulary size
637+
grad_shape = grad_output.meta["val"].size()
638+
embedding_dim = grad_shape[-1]
639+
weights = g.call_function(
640+
ttnn.zeros, args=((num_weights, embedding_dim),), kwargs={"device": TtnnDevice()}
641+
)
642+
# Pack grad_output into (1, 1, x, embedding dim)
643+
new_grad_shape = (1, 1, math.prod(grad_shape[:-1]), embedding_dim)
644+
grad_output = g.call_function(ttnn.reshape, args=(grad_output, new_grad_shape))
645+
646+
result = g.call_function(ttnn.embedding_bw, args=(indices, weights, grad_output))
647+
return g.call_function(ttnn.reshape, args=(result, node.meta["val"].size()))
648+
626649
with g.inserting_before(node):
627650
new_node = rewrite_node(node)
628651
if new_node is not None:

0 commit comments

Comments
 (0)