1+ {
2+ "cells" : [
3+ {
4+ "cell_type" : " markdown" ,
5+ "metadata" : {
6+ "colab_type" : " text"
7+ },
8+ "source" : [
9+ " # Load HuggingFace Transformers checkpoint into a multi backend KerasHub model\n " ,
10+ " \n " ,
11+ " **Author:** [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)<br><br>\n " ,
12+ " **Date created:** 2025/06/17<br><br>\n " ,
13+ " **Last modified:** 2025/06/17<br><br>\n " ,
14+ " **Description:** How to load and run inference from KerasHub model checkpoints hosted on HuggingFace Hub."
15+ ]
16+ },
17+ {
18+ "cell_type" : " markdown" ,
19+ "metadata" : {
20+ "colab_type" : " text"
21+ },
22+ "source" : [
23+ " ## Introduction\n " ,
24+ " \n " ,
25+ " KerasHub has built-in converters for HuggingFace's `.safetensors` models.\n " ,
26+ " Loading model weights from HuggingFace is therefore no more difficult than\n " ,
27+ " using KerasHub's own presets.\n " ,
28+ " \n " ,
29+ " ### KerasHub built-in HuggingFace transformers converters\n " ,
30+ " \n " ,
31+ " KerasHub simplifies the use of HuggingFace Transformers models through its\n " ,
32+ " built-in converters. These converters automatically handle the process of translating\n " ,
33+ " HuggingFace model checkpoints into a format that's compatible with the Keras ecosystem.\n " ,
34+ " This means you can seamlessly load a wide variety of pretrained models from the HuggingFace\n " ,
35+ " Hub directly into KerasHub with just a few lines of code.\n " ,
36+ " \n " ,
37+ " Key advantages of using KerasHub converters:\n " ,
38+ " - **Ease of Use**: Load HuggingFace models without manual conversion steps.\n " ,
39+ " - **Broad Compatibility**: Access a vast range of models available on the HuggingFace Hub.\n " ,
40+ " - **Seamless Integration**: Work with these models using familiar Keras APIs for training,\n " ,
41+ " evaluation, and inference.\n " ,
42+ " \n " ,
43+ " Fortunately, all of this happens behind the scenes, so you can focus on using\n " ,
44+ " the models rather than managing the conversion process!\n " ,
45+ " \n " ,
46+ " ## Setup\n " ,
47+ " \n " ,
48+ " Before you begin, make sure you have the necessary libraries installed.\n " ,
49+ " You'll primarily need `keras` and `keras_hub`.\n " ,
50+ " \n " ,
51+ " **Note:** Changing the backend after Keras has been imported might not work as expected.\n " ,
52+ " Ensure `KERAS_BACKEND` is set at the beginning of your script."
53+ ]
54+ },
55+ {
56+ "cell_type" : " code" ,
57+ "execution_count" : 0 ,
58+ "metadata" : {
59+ "colab_type" : " code"
60+ },
61+ "outputs" : [],
62+ "source" : [
63+ " import os\n " ,
64+ " \n " ,
65+ " os.environ[\" KERAS_BACKEND\" ] = \" jax\" # \" tensorflow\" or \" torch\"\n " ,
66+ " \n " ,
67+ " import keras\n " ,
68+ " import keras_hub"
69+ ]
70+ },
71+ {
72+ "cell_type" : " markdown" ,
73+ "metadata" : {
74+ "colab_type" : " text"
75+ },
76+ "source" : [
77+ " KerasHub allows you to easily load models from HuggingFace Transformers.\n " ,
78+ " Here's an example of how to load a Gemma causal language model.\n " ,
79+ " In this particular case, you will need to consent to Google's license on\n " ,
80+ " HuggingFace for being able to download model weights, and provide your\n " ,
81+ " `HF_TOKEN` as environment variable or as \" Colab secret\" when working with\n " ,
82+ " Google Colab."
83+ ]
84+ },
85+ {
86+ "cell_type" : " code" ,
87+ "execution_count" : 0 ,
88+ "metadata" : {
89+ "colab_type" : " code"
90+ },
91+ "outputs" : [],
92+ "source" : [
93+ " # not a keras checkpoint, it is a HF transformer checkpoint\n " ,
94+ " \n " ,
95+ " gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(\" hf://google/gemma-2b\" )"
96+ ]
97+ },
98+ {
99+ "cell_type" : " markdown" ,
100+ "metadata" : {
101+ "colab_type" : " text"
102+ },
103+ "source" : [
104+ " Let us try running some inference"
105+ ]
106+ },
107+ {
108+ "cell_type" : " code" ,
109+ "execution_count" : 0 ,
110+ "metadata" : {
111+ "colab_type" : " code"
112+ },
113+ "outputs" : [],
114+ "source" : [
115+ " gemma_lm.generate(\" I want to say\" , max_length=30)"
116+ ]
117+ },
118+ {
119+ "cell_type" : " markdown" ,
120+ "metadata" : {
121+ "colab_type" : " text"
122+ },
123+ "source" : [
124+ " ### Fine-tune a Gemma Transformer checkpoint using the Keras model.fit(...) API.\n " ,
125+ " \n " ,
126+ " Once you have loaded HuggingFace weights, you can use the instantiated model\n " ,
127+ " just like any other KerasHub model. For instance, you might fine-tune the model\n " ,
128+ " on your own data like so:"
129+ ]
130+ },
131+ {
132+ "cell_type" : " code" ,
133+ "execution_count" : 0 ,
134+ "metadata" : {
135+ "colab_type" : " code"
136+ },
137+ "outputs" : [],
138+ "source" : [
139+ " features = [\" The quick brown fox jumped.\" , \" I forgot my homework.\" ]\n " ,
140+ " gemma_lm.fit(x=features, batch_size=2)"
141+ ]
142+ },
143+ {
144+ "cell_type" : " markdown" ,
145+ "metadata" : {
146+ "colab_type" : " text"
147+ },
148+ "source" : [
149+ " ### Saving and uploading the new checkpoint\n " ,
150+ " \n " ,
151+ " To store and share your fine-tuned model, KerasHub makes it easy to save or\n " ,
152+ " upload it using standard methods. You can do this through familiar commands\n " ,
153+ " such as:"
154+ ]
155+ },
156+ {
157+ "cell_type" : " code" ,
158+ "execution_count" : 0 ,
159+ "metadata" : {
160+ "colab_type" : " code"
161+ },
162+ "outputs" : [],
163+ "source" : [
164+ " gemma_lm.save_to_preset(\" ./gemma-2b-finetuned\" )\n " ,
165+ " keras_hub.upload_preset(\" hf://laxmareddyp/gemma-2b-finetune\" , \" ./gemma-2b-finetuned\" )"
166+ ]
167+ },
168+ {
169+ "cell_type" : " markdown" ,
170+ "metadata" : {
171+ "colab_type" : " text"
172+ },
173+ "source" : [
174+ " By uploading your preset, you can then load it from anywhere using:\n " ,
175+ " `loaded_model = keras_hub.models.GemmaCausalLM.from_preset(\" hf://YOUR_HF_USERNAME/gemma-2b-finetuned\" )`\n " ,
176+ " \n " ,
177+ " For a comprehensive, step-by-step guide on uploading your model, refer to the official KerasHub upload documentation.\n " ,
178+ " You can find all the details here: [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)\n " ,
179+ " \n " ,
180+ " By integrating HuggingFace Transformers, KerasHub significantly expands your access to pretrained models.\n " ,
181+ " The Hugging Face Hub now hosts well over 750k+ model checkpoints across various domains such as NLP,\n " ,
182+ " Computer Vision, Audio, and more. Of these, approximately 400K models are currently compatible with KerasHub,\n " ,
183+ " giving you access to a vast and diverse selection of state-of-the-art architectures for your projects.\n " ,
184+ " \n " ,
185+ " With KerasHub, you can:\n " ,
186+ " - **Tap into State-of-the-Art Models**: Easily experiment with the latest\n " ,
187+ " architectures and pretrained weights from the research community and industry.\n " ,
188+ " - **Reduce Development Time**: Leverage existing models instead of training from scratch,\n " ,
189+ " saving significant time and computational resources.\n " ,
190+ " - **Enhance Model Capabilities**: Find specialized models for a wide array of tasks,\n " ,
191+ " from text generation and translation to image segmentation and object detection.\n " ,
192+ " \n " ,
193+ " This seamless access empowers you to build more powerful and sophisticated AI applications with Keras.\n " ,
194+ " \n " ,
195+ " ## Use a wider range of frameworks\n " ,
196+ " \n " ,
197+ " Keras 3, and by extension KerasHub, is designed for multi-framework compatibility.\n " ,
198+ " This means you can run your models with different backend frameworks like JAX, TensorFlow, and PyTorch.\n " ,
199+ " This flexibility allows you to:\n " ,
200+ " \n " ,
201+ " - **Choose the Best Backend for Your Needs**: Select a backend based on performance characteristics,\n " ,
202+ " hardware compatibility (e.g., TPUs with JAX), or existing team expertise.\n " ,
203+ " - **Interoperability**: More easily integrate KerasHub models into existing\n " ,
204+ " workflows that might be built on TensorFlow or PyTorch.\n " ,
205+ " - **Future-Proofing**: Adapt to evolving framework landscapes without\n " ,
206+ " rewriting your core model logic.\n " ,
207+ " \n " ,
208+ " ## Run transformer models in JAX backend and on TPUs\n " ,
209+ " \n " ,
210+ " To experiment with a model using JAX, you can utilize Keras by setting its backend to JAX.\n " ,
211+ " By switching Keras\u2019 s backend before model construction, and ensuring your environment is connected to a TPU runtime.\n " ,
212+ " Keras will then automatically leverage JAX\u2019 s TPU support,\n " ,
213+ " allowing your model to train efficiently on TPU hardware without further code changes."
214+ ]
215+ },
216+ {
217+ "cell_type" : " code" ,
218+ "execution_count" : 0 ,
219+ "metadata" : {
220+ "colab_type" : " code"
221+ },
222+ "outputs" : [],
223+ "source" : [
224+ " import os\n " ,
225+ " \n " ,
226+ " os.environ[\" KERAS_BACKEND\" ] = \" jax\"\n " ,
227+ " gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(\" hf://google/gemma-2b\" )"
228+ ]
229+ },
230+ {
231+ "cell_type" : " markdown" ,
232+ "metadata" : {
233+ "colab_type" : " text"
234+ },
235+ "source" : [
236+ " ## Additional Examples\n " ,
237+ " \n " ,
238+ " ### Generation\n " ,
239+ " \n " ,
240+ " Here\u2019 s an example using Llama: loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend."
241+ ]
242+ },
243+ {
244+ "cell_type" : " code" ,
245+ "execution_count" : 0 ,
246+ "metadata" : {
247+ "colab_type" : " code"
248+ },
249+ "outputs" : [],
250+ "source" : [
251+ " import os\n " ,
252+ " \n " ,
253+ " os.environ[\" KERAS_BACKEND\" ] = \" jax\"\n " ,
254+ " \n " ,
255+ " from keras_hub.models import Llama3CausalLM\n " ,
256+ " \n " ,
257+ " # Get the model\n " ,
258+ " causal_lm = Llama3CausalLM.from_preset(\" hf://NousResearch/Hermes-2-Pro-Llama-3-8B\" )\n " ,
259+ " \n " ,
260+ " prompts = [\n " ,
261+ " \"\"\" <|im_start|>system\n " ,
262+ " You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>\n " ,
263+ " <|im_start|>user\n " ,
264+ " Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>\n " ,
265+ " <|im_start|>assistant\"\"\" ,\n " ,
266+ " ]\n " ,
267+ " \n " ,
268+ " # Generate from the model\n " ,
269+ " causal_lm.generate(prompts, max_length=30)[0]"
270+ ]
271+ },
272+ {
273+ "cell_type" : " markdown" ,
274+ "metadata" : {
275+ "colab_type" : " text"
276+ },
277+ "source" : [
278+ " ### Changing precision\n " ,
279+ " \n " ,
280+ " You can adjust your model\u2019 s precision by configuring it through `keras.config` as follows"
281+ ]
282+ },
283+ {
284+ "cell_type" : " code" ,
285+ "execution_count" : 0 ,
286+ "metadata" : {
287+ "colab_type" : " code"
288+ },
289+ "outputs" : [],
290+ "source" : [
291+ " import keras\n " ,
292+ " \n " ,
293+ " keras.config.set_dtype_policy(\" bfloat16\" )\n " ,
294+ " \n " ,
295+ " from keras_hub.models import Llama3CausalLM\n " ,
296+ " \n " ,
297+ " causal_lm = Llama3CausalLM.from_preset(\" hf://NousResearch/Hermes-2-Pro-Llama-3-8B\" )"
298+ ]
299+ },
300+ {
301+ "cell_type" : " markdown" ,
302+ "metadata" : {
303+ "colab_type" : " text"
304+ },
305+ "source" : [
306+ " Go try loading other model weights! You can find more options on HuggingFace\n " ,
307+ " and use them with `from_preset(\" hf://<namespace>/<model-name>\" )`.\n " ,
308+ " \n " ,
309+ " Happy experimenting!"
310+ ]
311+ }
312+ ],
313+ "metadata" : {
314+ "accelerator" : " None" ,
315+ "colab" : {
316+ "collapsed_sections" : [],
317+ "name" : " hugging_face_keras_integration" ,
318+ "private_outputs" : false ,
319+ "provenance" : [],
320+ "toc_visible" : true
321+ },
322+ "kernelspec" : {
323+ "display_name" : " Python 3" ,
324+ "language" : " python" ,
325+ "name" : " python3"
326+ },
327+ "language_info" : {
328+ "codemirror_mode" : {
329+ "name" : " ipython" ,
330+ "version" : 3
331+ },
332+ "file_extension" : " .py" ,
333+ "mimetype" : " text/x-python" ,
334+ "name" : " python" ,
335+ "nbconvert_exporter" : " python" ,
336+ "pygments_lexer" : " ipython3" ,
337+ "version" : " 3.7.0"
338+ }
339+ },
340+ "nbformat" : 4 ,
341+ "nbformat_minor" : 0
342+ }
0 commit comments