-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgemma3_sft_demo.py
More file actions
37 lines (29 loc) · 1017 Bytes
/
gemma3_sft_demo.py
File metadata and controls
37 lines (29 loc) · 1017 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import keras_hub
from keras_remote import core as keras_remote
@keras_remote.run(
accelerator="v5litepod-1", capture_env_vars=["KAGGLE_*", "GOOGLE_CLOUD_*"]
)
def train_gemma():
# Data for SFT
print("Starting Gemma 3 SFT training...")
features = {
"prompts": ["Capital of India?", "Capital of South Africa?"],
"responses": ["New Delhi", "Pretoria"],
}
print("Data prepared.")
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_1b")
print("Model initialized.")
# Fine-tune
gemma_lm.fit(x=features, batch_size=1)
print("Gemma 3 SFT training done")
if __name__ == "__main__":
# Set environment variables for TPU
os.environ["KERAS_BACKEND"] = "jax"
# set environment variables for gcp
os.environ["GOOGLE_CLOUD_PROJECT"] = "tpu-prod-123456"
os.environ["GOOGLE_CLOUD_ZONE"] = "us-central1-a"
# set environment variables for kaggle
os.environ["KAGGLE_USERNAME"] = "your_kaggle_username"
os.environ["KAGGLE_KEY"] = "your_kaggle_key"
train_gemma()