Skip to content

Commit d48d9cc

Browse files
authored
Merge pull request #235 from xju2/fix-3d
add an option for 3d images
2 parents c34a005 + 84c3b72 commit d48d9cc

2 files changed

Lines changed: 42 additions & 4 deletions

File tree

tests/test_readme.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def test_latent_q():
387387
quantizer = LatentQuantize(
388388
levels = [5, 5, 8], # number of levels per codebook dimension
389389
dim = 16, # input dim
390-
commitment_loss_weight=0.1,
390+
commitment_loss_weight=0.1,
391391
quantization_loss_weight=0.1,
392392
)
393393

@@ -488,3 +488,21 @@ def test_accum_ema_update():
488488
_ = vq(x)
489489

490490
assert not torch.allclose(codebook_before, vq.codebook, atol = 1e-6)
491+
492+
def test_vq_3d():
493+
from vector_quantize_pytorch import VectorQuantize
494+
495+
quantizer = VectorQuantize(
496+
dim = 64,
497+
codebook_size = 512, # codebook size
498+
accept_3d_fmap=True,
499+
)
500+
501+
x = torch.randn(1, 64, 16, 16, 16) # (B, C, D, H, W)
502+
quantizer.eval()
503+
504+
quantized, indices, commit_loss = quantizer(x)
505+
506+
assert x.shape == quantized.shape
507+
assert indices.shape == (1, 16, 16, 16)
508+
assert torch.allclose(quantized, quantizer.get_output_from_indices(indices))

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ def __init__(
770770
threshold_ema_dead_code = 0,
771771
channel_last = True,
772772
accept_image_fmap = False,
773+
accept_3d_fmap = False,
773774
commitment_weight = 1.,
774775
commitment_use_cross_entropy_loss = False,
775776
orthogonal_reg_weight = 0.,
@@ -906,6 +907,7 @@ def __init__(
906907
self.codebook_size = codebook_size
907908

908909
self.accept_image_fmap = accept_image_fmap
910+
self.accept_3d_fmap = accept_3d_fmap
909911
self.channel_last = channel_last
910912

911913
self.register_buffer('zero', tensor(0.), persistent = False)
@@ -953,7 +955,7 @@ def get_codes_from_indices(self, indices):
953955
codes = rearrange(codes, 'b h n d -> b n (h d)')
954956
codes = unpack_one(codes, 'b * d')
955957

956-
if not self.channel_last:
958+
if not self.channel_last or self.accept_image_fmap or self.accept_3d_fmap:
957959
codes = rearrange(codes, 'b ... d -> b d ...')
958960

959961
return codes
@@ -1000,7 +1002,11 @@ def update_ema_indices(self, x, indices, mask = None):
10001002
height, width = x.shape[-2:]
10011003
x = rearrange(x, 'b c h w -> b (h w) c')
10021004

1003-
if not self.channel_last and not self.accept_image_fmap:
1005+
if self.accept_3d_fmap:
1006+
assert not exists(mask)
1007+
x = rearrange(x, 'b c d h w -> b (d h w) c')
1008+
1009+
if not self.channel_last and not self.accept_image_fmap and not self.accept_3d_fmap:
10041010
x = rearrange(x, 'b d n -> b n d')
10051011

10061012
x = self.project_in(x)
@@ -1016,6 +1022,9 @@ def update_ema_indices(self, x, indices, mask = None):
10161022
if self.accept_image_fmap:
10171023
indices = rearrange(indices, 'b h w ... -> b (h w) ...')
10181024

1025+
if self.accept_3d_fmap:
1026+
indices = rearrange(indices, 'b d h w ... -> b (d h w) ...')
1027+
10191028
if x.ndim == 2: # only one token
10201029
indices = rearrange(indices, 'b ... -> b 1 ...')
10211030

@@ -1059,7 +1068,7 @@ def forward(
10591068

10601069
shape, dtype, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.dtype, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices)
10611070

1062-
need_transpose = not self.channel_last and not self.accept_image_fmap
1071+
need_transpose = not self.channel_last and not self.accept_image_fmap and not self.accept_3d_fmap
10631072
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
10641073

10651074
# rearrange inputs
@@ -1069,6 +1078,11 @@ def forward(
10691078
height, width = x.shape[-2:]
10701079
x = rearrange(x, 'b c h w -> b (h w) c')
10711080

1081+
if self.accept_3d_fmap:
1082+
assert not exists(mask)
1083+
depth, height, width = x.shape[-3:]
1084+
x = rearrange(x, 'b c d h w -> b (d h w) c')
1085+
10721086
if need_transpose:
10731087
x = rearrange(x, 'b d n -> b n d')
10741088

@@ -1196,6 +1210,9 @@ def calculate_ce_loss(codes):
11961210
if self.accept_image_fmap:
11971211
embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width)
11981212

1213+
if self.accept_3d_fmap:
1214+
embed_ind = rearrange(embed_ind, 'b (d h w) ... -> b d h w ...', d = depth, h = height, w = width)
1215+
11991216
if only_one:
12001217
embed_ind = rearrange(embed_ind, 'b 1 ... -> b ...')
12011218

@@ -1289,6 +1306,9 @@ def calculate_ce_loss(codes):
12891306
if self.accept_image_fmap:
12901307
quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width)
12911308

1309+
if self.accept_3d_fmap:
1310+
quantize = rearrange(quantize, "b (d h w) c -> b c d h w", d=depth, h=height, w=width)
1311+
12921312
if only_one:
12931313
quantize = rearrange(quantize, 'b 1 d -> b d')
12941314

0 commit comments

Comments
 (0)