Skip to content

Commit 35c4c84

Browse files
committed
only sync seed during training
1 parent cd35a68 commit 35c4c84

File tree

4 files changed

+4
-5
lines changed

4 files changed

+4
-5
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.19.2"
3+
version = "1.19.3"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

vector_quantize_pytorch/residual_fsq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def forward(
305305

306306
forward_kwargs = dict(
307307
return_all_codes = return_all_codes,
308-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
308+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
309309
)
310310

311311
# invoke residual vq on each group

vector_quantize_pytorch/residual_lfq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def forward(
272272
forward_kwargs = dict(
273273
mask = mask,
274274
return_all_codes = return_all_codes,
275-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
275+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
276276
)
277277

278278
# invoke residual vq on each group

vector_quantize_pytorch/residual_vq.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,6 @@ def forward(
328328

329329
rand = random.Random(rand_quantize_dropout_fixed_seed)
330330

331-
332331
rand_quantize_dropout_index = rand.randrange(self.quantize_dropout_cutoff_index, num_quant)
333332

334333
if quant_dropout_multiple_of != 1:
@@ -496,7 +495,7 @@ def forward(
496495
sample_codebook_temp = sample_codebook_temp,
497496
mask = mask,
498497
freeze_codebook = freeze_codebook,
499-
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
498+
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) if self.training else None
500499
)
501500

502501
# invoke residual vq on each group

0 commit comments

Comments
 (0)