-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathtext_generation_fnet.py
More file actions
405 lines (323 loc) · 13.8 KB
/
text_generation_fnet.py
File metadata and controls
405 lines (323 loc) · 13.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
"""
Title: Text Generation using FNet
Author: [Darshan Deshpande](https://twitter.com/getdarshan)
Date created: 2021/10/05
Last modified: 2021/10/05
Description: FNet transformer for text generation in Keras.
Accelerator: GPU
"""
"""
## Introduction
The original transformer implementation (Vaswani et al., 2017) was one of the major
breakthroughs in Natural Language Processing, giving rise to important architectures such BERT and GPT.
However, the drawback of these architectures is
that the self-attention mechanism they use is computationally expensive. The FNet
architecture proposes to replace this self-attention attention with a leaner mechanism:
a Fourier transformation-based linear mixer for input tokens.
The FNet model was able to achieve 92-97% of BERT's accuracy while training 80% faster on
GPUs and almost 70% faster on TPUs. This type of design provides an efficient and small
model size, leading to faster inference times.
In this example, we will implement and train this architecture on the Cornell Movie
Dialog corpus to show the applicability of this model to text generation.
"""
"""
## Imports
"""
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" , "torch"
import keras
from keras import layers, ops
import tensorflow as tf
# Defining hyperparameters
VOCAB_SIZE = 8192
MAX_SAMPLES = 50000
BUFFER_SIZE = 20000
MAX_LENGTH = 40
EMBED_DIM = 256
LATENT_DIM = 512
NUM_HEADS = 8
BATCH_SIZE = 64
"""
## Loading data
We will be using the Cornell Dialog Corpus. We will parse the movie conversations into
questions and answers sets.
"""
path_to_zip = keras.utils.get_file(
"cornell_movie_dialogs.zip",
origin="http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip",
extract=True,
)
path_to_dataset = os.path.join(
os.path.dirname(path_to_zip),
"cornell_movie_dialogs_extracted",
"cornell movie-dialogs corpus",
)
path_to_movie_lines = os.path.join(path_to_dataset, "movie_lines.txt")
path_to_movie_conversations = os.path.join(path_to_dataset, "movie_conversations.txt")
def load_conversations():
# Helper function for loading the conversation splits
id2line = {}
with open(path_to_movie_lines, errors="ignore") as file:
lines = file.readlines()
for line in lines:
parts = line.replace("\n", "").split(" +++$+++ ")
id2line[parts[0]] = parts[4]
inputs, outputs = [], []
with open(path_to_movie_conversations, "r") as file:
lines = file.readlines()
for line in lines:
parts = line.replace("\n", "").split(" +++$+++ ")
# get conversation in a list of line ID
conversation = [line[1:-1] for line in parts[3][1:-1].split(", ")]
for i in range(len(conversation) - 1):
inputs.append(id2line[conversation[i]])
outputs.append(id2line[conversation[i + 1]])
if len(inputs) >= MAX_SAMPLES:
return inputs, outputs
return inputs, outputs
questions, answers = load_conversations()
# Splitting training and validation sets
train_dataset = tf.data.Dataset.from_tensor_slices((questions[:40000], answers[:40000]))
val_dataset = tf.data.Dataset.from_tensor_slices((questions[40000:], answers[40000:]))
"""
### Preprocessing and Tokenization
"""
def preprocess_text(sentence):
sentence = tf.strings.lower(sentence)
# Adding a space between the punctuation and the last word to allow better tokenization
sentence = tf.strings.regex_replace(sentence, r"([?.!,])", r" \1 ")
# Replacing multiple continuous spaces with a single space
sentence = tf.strings.regex_replace(sentence, r"\s\s+", " ")
# Replacing non english words with spaces
sentence = tf.strings.regex_replace(sentence, r"[^a-z?.!,]+", " ")
sentence = tf.strings.strip(sentence)
sentence = tf.strings.join(["[start]", sentence, "[end]"], separator=" ")
return sentence
vectorizer = layers.TextVectorization(
VOCAB_SIZE,
standardize=preprocess_text,
output_mode="int",
output_sequence_length=MAX_LENGTH,
)
# We will adapt the vectorizer to both the questions and answers
# This dataset is batched to parallelize and speed up the process
vectorizer.adapt(tf.data.Dataset.from_tensor_slices((questions + answers)).batch(128))
"""
### Tokenizing and padding sentences using `TextVectorization`
"""
def vectorize_text(inputs, outputs):
inputs, outputs = vectorizer(inputs), vectorizer(outputs)
# One extra padding token to the right to match the output shape
outputs = tf.pad(outputs, [[0, 1]])
return (
{"encoder_inputs": inputs, "decoder_inputs": outputs[:-1]},
outputs[1:],
)
train_dataset = train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = (
train_dataset.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(tf.data.AUTOTUNE)
)
val_dataset = val_dataset.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
"""
## Creating the FNet Encoder
The FNet paper proposes a replacement for the standard attention mechanism used by the
Transformer architecture (Vaswani et al., 2017).

The outputs of the FFT layer are complex numbers. To avoid dealing with complex layers,
only the real part (the magnitude) is extracted.
The dense layers that follow the Fourier transformation act as convolutions applied on
the frequency domain.
"""
class FNetEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.dense_proj = keras.Sequential(
[
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
def call(self, inputs):
# Cast inputs to float32 and create imaginary component
inp_real = ops.cast(inputs, "float32")
inp_imag = ops.zeros_like(inp_real)
# Apply 2D FFT - returns tuple of (real, imaginary)
fft_real, fft_imag = ops.fft2((inp_real, inp_imag))
# Use only the real component
proj_input = self.layernorm_1(inputs + fft_real)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
"""
## Creating the Decoder
The decoder architecture remains the same as the one proposed by (Vaswani et al., 2017)
in the original transformer architecture, consisting of an embedding, positional
encoding, two masked multi-head attention layers and finally the dense output layers.
The architecture that follows is taken from
[Deep Learning with Python, second edition, chapter 11](https://www.manning.com/books/deep-learning-with-python-second-edition).
"""
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.token_embeddings = layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=embed_dim
)
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.embed_dim = embed_dim
def call(self, inputs):
length = ops.shape(inputs)[-1]
positions = ops.arange(0, length, 1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
return embedded_tokens + embedded_positions
def compute_mask(self, inputs, mask=None):
return ops.not_equal(inputs, 0)
class FNetDecoder(layers.Layer):
def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.latent_dim = latent_dim
self.num_heads = num_heads
self.attention_1 = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.attention_2 = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.dense_proj = keras.Sequential(
[
layers.Dense(latent_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
self.layernorm_3 = layers.LayerNormalization()
self.supports_masking = True
def call(self, inputs, encoder_outputs, mask=None):
causal_mask = self.get_causal_attention_mask(inputs)
if mask is not None:
padding_mask = ops.cast(mask[:, None, :], "int32")
else:
padding_mask = None
attention_output_1 = self.attention_1(
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
)
out_1 = self.layernorm_1(inputs + attention_output_1)
attention_output_2 = self.attention_2(
query=out_1,
value=encoder_outputs,
key=encoder_outputs,
attention_mask=padding_mask,
)
out_2 = self.layernorm_2(out_1 + attention_output_2)
proj_output = self.dense_proj(out_2)
return self.layernorm_3(out_2 + proj_output)
def get_causal_attention_mask(self, inputs):
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
multiples = [batch_size, 1, 1]
return ops.tile(mask, multiples)
def create_model():
encoder_inputs = keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(encoder_inputs)
encoder_outputs = FNetEncoder(EMBED_DIM, LATENT_DIM)(x)
encoder = keras.Model(encoder_inputs, encoder_outputs)
decoder_inputs = keras.Input(shape=(None,), dtype="int32", name="decoder_inputs")
encoded_seq_inputs = keras.Input(
shape=(None, EMBED_DIM), name="decoder_state_inputs"
)
x = PositionalEmbedding(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM)(decoder_inputs)
x = FNetDecoder(EMBED_DIM, LATENT_DIM, NUM_HEADS)(x, encoded_seq_inputs)
x = layers.Dropout(0.5)(x)
decoder_outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(x)
decoder = keras.Model(
[decoder_inputs, encoded_seq_inputs], decoder_outputs, name="outputs"
)
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
fnet = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs, name="fnet")
return fnet
"""
## Creating and Training the model
"""
fnet = create_model()
fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
"""
The model as configured here uses a simplified architecture to keep training time manageable for a tutorial. The text generation quality
will be limited - outputs may be generic.
Although accuracy is not a good measure for this task, we will use it just to get a hint of the improvement
of the network.
"""
fnet.fit(train_dataset, epochs=1, validation_data=val_dataset)
"""
## Performing inference
"""
VOCAB = vectorizer.get_vocabulary()
def decode_sentence(input_sentence):
# Mapping the input sentence to tokens and adding start and end tokens
tokenized_input_sentence = vectorizer(
"[start] " + preprocess_text(input_sentence) + " [end]"
)
# Start token
start_token_index = VOCAB.index("[start]")
end_token_index = VOCAB.index("[end]")
tokenized_target_sentence = ops.expand_dims(start_token_index, axis=0)
decoded_sentence = []
for i in range(MAX_LENGTH):
# Get the predictions
predictions = fnet.predict(
{
"encoder_inputs": ops.expand_dims(tokenized_input_sentence, axis=0),
"decoder_inputs": ops.expand_dims(
ops.pad(
tokenized_target_sentence,
[[0, MAX_LENGTH - ops.shape(tokenized_target_sentence)[0]]],
),
axis=0,
),
}
)
# Calculating the token with maximum probability and getting the corresponding word
sampled_token_index = ops.argmax(predictions[0, i, :])
sampled_token_index = int(sampled_token_index)
if sampled_token_index == end_token_index:
break
decoded_sentence.append(VOCAB[sampled_token_index])
tokenized_target_sentence = ops.concatenate(
[tokenized_target_sentence, ops.expand_dims(sampled_token_index, axis=0)],
axis=0,
)
return " ".join(decoded_sentence)
decode_sentence("Where have you been all this time?")
"""
## Conclusion
This example shows how to train and perform inference using the FNet model.
For getting insight into the architecture or for further reading, you can refer to:
1. [FNet: Mixing Tokens with Fourier Transforms](https://arxiv.org/abs/2105.03824v3)
(Lee-Thorp et al., 2021)
2. [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5) (Vaswani et al.,
2017)
Thanks to François Chollet for his Keras example on
[English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)
from which the decoder implementation was extracted.
"""
"""
## Relevant Chapters from Deep Learning with Python
- [Chapter 15: Language models and the Transformer](https://deeplearningwithpython.io/chapters/chapter15_language-models-and-the-transformer)
- [Chapter 16: Text generation](https://deeplearningwithpython.io/chapters/chapter16_text-generation)
"""