-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsimple_demo.py
More file actions
63 lines (50 loc) · 1.72 KB
/
simple_demo.py
File metadata and controls
63 lines (50 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
os.environ["KERAS_BACKEND"] = "jax"
import jax
import keras
import numpy as np
import keras_remote
@keras_remote.run(accelerator="cpu")
def train_keras_jax_model():
print(f"Keras version: {keras.__version__}")
print(f"Keras backend: {keras.config.backend()}")
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
num_classes = 10
input_shape = (28, 28, 1)
model = keras.Sequential(
[
keras.layers.Input(shape=input_shape),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.5),
keras.layers.Dense(num_classes, activation="softmax"),
]
)
print("Model defined.")
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
],
)
print("Model compiled.")
# Dummy data
num_samples = 128
x_train = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(
np.int32
)
print("Starting model.fit...")
history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
print("Model.fit finished.")
return history.history["loss"][-1]
if __name__ == "__main__":
print("Starting Keras JAX demo...")
loss = train_keras_jax_model()
print(f"Final training loss: {loss}")