@@ -70,3 +70,38 @@ def test_embedding_tile_layout(device, batch_size, sentence_size, vocabulary_siz
70
70
assert [node .target for node in nodes ].count (ttnn .embedding ) == 1
71
71
# Check inference result
72
72
assert torch .allclose (result_before , result_after )
73
+
74
+
75
+ @pytest .mark .parametrize (
76
+ "batch, sentence_size, vocabulary_size, hidden_embedding_dim, converted" ,
77
+ [
78
+ (1 , 384 , 160 , 1024 , True ),
79
+ (8 , 384 , 256 , 512 , True ),
80
+ # TODO(TODO): Not support vocabulary size > 256
81
+ (8 , 384 , 512 , 1024 , False ),
82
+ ],
83
+ )
84
+ def test_embedding_backward_tile_layout (device , batch , sentence_size , vocabulary_size , hidden_embedding_dim , converted ):
85
+ m = EmbeddingTileLayoutModule ()
86
+ input = torch .randint (0 , vocabulary_size , (batch , sentence_size ), dtype = torch .int64 )
87
+ weights = torch .rand ((vocabulary_size , hidden_embedding_dim ), dtype = torch .bfloat16 )
88
+ grad_data = torch .rand ((batch , sentence_size , hidden_embedding_dim ))
89
+
90
+ weights_before = weights .clone ().detach ().requires_grad_ (True )
91
+ forward_output = m .forward (input , weights_before )
92
+ forward_output .backward (gradient = grad_data )
93
+
94
+ option = torch_ttnn .TorchTtnnOption (device = device , gen_graphviz = True )
95
+ # The compilation is lazy, so we need to run forward once to trigger the compilation
96
+ m = torch .compile (m , backend = torch_ttnn .backend , options = option )
97
+ weights_after = weights .clone ().detach ().requires_grad_ (True )
98
+ forward_output = m .forward (input , weights_after )
99
+ forward_output .backward (gradient = grad_data )
100
+
101
+ # Check the graph has be rewritten
102
+ nodes = list (option ._out_fx_graphs [- 1 ].nodes )
103
+ assert [node .target for node in nodes ].count (ttnn .embedding_bw ) == (1 if converted else 0 )
104
+ # Check inference result
105
+ assert weights_before .grad .shape == weights_after .grad .shape
106
+ # Multiple float multiplications needs a higher tolerance
107
+ assert torch .allclose (weights_before .grad , weights_after .grad , rtol = 0.1 )
0 commit comments