File tree 4 files changed +4
-5
lines changed
4 files changed +4
-5
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " vector-quantize-pytorch"
3
- version = " 1.19.2 "
3
+ version = " 1.19.3 "
4
4
description = " Vector Quantization - Pytorch"
5
5
authors = [
6
6
{
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change @@ -305,7 +305,7 @@ def forward(
305
305
306
306
forward_kwargs = dict (
307
307
return_all_codes = return_all_codes ,
308
- rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device )
308
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device ) if self . training else None
309
309
)
310
310
311
311
# invoke residual vq on each group
Original file line number Diff line number Diff line change @@ -272,7 +272,7 @@ def forward(
272
272
forward_kwargs = dict (
273
273
mask = mask ,
274
274
return_all_codes = return_all_codes ,
275
- rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device )
275
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device ) if self . training else None
276
276
)
277
277
278
278
# invoke residual vq on each group
Original file line number Diff line number Diff line change @@ -328,7 +328,6 @@ def forward(
328
328
329
329
rand = random .Random (rand_quantize_dropout_fixed_seed )
330
330
331
-
332
331
rand_quantize_dropout_index = rand .randrange (self .quantize_dropout_cutoff_index , num_quant )
333
332
334
333
if quant_dropout_multiple_of != 1 :
@@ -496,7 +495,7 @@ def forward(
496
495
sample_codebook_temp = sample_codebook_temp ,
497
496
mask = mask ,
498
497
freeze_codebook = freeze_codebook ,
499
- rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device )
498
+ rand_quantize_dropout_fixed_seed = get_maybe_sync_seed (device ) if self . training else None
500
499
)
501
500
502
501
# invoke residual vq on each group
You can’t perform that action at this time.
0 commit comments