Skip to content

Commit 723ea9f

Browse files
committed
move the rotation operation from rotation trick to a function for potential reuse for other research
1 parent e7ff7d7 commit 723ea9f

File tree

2 files changed

+35
-30
lines changed

2 files changed

+35
-30
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.19.4"
3+
version = "1.19.5"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/vector_quantize_pytorch.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,6 @@ def lens_to_mask(lens, max_length):
7676
seq = torch.arange(max_length, device = lens.device)
7777
return seq < lens[:, None]
7878

79-
def efficient_rotation_trick_transform(u, q, e):
80-
"""
81-
4.2 in https://arxiv.org/abs/2410.06424
82-
"""
83-
e = rearrange(e, 'b d -> b 1 d')
84-
w = l2norm(u + q, dim = 1).detach()
85-
86-
return (
87-
e -
88-
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
89-
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
90-
)
91-
9279
def uniform_init(*shape):
9380
t = torch.empty(shape)
9481
nn.init.kaiming_uniform_(t)
@@ -248,6 +235,39 @@ def kmeans(
248235

249236
return means, bins
250237

238+
# rotation trick related
239+
240+
def efficient_rotation_trick_transform(u, q, e):
241+
"""
242+
4.2 in https://arxiv.org/abs/2410.06424
243+
"""
244+
e = rearrange(e, 'b d -> b 1 d')
245+
w = l2norm(u + q, dim = 1).detach()
246+
247+
return (
248+
e -
249+
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
250+
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
251+
)
252+
253+
def rotate_from_to(src, tgt):
254+
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255+
tgt, inverse = pack_one(tgt, '* d')
256+
src, _ = pack_one(src, '* d')
257+
258+
norm_tgt = tgt.norm(dim = -1, keepdim = True)
259+
norm_src = src.norm(dim = -1, keepdim = True)
260+
261+
rotated_src = efficient_rotation_trick_transform(
262+
safe_div(tgt, norm_tgt),
263+
safe_div(src, norm_src),
264+
tgt
265+
).squeeze()
266+
267+
rotated = rotated_src * safe_div(norm_src, norm_tgt).detach()
268+
269+
return inverse(rotated)
270+
251271
# distributed helpers
252272

253273
@cache
@@ -1098,22 +1118,7 @@ def forward(
10981118
commit_quantize = maybe_detach(quantize)
10991119

11001120
if self.rotation_trick:
1101-
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1102-
x, inverse = pack_one(x, '* d')
1103-
quantize, _ = pack_one(quantize, '* d')
1104-
1105-
norm_x = x.norm(dim = -1, keepdim = True)
1106-
norm_quantize = quantize.norm(dim = -1, keepdim = True)
1107-
1108-
rot_quantize = efficient_rotation_trick_transform(
1109-
safe_div(x, norm_x),
1110-
safe_div(quantize, norm_quantize),
1111-
x
1112-
).squeeze()
1113-
1114-
quantize = rot_quantize * safe_div(norm_quantize, norm_x).detach()
1115-
1116-
x, quantize = inverse(x), inverse(quantize)
1121+
quantize = rotate_from_to(quantize, x)
11171122
else:
11181123
# standard STE to get gradients through VQ layer.
11191124
quantize = x + (quantize - x).detach()

0 commit comments

Comments
 (0)