@@ -191,6 +191,7 @@ def __init__(
191191
192192 codebook_dim = default (codebook_dim , dim )
193193 codebook_input_dim = codebook_dim * heads
194+ self .codebook_dim = codebook_dim
194195
195196 requires_projection = codebook_input_dim != dim
196197 self .project_in = nn .Linear (dim , codebook_input_dim ) if requires_projection else nn .Identity ()
@@ -223,6 +224,7 @@ def __init__(
223224 self .num_quantizers = num_quantizers
224225
225226 self .codebook_sizes = codebook_sizes
227+
226228 self .uniform_codebook_size = len (unique (codebook_sizes )) == 1
227229
228230 # define vq across layers
@@ -287,10 +289,6 @@ def __init__(
287289 @property
288290 def codebook_size (self ):
289291 return self .layers [0 ].codebook_size
290-
291- @property
292- def codebook_dim (self ):
293- return self .layers [0 ].codebook_dim
294292
295293 @property
296294 def codebooks (self ):
@@ -423,7 +421,7 @@ def forward(
423421
424422 # save all inputs across layers, for use during expiration at end under shared codebook setting, or ema update during beam search
425423
426- all_residuals = torch .empty ((* input_shape [:- 1 ], 0 , input_shape [ - 1 ] ), dtype = residual .dtype , device = device )
424+ all_residuals = torch .empty ((* input_shape [:- 1 ], 0 , self . codebook_dim ), dtype = residual .dtype , device = device )
427425
428426 # maybe prepare beam search
429427
0 commit comments