Skip to content

Commit f8ecb97

Browse files
committed
Add vllm guide
1 parent e330ef2 commit f8ecb97

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

guides/keras_hub/vllm_guide.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""
2+
Title: Serving Gemma with vLLM
3+
Author: Dhiraj
4+
Date created: 2025/08/16
5+
Last modified: 2025/08/18
6+
Description: Export Gemma models from KerasHub to Hugging Face and serve with vLLM for fast inference.
7+
Accelerator: TPU and GPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
This guide demonstrates how to export Gemma models from KerasHub to the Hugging Face format and serve them using vLLM for efficient, high-throughput inference. We'll walk through the process step-by-step, from loading a pre-trained Gemma model in KerasHub to running inferences with vLLM in a Google Colab environment.
14+
15+
vLLM is an optimized serving engine for large language models that leverages techniques like PagedAttention to enable continuous batching and high GPU utilization. By exporting KerasHub models to a compatible format, you can take advantage of vLLM's performance benefits while starting from the Keras ecosystem
16+
17+
At present, this is supported only for Gemma 2 and its presets. In the future, there will be more coverage of the models in KerasHub.
18+
19+
**Note:** We'll perform the model export on a TPU runtime (for efficiency with larger models) and then switch to a GPU runtime for serving with vLLM, as vLLM [does not support TPU v2 on Colab](https://docs.vllm.ai/en/v0.5.5/getting_started/tpu-installation.html)
20+
"""
21+
22+
"""
23+
## Setup
24+
25+
First, install the required libraries. Select a TPU runtime in Colab before running these cells.
26+
27+
"""
28+
29+
"""shell
30+
!pip install -q --upgrade keras-hub huggingface-hub
31+
"""
32+
33+
import keras_hub
34+
from huggingface_hub import snapshot_download
35+
import os
36+
import shutil
37+
import json
38+
"""
39+
## Loading and Exporting the Model
40+
41+
Load a pre-trained Gemma 2 model from KerasHub using the 'gemma2_instruct_2b_en' preset. This is an instruction-tuned variant suitable for conversational tasks.
42+
43+
**Note:** The export method needs to map the weights from Keras to safetensors, hence requiring double the RAM needed to load a preset. This is also the reason why we are running on a TPU instance in Colab as it offers more VRAM instead of GPU.
44+
"""
45+
46+
# Load the pre-trained Gemma model
47+
model_preset = 'gemma2_instruct_2b_en'
48+
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(model_preset)
49+
print("✅ Gemma model loaded successfully")
50+
51+
# Set export path
52+
export_path = "./gemma_exported"
53+
54+
# Export to Hugging Face format
55+
gemma_lm.export_to_transformers(export_path)
56+
print(f"Model exported successfully to {export_path}")
57+
58+
"""
59+
## Downloading Additional Metadata
60+
61+
vLLM requires complete Hugging Face model configuration files. Download these from the original Gemma repository on Hugging Face.
62+
"""
63+
64+
SERVABLE_CKPT_DIR = "./gemma_exported"
65+
66+
# Download metadata files
67+
snapshot_download(repo_id="google/gemma-2-2b-it", allow_patterns="*.json", local_dir=SERVABLE_CKPT_DIR)
68+
print("✅ Metadata files downloaded")
69+
70+
"""
71+
## Updating the Model Index
72+
73+
The exported model uses a single safetensors file, unlike the original which may have multiple shards. Update the index file to reflect this.
74+
"""
75+
76+
index_path = os.path.join(SERVABLE_CKPT_DIR, "model.safetensors.index.json")
77+
with open(index_path, "r") as f:
78+
index_data = json.load(f)
79+
80+
# Replace shard references with single file
81+
for key in index_data.get("weight_map", {}):
82+
index_data["weight_map"][key] = "model.safetensors"
83+
84+
with open(index_path, "w") as f:
85+
json.dump(index_data, f, indent=2)
86+
87+
print("✅ Model index updated")
88+
89+
# Verify the directory contents
90+
print("Directory contents:")
91+
for file in os.listdir(SERVABLE_CKPT_DIR):
92+
size = os.path.getsize(os.path.join(SERVABLE_CKPT_DIR, file)) / (1024 * 1024)
93+
print(f"{file}: {size:.2f} MB")
94+
95+
"""
96+
## Saving to Google Drive
97+
98+
Save the files to Google Drive. This is needed because vLLM currently [does not support TPU v2 on Colab](https://docs.vllm.ai/en/v0.5.5/getting_started/tpu-installation.html) and cannot dynamically switch the backend to CPU. Switch to a different Colab GPU instance for serving after saving. If you are using Cloud TPU or GPU from the start, you may skip this step.
99+
100+
**Note:** the `model.safetensors` file is ~9.5GB for Gemma 2B, so ensure you have enough space in your Google Drive.
101+
"""
102+
103+
from google.colab import drive
104+
105+
drive.mount("/content/drive")
106+
107+
drive_dir = "/content/drive/MyDrive/gemma_exported"
108+
109+
# Remove any existing exports with the same name
110+
if os.path.exists(drive_dir):
111+
shutil.rmtree(drive_dir)
112+
print("✅ Existing export removed")
113+
114+
# Copy the exported model to Google Drive
115+
shutil.copytree(SERVABLE_CKPT_DIR, drive_dir)
116+
print("✅ Model copied to Google Drive")
117+
"""
118+
Verify the file sizes to ensure no corruption during copy. Here's how they should appear:
119+
"""
120+
121+
print("Drive directory contents:")
122+
for file in os.listdir(drive_dir):
123+
size = os.path.getsize(os.path.join(drive_dir, file)) / (1024 * 1024)
124+
print(f"{file}: {size:.2f} MB")
125+
126+
"""
127+
Disconnect TPU runtime (if applicable) and re-connect with a T4 GPU runtime before proceeding.
128+
"""
129+
130+
from google.colab import drive
131+
132+
drive.mount("/content/drive")
133+
134+
SERVABLE_CKPT_DIR = "/content/drive/MyDrive/gemma_exported"
135+
136+
print("Drive directory contents:")
137+
for file in os.listdir(SERVABLE_CKPT_DIR):
138+
size = os.path.getsize(os.path.join(SERVABLE_CKPT_DIR, file)) / (1024 * 1024)
139+
print(f"{file}: {size:.2f} MB")
140+
141+
142+
"""
143+
## Install vLLM
144+
"""
145+
146+
"""shell
147+
!pip install -q vllm
148+
"""
149+
150+
"""
151+
## Instantiating vLLM
152+
153+
Load the exported model into vLLM for serving.
154+
"""
155+
from vllm import LLM, SamplingParams
156+
157+
llm = LLM(model=SERVABLE_CKPT_DIR, load_format="safetensors", dtype="float32")
158+
print("✅ vLLM engine initialized")
159+
160+
"""
161+
## Generating with vLLM
162+
163+
First, test with a simple prompt to verify the setup.
164+
"""
165+
simple_prompt = "Hello, what is vLLM?"
166+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=128)
167+
outputs = llm.generate(simple_prompt, sampling_params)
168+
169+
for output in outputs:
170+
print(f"Prompt: {output.prompt}\nGenerated text: {output.outputs[0].text}")
171+
172+
"""
173+
As we have loaded the weights of the Gemma instruct model, let's use a formatted example with the chat template.
174+
"""
175+
reasoning_start = "<start_working_out>"
176+
reasoning_end = "<end_working_out>"
177+
solution_start = "<SOLUTION>"
178+
solution_end = "</SOLUTION>"
179+
180+
SYSTEM_PROMPT = f"""You are given a problem.
181+
Think about the problem and provide your working out.
182+
Place it between {reasoning_start} and {reasoning_end}.
183+
Then, provide your solution between {solution_start} and {solution_end}"""
184+
185+
TEMPLATE = """
186+
<start_of_turn>user
187+
{system_prompt}
188+
189+
{question}<end_of_turn>
190+
<start_of_turn>model"""
191+
192+
193+
question = (
194+
"Trevor and two of his neighborhood friends go to the toy shop every year "
195+
"to buy toys. Trevor always spends $20 more than his friend Reed on toys, "
196+
"and Reed spends 2 times as much money as their friend Quinn on the toys. "
197+
"If Trevor spends $80 every year to buy his toys, calculate how much money "
198+
"in total the three spend in 4 years."
199+
)
200+
prompts = [TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)]
201+
202+
sampling_params = SamplingParams(temperature=0.9, top_p=0.92, max_tokens=768)
203+
outputs = llm.generate(prompts, sampling_params)
204+
for output in outputs:
205+
print("===============================")
206+
print(f"Prompt: {output.prompt}\nGenerated text: {output.outputs[0].text}")
207+
208+
"""
209+
## Conclusion
210+
211+
You've now successfully exported a KerasHub Gemma model to Hugging Face format and served it with vLLM for efficient inference. This setup enables high-throughput generation, suitable for production or batch processing.
212+
213+
Experiment with different prompts, sampling parameters, or larger Gemma variants (ensure sufficient GPU memory). For deployment beyond Colab, consider Docker containers or cloud instances.
214+
215+
Happy experimenting!
216+
"""

0 commit comments

Comments
 (0)