diff --git a/docs/tutorials/transformer.ipynb b/docs/tutorials/transformer.ipynb index 8b1683d55..7061ecd22 100644 --- a/docs/tutorials/transformer.ipynb +++ b/docs/tutorials/transformer.ipynb @@ -975,6 +975,15 @@ " def __init__(self, vocab_size, d_model):\n", " super().__init__()\n", " self.d_model = d_model\n", + " try:\n", + " input_dim = int(vocab_size)\n", + " except: ValueError:\n", + " try: + " input_dim = int(vocab_size.numpy())\n", + " except: ValueError:\n", + " print(f\"Can't convert vocab_size (type = {type(vocab_size)} to int.\")\n", + " \n", + " self.embedding = tf.keras.layers.Embedding(input_dim=input_dim, output_dim=d_model, mask_zero=True) \n", " self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) \n", " self.pos_encoding = positional_encoding(length=2048, depth=d_model)\n", "\n",