Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 58 additions & 46 deletions examples/generative/text_generation_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
## Imports
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os

os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" , "torch"
import keras
from keras import layers, ops
import tensorflow as tf

# Defining hyperparameters

VOCAB_SIZE = 8192
Expand All @@ -59,7 +61,9 @@
)

path_to_dataset = os.path.join(
os.path.dirname(path_to_zip), "cornell movie-dialogs corpus"
os.path.dirname(path_to_zip),
"cornell_movie_dialogs_extracted",
"cornell movie-dialogs corpus",
)
path_to_movie_lines = os.path.join(path_to_dataset, "movie_lines.txt")
path_to_movie_conversations = os.path.join(path_to_dataset, "movie_conversations.txt")
Expand Down Expand Up @@ -136,7 +140,7 @@ def vectorize_text(inputs, outputs):
outputs = tf.pad(outputs, [[0, 1]])
return (
{"encoder_inputs": inputs, "decoder_inputs": outputs[:-1]},
{"outputs": outputs[1:]},
outputs[1:],
)


Expand Down Expand Up @@ -182,12 +186,14 @@ def __init__(self, embed_dim, dense_dim, **kwargs):
self.layernorm_2 = layers.LayerNormalization()

def call(self, inputs):
# Casting the inputs to complex64
inp_complex = tf.cast(inputs, tf.complex64)
# Projecting the inputs to the frequency domain using FFT2D and
# extracting the real part of the output
fft = tf.math.real(tf.signal.fft2d(inp_complex))
proj_input = self.layernorm_1(inputs + fft)
# Cast inputs to float32 and create imaginary component
inp_real = ops.cast(inputs, "float32")
inp_imag = ops.zeros_like(inp_real)

# Apply 2D FFT - returns tuple of (real, imaginary)
fft_real, fft_imag = ops.fft2((inp_real, inp_imag))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be ops.fft

# Use only the real component
proj_input = self.layernorm_1(inputs + fft_real)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)

Expand Down Expand Up @@ -218,14 +224,14 @@ def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
self.embed_dim = embed_dim

def call(self, inputs):
length = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=length, delta=1)
length = ops.shape(inputs)[-1]
positions = ops.arange(0, length, 1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
return embedded_tokens + embedded_positions

def compute_mask(self, inputs, mask=None):
return tf.math.not_equal(inputs, 0)
return ops.not_equal(inputs, 0)


class FNetDecoder(layers.Layer):
Expand Down Expand Up @@ -253,9 +259,11 @@ def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):

def call(self, inputs, encoder_outputs, mask=None):
causal_mask = self.get_causal_attention_mask(inputs)

if mask is not None:
padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
padding_mask = tf.minimum(padding_mask, causal_mask)
padding_mask = ops.cast(mask[:, None, :], "int32")
else:
padding_mask = None

attention_output_1 = self.attention_1(
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
Expand All @@ -274,17 +282,14 @@ def call(self, inputs, encoder_outputs, mask=None):
return self.layernorm_3(out_2 + proj_output)

def get_causal_attention_mask(self, inputs):
input_shape = tf.shape(inputs)
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = tf.range(sequence_length)[:, tf.newaxis]
j = tf.range(sequence_length)
mask = tf.cast(i >= j, dtype="int32")
mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = tf.concat(
[tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
axis=0,
)
return tf.tile(mask, mult)
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
multiples = [batch_size, 1, 1]
return ops.tile(mask, multiples)


def create_model():
Expand Down Expand Up @@ -316,9 +321,9 @@ def create_model():
fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

"""
Here, the `epochs` parameter is set to a single epoch, but in practice the model will take around
**20-30 epochs** of training to start outputting comprehensible sentences. Although accuracy
is not a good measure for this task, we will use it just to get a hint of the improvement
The model as configured here uses a simplified architecture to keep training time manageable for a tutorial. The text generation quality
will be limited - outputs may be generic.
Although accuracy is not a good measure for this task, we will use it just to get a hint of the improvement
of the network.
"""

Expand All @@ -334,38 +339,45 @@ def create_model():
def decode_sentence(input_sentence):
# Mapping the input sentence to tokens and adding start and end tokens
tokenized_input_sentence = vectorizer(
tf.constant("[start] " + preprocess_text(input_sentence) + " [end]")
"[start] " + preprocess_text(input_sentence) + " [end]"
)
# Initializing the initial sentence consisting of only the start token.
tokenized_target_sentence = tf.expand_dims(VOCAB.index("[start]"), 0)
decoded_sentence = ""

# Start token
start_token_index = VOCAB.index("[start]")
end_token_index = VOCAB.index("[end]")

tokenized_target_sentence = ops.expand_dims(start_token_index, axis=0)
decoded_sentence = []

for i in range(MAX_LENGTH):
# Get the predictions
predictions = fnet.predict(
{
"encoder_inputs": tf.expand_dims(tokenized_input_sentence, 0),
"decoder_inputs": tf.expand_dims(
tf.pad(
"encoder_inputs": ops.expand_dims(tokenized_input_sentence, axis=0),
"decoder_inputs": ops.expand_dims(
ops.pad(
tokenized_target_sentence,
[[0, MAX_LENGTH - tf.shape(tokenized_target_sentence)[0]]],
[[0, MAX_LENGTH - ops.shape(tokenized_target_sentence)[0]]],
),
0,
axis=0,
),
}
)
# Calculating the token with maximum probability and getting the corresponding word
sampled_token_index = tf.argmax(predictions[0, i, :])
sampled_token = VOCAB[sampled_token_index.numpy()]
# If sampled token is the end token then stop generating and return the sentence
if tf.equal(sampled_token_index, VOCAB.index("[end]")):
sampled_token_index = ops.argmax(predictions[0, i, :])
sampled_token_index = int(sampled_token_index)

if sampled_token_index == end_token_index:
break
decoded_sentence += sampled_token + " "
tokenized_target_sentence = tf.concat(
[tokenized_target_sentence, [sampled_token_index]], 0

decoded_sentence.append(VOCAB[sampled_token_index])

tokenized_target_sentence = ops.concatenate(
[tokenized_target_sentence, ops.expand_dims(sampled_token_index, axis=0)],
axis=0,
)

return decoded_sentence
return " ".join(decoded_sentence)


decode_sentence("Where have you been all this time?")
Expand Down