Skip to content

Commit 4232bba

Browse files
authored
Guide on how to load and inference keras hub model weights hosted on HF (#2124)
* Keras hub guide for loading and inference weights hosted on HF * Address comments * address nit comments * Add ipynb and md files
1 parent 2967015 commit 4232bba

File tree

3 files changed

+805
-0
lines changed

3 files changed

+805
-0
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"colab_type": "text"
7+
},
8+
"source": [
9+
"# Load HuggingFace Transformers checkpoint into a multi backend KerasHub model\n",
10+
"\n",
11+
"**Author:** [Laxma Reddy Patlolla](https://github.com/laxmareddyp), [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)<br><br>\n",
12+
"**Date created:** 2025/06/17<br><br>\n",
13+
"**Last modified:** 2025/06/17<br><br>\n",
14+
"**Description:** How to load and run inference from KerasHub model checkpoints hosted on HuggingFace Hub."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {
20+
"colab_type": "text"
21+
},
22+
"source": [
23+
"## Introduction\n",
24+
"\n",
25+
"KerasHub has built-in converters for HuggingFace's `.safetensors` models.\n",
26+
"Loading model weights from HuggingFace is therefore no more difficult than\n",
27+
"using KerasHub's own presets.\n",
28+
"\n",
29+
"### KerasHub built-in HuggingFace transformers converters\n",
30+
"\n",
31+
"KerasHub simplifies the use of HuggingFace Transformers models through its\n",
32+
"built-in converters. These converters automatically handle the process of translating\n",
33+
"HuggingFace model checkpoints into a format that's compatible with the Keras ecosystem.\n",
34+
"This means you can seamlessly load a wide variety of pretrained models from the HuggingFace\n",
35+
"Hub directly into KerasHub with just a few lines of code.\n",
36+
"\n",
37+
"Key advantages of using KerasHub converters:\n",
38+
"- **Ease of Use**: Load HuggingFace models without manual conversion steps.\n",
39+
"- **Broad Compatibility**: Access a vast range of models available on the HuggingFace Hub.\n",
40+
"- **Seamless Integration**: Work with these models using familiar Keras APIs for training,\n",
41+
"evaluation, and inference.\n",
42+
"\n",
43+
"Fortunately, all of this happens behind the scenes, so you can focus on using\n",
44+
"the models rather than managing the conversion process!\n",
45+
"\n",
46+
"## Setup\n",
47+
"\n",
48+
"Before you begin, make sure you have the necessary libraries installed.\n",
49+
"You'll primarily need `keras` and `keras_hub`.\n",
50+
"\n",
51+
"**Note:** Changing the backend after Keras has been imported might not work as expected.\n",
52+
"Ensure `KERAS_BACKEND` is set at the beginning of your script."
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 0,
58+
"metadata": {
59+
"colab_type": "code"
60+
},
61+
"outputs": [],
62+
"source": [
63+
"import os\n",
64+
"\n",
65+
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # \"tensorflow\" or \"torch\"\n",
66+
"\n",
67+
"import keras\n",
68+
"import keras_hub"
69+
]
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"metadata": {
74+
"colab_type": "text"
75+
},
76+
"source": [
77+
"KerasHub allows you to easily load models from HuggingFace Transformers.\n",
78+
"Here's an example of how to load a Gemma causal language model.\n",
79+
"In this particular case, you will need to consent to Google's license on\n",
80+
"HuggingFace for being able to download model weights, and provide your\n",
81+
"`HF_TOKEN` as environment variable or as \"Colab secret\" when working with\n",
82+
"Google Colab."
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": 0,
88+
"metadata": {
89+
"colab_type": "code"
90+
},
91+
"outputs": [],
92+
"source": [
93+
"# not a keras checkpoint, it is a HF transformer checkpoint\n",
94+
"\n",
95+
"gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(\"hf://google/gemma-2b\")"
96+
]
97+
},
98+
{
99+
"cell_type": "markdown",
100+
"metadata": {
101+
"colab_type": "text"
102+
},
103+
"source": [
104+
"Let us try running some inference"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 0,
110+
"metadata": {
111+
"colab_type": "code"
112+
},
113+
"outputs": [],
114+
"source": [
115+
"gemma_lm.generate(\"I want to say\", max_length=30)"
116+
]
117+
},
118+
{
119+
"cell_type": "markdown",
120+
"metadata": {
121+
"colab_type": "text"
122+
},
123+
"source": [
124+
"### Fine-tune a Gemma Transformer checkpoint using the Keras model.fit(...) API.\n",
125+
"\n",
126+
"Once you have loaded HuggingFace weights, you can use the instantiated model\n",
127+
"just like any other KerasHub model. For instance, you might fine-tune the model\n",
128+
"on your own data like so:"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 0,
134+
"metadata": {
135+
"colab_type": "code"
136+
},
137+
"outputs": [],
138+
"source": [
139+
"features = [\"The quick brown fox jumped.\", \"I forgot my homework.\"]\n",
140+
"gemma_lm.fit(x=features, batch_size=2)"
141+
]
142+
},
143+
{
144+
"cell_type": "markdown",
145+
"metadata": {
146+
"colab_type": "text"
147+
},
148+
"source": [
149+
"### Saving and uploading the new checkpoint\n",
150+
"\n",
151+
"To store and share your fine-tuned model, KerasHub makes it easy to save or\n",
152+
"upload it using standard methods. You can do this through familiar commands\n",
153+
"such as:"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": 0,
159+
"metadata": {
160+
"colab_type": "code"
161+
},
162+
"outputs": [],
163+
"source": [
164+
"gemma_lm.save_to_preset(\"./gemma-2b-finetuned\")\n",
165+
"keras_hub.upload_preset(\"hf://laxmareddyp/gemma-2b-finetune\", \"./gemma-2b-finetuned\")"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {
171+
"colab_type": "text"
172+
},
173+
"source": [
174+
"By uploading your preset, you can then load it from anywhere using:\n",
175+
"`loaded_model = keras_hub.models.GemmaCausalLM.from_preset(\"hf://YOUR_HF_USERNAME/gemma-2b-finetuned\")`\n",
176+
"\n",
177+
"For a comprehensive, step-by-step guide on uploading your model, refer to the official KerasHub upload documentation.\n",
178+
"You can find all the details here: [KerasHub Upload Guide](https://keras.io/keras_hub/guides/upload/)\n",
179+
"\n",
180+
"By integrating HuggingFace Transformers, KerasHub significantly expands your access to pretrained models.\n",
181+
"The Hugging Face Hub now hosts well over 750k+ model checkpoints across various domains such as NLP,\n",
182+
"Computer Vision, Audio, and more. Of these, approximately 400K models are currently compatible with KerasHub,\n",
183+
"giving you access to a vast and diverse selection of state-of-the-art architectures for your projects.\n",
184+
"\n",
185+
"With KerasHub, you can:\n",
186+
"- **Tap into State-of-the-Art Models**: Easily experiment with the latest\n",
187+
"architectures and pretrained weights from the research community and industry.\n",
188+
"- **Reduce Development Time**: Leverage existing models instead of training from scratch,\n",
189+
"saving significant time and computational resources.\n",
190+
"- **Enhance Model Capabilities**: Find specialized models for a wide array of tasks,\n",
191+
"from text generation and translation to image segmentation and object detection.\n",
192+
"\n",
193+
"This seamless access empowers you to build more powerful and sophisticated AI applications with Keras.\n",
194+
"\n",
195+
"## Use a wider range of frameworks\n",
196+
"\n",
197+
"Keras 3, and by extension KerasHub, is designed for multi-framework compatibility.\n",
198+
"This means you can run your models with different backend frameworks like JAX, TensorFlow, and PyTorch.\n",
199+
"This flexibility allows you to:\n",
200+
"\n",
201+
"- **Choose the Best Backend for Your Needs**: Select a backend based on performance characteristics,\n",
202+
"hardware compatibility (e.g., TPUs with JAX), or existing team expertise.\n",
203+
"- **Interoperability**: More easily integrate KerasHub models into existing\n",
204+
"workflows that might be built on TensorFlow or PyTorch.\n",
205+
"- **Future-Proofing**: Adapt to evolving framework landscapes without\n",
206+
"rewriting your core model logic.\n",
207+
"\n",
208+
"## Run transformer models in JAX backend and on TPUs\n",
209+
"\n",
210+
"To experiment with a model using JAX, you can utilize Keras by setting its backend to JAX.\n",
211+
"By switching Keras\u2019s backend before model construction, and ensuring your environment is connected to a TPU runtime.\n",
212+
"Keras will then automatically leverage JAX\u2019s TPU support,\n",
213+
"allowing your model to train efficiently on TPU hardware without further code changes."
214+
]
215+
},
216+
{
217+
"cell_type": "code",
218+
"execution_count": 0,
219+
"metadata": {
220+
"colab_type": "code"
221+
},
222+
"outputs": [],
223+
"source": [
224+
"import os\n",
225+
"\n",
226+
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
227+
"gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(\"hf://google/gemma-2b\")"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {
233+
"colab_type": "text"
234+
},
235+
"source": [
236+
"## Additional Examples\n",
237+
"\n",
238+
"### Generation\n",
239+
"\n",
240+
"Here\u2019s an example using Llama: loading a PyTorch Hugging Face transformer checkpoint into KerasHub and running it on the JAX backend."
241+
]
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": 0,
246+
"metadata": {
247+
"colab_type": "code"
248+
},
249+
"outputs": [],
250+
"source": [
251+
"import os\n",
252+
"\n",
253+
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
254+
"\n",
255+
"from keras_hub.models import Llama3CausalLM\n",
256+
"\n",
257+
"# Get the model\n",
258+
"causal_lm = Llama3CausalLM.from_preset(\"hf://NousResearch/Hermes-2-Pro-Llama-3-8B\")\n",
259+
"\n",
260+
"prompts = [\n",
261+
" \"\"\"<|im_start|>system\n",
262+
"You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>\n",
263+
"<|im_start|>user\n",
264+
"Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>\n",
265+
"<|im_start|>assistant\"\"\",\n",
266+
"]\n",
267+
"\n",
268+
"# Generate from the model\n",
269+
"causal_lm.generate(prompts, max_length=30)[0]"
270+
]
271+
},
272+
{
273+
"cell_type": "markdown",
274+
"metadata": {
275+
"colab_type": "text"
276+
},
277+
"source": [
278+
"### Changing precision\n",
279+
"\n",
280+
"You can adjust your model\u2019s precision by configuring it through `keras.config` as follows"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": 0,
286+
"metadata": {
287+
"colab_type": "code"
288+
},
289+
"outputs": [],
290+
"source": [
291+
"import keras\n",
292+
"\n",
293+
"keras.config.set_dtype_policy(\"bfloat16\")\n",
294+
"\n",
295+
"from keras_hub.models import Llama3CausalLM\n",
296+
"\n",
297+
"causal_lm = Llama3CausalLM.from_preset(\"hf://NousResearch/Hermes-2-Pro-Llama-3-8B\")"
298+
]
299+
},
300+
{
301+
"cell_type": "markdown",
302+
"metadata": {
303+
"colab_type": "text"
304+
},
305+
"source": [
306+
"Go try loading other model weights! You can find more options on HuggingFace\n",
307+
"and use them with `from_preset(\"hf://<namespace>/<model-name>\")`.\n",
308+
"\n",
309+
"Happy experimenting!"
310+
]
311+
}
312+
],
313+
"metadata": {
314+
"accelerator": "None",
315+
"colab": {
316+
"collapsed_sections": [],
317+
"name": "hugging_face_keras_integration",
318+
"private_outputs": false,
319+
"provenance": [],
320+
"toc_visible": true
321+
},
322+
"kernelspec": {
323+
"display_name": "Python 3",
324+
"language": "python",
325+
"name": "python3"
326+
},
327+
"language_info": {
328+
"codemirror_mode": {
329+
"name": "ipython",
330+
"version": 3
331+
},
332+
"file_extension": ".py",
333+
"mimetype": "text/x-python",
334+
"name": "python",
335+
"nbconvert_exporter": "python",
336+
"pygments_lexer": "ipython3",
337+
"version": "3.7.0"
338+
}
339+
},
340+
"nbformat": 4,
341+
"nbformat_minor": 0
342+
}

0 commit comments

Comments
 (0)