@@ -76,19 +76,6 @@ def lens_to_mask(lens, max_length):
76
76
seq = torch .arange (max_length , device = lens .device )
77
77
return seq < lens [:, None ]
78
78
79
- def efficient_rotation_trick_transform (u , q , e ):
80
- """
81
- 4.2 in https://arxiv.org/abs/2410.06424
82
- """
83
- e = rearrange (e , 'b d -> b 1 d' )
84
- w = l2norm (u + q , dim = 1 ).detach ()
85
-
86
- return (
87
- e -
88
- 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
89
- 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
90
- )
91
-
92
79
def uniform_init (* shape ):
93
80
t = torch .empty (shape )
94
81
nn .init .kaiming_uniform_ (t )
@@ -248,6 +235,39 @@ def kmeans(
248
235
249
236
return means , bins
250
237
238
+ # rotation trick related
239
+
240
+ def efficient_rotation_trick_transform (u , q , e ):
241
+ """
242
+ 4.2 in https://arxiv.org/abs/2410.06424
243
+ """
244
+ e = rearrange (e , 'b d -> b 1 d' )
245
+ w = l2norm (u + q , dim = 1 ).detach ()
246
+
247
+ return (
248
+ e -
249
+ 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
250
+ 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
251
+ )
252
+
253
+ def rotate_from_to (src , tgt ):
254
+ # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255
+ tgt , inverse = pack_one (tgt , '* d' )
256
+ src , _ = pack_one (src , '* d' )
257
+
258
+ norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
259
+ norm_src = src .norm (dim = - 1 , keepdim = True )
260
+
261
+ rotated_src = efficient_rotation_trick_transform (
262
+ safe_div (tgt , norm_tgt ),
263
+ safe_div (src , norm_src ),
264
+ tgt
265
+ ).squeeze ()
266
+
267
+ rotated = rotated_src * safe_div (norm_src , norm_tgt ).detach ()
268
+
269
+ return inverse (rotated )
270
+
251
271
# distributed helpers
252
272
253
273
@cache
@@ -1098,22 +1118,7 @@ def forward(
1098
1118
commit_quantize = maybe_detach (quantize )
1099
1119
1100
1120
if self .rotation_trick :
1101
- # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1102
- x , inverse = pack_one (x , '* d' )
1103
- quantize , _ = pack_one (quantize , '* d' )
1104
-
1105
- norm_x = x .norm (dim = - 1 , keepdim = True )
1106
- norm_quantize = quantize .norm (dim = - 1 , keepdim = True )
1107
-
1108
- rot_quantize = efficient_rotation_trick_transform (
1109
- safe_div (x , norm_x ),
1110
- safe_div (quantize , norm_quantize ),
1111
- x
1112
- ).squeeze ()
1113
-
1114
- quantize = rot_quantize * safe_div (norm_quantize , norm_x ).detach ()
1115
-
1116
- x , quantize = inverse (x ), inverse (quantize )
1121
+ quantize = rotate_from_to (quantize , x )
1117
1122
else :
1118
1123
# standard STE to get gradients through VQ layer.
1119
1124
quantize = x + (quantize - x ).detach ()
0 commit comments