Skip to content

Commit f8c56dd

Browse files
committed
Fix padded vocab size bug
1 parent b63c903 commit f8c56dd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

models/tt_transformers/tt/lm_head.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def __init__(
3232
self.padded_vocab_size = args.padded_vocab_size
3333
self.num_devices = args.num_devices
3434

35-
if not self.padded_vocab_size or self.padded_vocab_size % 32 != 0:
35+
if not self.padded_vocab_size:
36+
self.padded_vocab_size = self.vocab_size
37+
38+
if self.padded_vocab_size % 32 != 0:
3639
# Pad vocab_size to be divisible by 32
3740
self.padded_vocab_size = math.ceil(self.padded_vocab_size / 32) * 32
3841

0 commit comments

Comments
 (0)