|
1 | 1 | """ |
2 | | -Title: GPT2 Text Generation with KerasNLP |
| 2 | +Title: GPT2 Text Generation with KerasHub |
3 | 3 | Author: Chen Qian |
4 | 4 | Date created: 2023/04/17 |
5 | 5 | Last modified: 2024/04/12 |
6 | | -Description: Use KerasNLP GPT2 model and `samplers` to do text generation. |
| 6 | +Description: Use KerasHub GPT2 model and `samplers` to do text generation. |
7 | 7 | Accelerator: GPU |
8 | 8 | """ |
9 | 9 |
|
10 | 10 | """ |
11 | | -In this tutorial, you will learn to use [KerasNLP](https://keras.io/keras_nlp/) to load a |
| 11 | +In this tutorial, you will learn to use [KerasHub](https://keras.io/keras_hub/) to load a |
12 | 12 | pre-trained Large Language Model (LLM) - [GPT-2 model](https://openai.com/research/better-language-models) |
13 | 13 | (originally invented by OpenAI), finetune it to a specific text style, and |
14 | 14 | generate text based on users' input (also known as prompt). You will also learn |
|
25 | 25 | """ |
26 | 26 |
|
27 | 27 | """ |
28 | | -## Install KerasNLP, Choose Backend and Import Dependencies |
| 28 | +## Install KerasHub, Choose Backend and Import Dependencies |
29 | 29 |
|
30 | 30 | This examples uses [Keras 3](https://keras.io/keras_3/) to work in any of |
31 | 31 | `"tensorflow"`, `"jax"` or `"torch"`. Support for Keras 3 is baked into |
32 | | -KerasNLP, simply change the `"KERAS_BACKEND"` environment variable to select |
| 32 | +KerasHub, simply change the `"KERAS_BACKEND"` environment variable to select |
33 | 33 | the backend of your choice. We select the JAX backend below. |
34 | 34 | """ |
35 | 35 |
|
36 | 36 | """shell |
37 | | -pip install git+https://github.com/keras-team/keras-nlp.git -q |
| 37 | +pip install git+https://github.com/keras-team/keras-hub.git -q |
38 | 38 | """ |
39 | 39 |
|
40 | 40 | import os |
41 | 41 |
|
42 | 42 | os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch" |
43 | 43 |
|
44 | | -import keras_nlp |
| 44 | +import keras_hub |
45 | 45 | import keras |
46 | 46 | import tensorflow as tf |
47 | 47 | import time |
|
70 | 70 | """ |
71 | 71 |
|
72 | 72 | """ |
73 | | -## Introduction to KerasNLP |
| 73 | +## Introduction to KerasHub |
74 | 74 |
|
75 | 75 | Large Language Models are complex to build and expensive to train from scratch. |
76 | | -Luckily there are pretrained LLMs available for use right away. [KerasNLP](https://keras.io/keras_nlp/) |
| 76 | +Luckily there are pretrained LLMs available for use right away. [KerasHub](https://keras.io/keras_hub/) |
77 | 77 | provides a large number of pre-trained checkpoints that allow you to experiment |
78 | 78 | with SOTA models without needing to train them yourself. |
79 | 79 |
|
80 | | -KerasNLP is a natural language processing library that supports users through |
81 | | -their entire development cycle. KerasNLP offers both pretrained models and |
| 80 | +KerasHub is a natural language processing library that supports users through |
| 81 | +their entire development cycle. KerasHub offers both pretrained models and |
82 | 82 | modularized building blocks, so developers could easily reuse pretrained models |
83 | 83 | or stack their own LLM. |
84 | 84 |
|
85 | | -In a nutshell, for generative LLM, KerasNLP offers: |
| 85 | +In a nutshell, for generative LLM, KerasHub offers: |
86 | 86 |
|
87 | 87 | - Pretrained models with `generate()` method, e.g., |
88 | | - `keras_nlp.models.GPT2CausalLM` and `keras_nlp.models.OPTCausalLM`. |
| 88 | + `keras_hub.models.GPT2CausalLM` and `keras_hub.models.OPTCausalLM`. |
89 | 89 | - Sampler class that implements generation algorithms such as Top-K, Beam and |
90 | 90 | contrastive search. These samplers can be used to generate text with |
91 | 91 | custom models. |
|
94 | 94 | """ |
95 | 95 | ## Load a pre-trained GPT-2 model and generate some text |
96 | 96 |
|
97 | | -KerasNLP provides a number of pre-trained models, such as [Google |
| 97 | +KerasHub provides a number of pre-trained models, such as [Google |
98 | 98 | Bert](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html) |
99 | 99 | and [GPT-2](https://openai.com/research/better-language-models). You can see |
100 | | -the list of models available in the [KerasNLP repository](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/models). |
| 100 | +the list of models available in the [KerasHub repository](https://github.com/keras-team/keras-hub/tree/master/keras_hub/models). |
101 | 101 |
|
102 | 102 | It's very easy to load the GPT-2 model as you can see below: |
103 | 103 | """ |
104 | 104 |
|
105 | 105 | # To speed up training and generation, we use preprocessor of length 128 |
106 | 106 | # instead of full length 1024. |
107 | | -preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( |
| 107 | +preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset( |
108 | 108 | "gpt2_base_en", |
109 | 109 | sequence_length=128, |
110 | 110 | ) |
111 | | -gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( |
| 111 | +gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset( |
112 | 112 | "gpt2_base_en", preprocessor=preprocessor |
113 | 113 | ) |
114 | 114 |
|
|
150 | 150 | """ |
151 | 151 |
|
152 | 152 | """ |
153 | | -## More on the GPT-2 model from KerasNLP |
| 153 | +## More on the GPT-2 model from KerasHub |
154 | 154 |
|
155 | 155 | Next up, we will actually fine-tune the model to update its parameters, but |
156 | 156 | before we do, let's take a look at the full set of tools we have to for working |
157 | 157 | with for GPT2. |
158 | 158 |
|
159 | 159 | The code of GPT2 can be found |
160 | | -[here](https://github.com/keras-team/keras-nlp/blob/master/keras_nlp/models/gpt2/). |
| 160 | +[here](https://github.com/keras-team/keras-hub/blob/master/keras_hub/models/gpt2/). |
161 | 161 | Conceptually the `GPT2CausalLM` can be hierarchically broken down into several |
162 | | -modules in KerasNLP, all of which have a *from_preset()* function that loads a |
| 162 | +modules in KerasHub, all of which have a *from_preset()* function that loads a |
163 | 163 | pretrained model: |
164 | 164 |
|
165 | | -- `keras_nlp.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a |
| 165 | +- `keras_hub.models.GPT2Tokenizer`: The tokenizer used by GPT2 model, which is a |
166 | 166 | [byte-pair encoder](https://huggingface.co/course/chapter6/5?fw=pt). |
167 | | -- `keras_nlp.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2 |
| 167 | +- `keras_hub.models.GPT2CausalLMPreprocessor`: the preprocessor used by GPT2 |
168 | 168 | causal LM training. It does the tokenization along with other preprocessing |
169 | 169 | works such as creating the label and appending the end token. |
170 | | -- `keras_nlp.models.GPT2Backbone`: the GPT2 model, which is a stack of |
171 | | - `keras_nlp.layers.TransformerDecoder`. This is usually just referred as |
| 170 | +- `keras_hub.models.GPT2Backbone`: the GPT2 model, which is a stack of |
| 171 | + `keras_hub.layers.TransformerDecoder`. This is usually just referred as |
172 | 172 | `GPT2`. |
173 | | -- `keras_nlp.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the |
| 173 | +- `keras_hub.models.GPT2CausalLM`: wraps `GPT2Backbone`, it multiplies the |
174 | 174 | output of `GPT2Backbone` by embedding matrix to generate logits over |
175 | 175 | vocab tokens. |
176 | 176 | """ |
177 | 177 |
|
178 | 178 | """ |
179 | 179 | ## Finetune on Reddit dataset |
180 | 180 |
|
181 | | -Now you have the knowledge of the GPT-2 model from KerasNLP, you can take one |
| 181 | +Now you have the knowledge of the GPT-2 model from KerasHub, you can take one |
182 | 182 | step further to finetune the model so that it generates text in a specific |
183 | 183 | style, short or long, strict or casual. In this tutorial, we will use reddit |
184 | 184 | dataset for example. |
|
217 | 217 | """ |
218 | 218 | Now you can finetune the model using the familiar *fit()* function. Note that |
219 | 219 | `preprocessor` will be automatically called inside `fit` method since |
220 | | -`GPT2CausalLM` is a `keras_nlp.models.Task` instance. |
| 220 | +`GPT2CausalLM` is a `keras_hub.models.Task` instance. |
221 | 221 |
|
222 | 222 | This step takes quite a bit of GPU memory and a long time if we were to train |
223 | 223 | it all the way to a fully trained state. Here we just use part of the dataset |
|
261 | 261 | """ |
262 | 262 | ## Into the Sampling Method |
263 | 263 |
|
264 | | -In KerasNLP, we offer a few sampling methods, e.g., contrastive search, |
| 264 | +In KerasHub, we offer a few sampling methods, e.g., contrastive search, |
265 | 265 | Top-K and beam sampling. By default, our `GPT2CausalLM` uses Top-k search, but |
266 | 266 | you can choose your own sampling method. |
267 | 267 |
|
|
270 | 270 |
|
271 | 271 | - Use a string identifier, such as "greedy", you are using the default |
272 | 272 | configuration via this way. |
273 | | -- Pass a `keras_nlp.samplers.Sampler` instance, you can use custom configuration |
| 273 | +- Pass a `keras_hub.samplers.Sampler` instance, you can use custom configuration |
274 | 274 | via this way. |
275 | 275 | """ |
276 | 276 |
|
|
281 | 281 | print(output) |
282 | 282 |
|
283 | 283 | # Use a `Sampler` instance. `GreedySampler` tends to repeat itself, |
284 | | -greedy_sampler = keras_nlp.samplers.GreedySampler() |
| 284 | +greedy_sampler = keras_hub.samplers.GreedySampler() |
285 | 285 | gpt2_lm.compile(sampler=greedy_sampler) |
286 | 286 |
|
287 | 287 | output = gpt2_lm.generate("I like basketball", max_length=200) |
288 | 288 | print("\nGPT-2 output:") |
289 | 289 | print(output) |
290 | 290 |
|
291 | 291 | """ |
292 | | -For more details on KerasNLP `Sampler` class, you can check the code |
293 | | -[here](https://github.com/keras-team/keras-nlp/tree/master/keras_nlp/samplers). |
| 292 | +For more details on KerasHub `Sampler` class, you can check the code |
| 293 | +[here](https://github.com/keras-team/keras-hub/tree/master/keras_hub/samplers). |
294 | 294 | """ |
295 | 295 |
|
296 | 296 | """ |
|
0 commit comments