Skip to content

Commit 516107f

Browse files
Adds ruff config (#29)
* Adds ruff config * initial format * more fixes
1 parent 849872c commit 516107f

31 files changed

+2212
-2023
lines changed

.github/workflows/lint.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: Lint
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
ruff:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: astral-sh/ruff-action@v3
15+
with:
16+
args: check
17+
- uses: astral-sh/ruff-action@v3
18+
with:
19+
args: format --check

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.15.1
4+
hooks:
5+
- id: ruff
6+
args: [--fix]
7+
- id: ruff-format

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,30 @@ Use `--yes` to skip the confirmation prompt.
351351

352352
Contributions are welcome. Please read our contributing guidelines before submitting pull requests.
353353

354+
### Development Setup
355+
356+
1. Install the package with dev dependencies:
357+
358+
```bash
359+
pip install -e ".[dev]"
360+
```
361+
362+
2. Install pre-commit hooks:
363+
364+
```bash
365+
pre-commit install
366+
```
367+
368+
This enables automatic linting and formatting checks (via [Ruff](https://docs.astral.sh/ruff/)) on every commit.
369+
370+
To run the checks manually against all files:
371+
372+
```bash
373+
pre-commit run --all-files
374+
```
375+
376+
### Submitting Changes
377+
354378
1. Fork the repository
355379
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
356380
3. Commit your changes (`git commit -m 'Add amazing feature'`)

examples/demo_train.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import jax
77
import keras
88
import numpy as np
9+
910
from keras_remote import core as keras_remote
1011

1112

12-
@keras_remote.run(accelerator='v2-8')
13+
@keras_remote.run(accelerator="v2-8")
1314
def train_keras_jax_model():
1415
host = socket.gethostname()
1516
print(f"Running on host: {host}")
@@ -22,40 +23,43 @@ def train_keras_jax_model():
2223
input_shape = (28, 28, 1)
2324

2425
model = keras.Sequential(
25-
[
26-
keras.layers.Input(shape=input_shape),
27-
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
28-
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
29-
keras.layers.MaxPooling2D(pool_size=(2, 2)),
30-
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
31-
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
32-
keras.layers.GlobalAveragePooling2D(),
33-
keras.layers.Dropout(0.5),
34-
keras.layers.Dense(num_classes, activation="softmax"),
35-
]
26+
[
27+
keras.layers.Input(shape=input_shape),
28+
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
29+
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
30+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
31+
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
32+
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
33+
keras.layers.GlobalAveragePooling2D(),
34+
keras.layers.Dropout(0.5),
35+
keras.layers.Dense(num_classes, activation="softmax"),
36+
]
3637
)
3738
print("Model defined.")
3839

3940
model.compile(
40-
loss=keras.losses.SparseCategoricalCrossentropy(),
41-
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
42-
metrics=[
43-
keras.metrics.SparseCategoricalAccuracy(name="acc"),
44-
],
41+
loss=keras.losses.SparseCategoricalCrossentropy(),
42+
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
43+
metrics=[
44+
keras.metrics.SparseCategoricalAccuracy(name="acc"),
45+
],
4546
)
4647
print("Model compiled.")
4748

4849
# Dummy data
4950
num_samples = 128
5051
x_train = np.random.rand(num_samples, *input_shape).astype(np.float32)
51-
y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(np.int32)
52+
y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(
53+
np.int32
54+
)
5255

5356
print("Starting model.fit...")
5457
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
5558
print("Model.fit finished.")
5659

5760
return f"Keras JAX training complete on {host}"
5861

62+
5963
if __name__ == "__main__":
6064
print("Starting Keras JAX demo...")
6165
result = train_keras_jax_model()

examples/example_gke.py

Lines changed: 94 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -42,123 +42,124 @@
4242

4343
import keras_remote
4444

45+
4546
# Example 1: CPU-only execution (works with default cluster)
4647
@keras_remote.run(accelerator="cpu")
4748
def simple_computation(x, y):
48-
"""Simple addition that runs on remote CPU."""
49-
result = x + y
50-
print(f"Computing {x} + {y} = {result}")
51-
return result
49+
"""Simple addition that runs on remote CPU."""
50+
result = x + y
51+
print(f"Computing {x} + {y} = {result}")
52+
return result
5253

5354

5455
# Example 2: Keras model training on CPU
5556
@keras_remote.run(accelerator="cpu")
5657
def train_simple_model_cpu():
57-
"""Train a simple Keras model on remote CPU."""
58-
59-
# Create a simple model
60-
model = keras.Sequential(
61-
[
62-
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
63-
keras.layers.Dense(64, activation="relu"),
64-
keras.layers.Dense(1),
65-
]
66-
)
58+
"""Train a simple Keras model on remote CPU."""
59+
60+
# Create a simple model
61+
model = keras.Sequential(
62+
[
63+
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
64+
keras.layers.Dense(64, activation="relu"),
65+
keras.layers.Dense(1),
66+
]
67+
)
6768

68-
model.compile(optimizer="adam", loss="mse")
69+
model.compile(optimizer="adam", loss="mse")
6970

70-
# Generate some dummy data
71-
x_train = np.random.randn(1000, 10)
72-
y_train = np.random.randn(1000, 1)
71+
# Generate some dummy data
72+
x_train = np.random.randn(1000, 10)
73+
y_train = np.random.randn(1000, 1)
7374

74-
# Train the model
75-
print("Training model on CPU...")
76-
history = model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)
75+
# Train the model
76+
print("Training model on CPU...")
77+
history = model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)
7778

78-
print(f"Final loss: {history.history['loss'][-1]}")
79-
return history.history["loss"][-1]
79+
print(f"Final loss: {history.history['loss'][-1]}")
80+
return history.history["loss"][-1]
8081

8182

8283
# Example 3: GPU training (requires GPU node pool)
8384
@keras_remote.run(accelerator="nvidia-tesla-t4")
8485
def train_model_gpu():
85-
"""Train a Keras model on remote GPU. Requires T4 GPU node pool."""
86-
model = keras.Sequential(
87-
[
88-
keras.layers.Dense(128, activation="relu", input_shape=(20,)),
89-
keras.layers.Dense(128, activation="relu"),
90-
keras.layers.Dense(1),
91-
]
92-
)
86+
"""Train a Keras model on remote GPU. Requires T4 GPU node pool."""
87+
model = keras.Sequential(
88+
[
89+
keras.layers.Dense(128, activation="relu", input_shape=(20,)),
90+
keras.layers.Dense(128, activation="relu"),
91+
keras.layers.Dense(1),
92+
]
93+
)
9394

94-
model.compile(optimizer="adam", loss="mse")
95+
model.compile(optimizer="adam", loss="mse")
9596

96-
x_train = np.random.randn(5000, 20)
97-
y_train = np.random.randn(5000, 1)
97+
x_train = np.random.randn(5000, 20)
98+
y_train = np.random.randn(5000, 1)
9899

99-
print("Training model on T4 GPU...")
100-
history = model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
100+
print("Training model on T4 GPU...")
101+
history = model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
101102

102-
return history.history["loss"][-1]
103+
return history.history["loss"][-1]
103104

104105

105106
def main():
106-
"""Run examples."""
107-
print("=" * 60)
108-
print("Keras Remote - GKE Examples")
109-
print("=" * 60)
110-
111-
# Example 1: Simple computation (CPU)
112-
print("\n--- Example 1: Simple Computation (CPU) ---")
113-
print("Running simple_computation(10, 20) on GKE...")
114-
result = simple_computation(10, 20)
115-
print(f"Result: {result}")
116-
117-
# Example 2: Model training on CPU
118-
print("\n--- Example 2: Keras Model Training (CPU) ---")
119-
print("Training a simple model on CPU...")
120-
final_loss = train_simple_model_cpu()
121-
print(f"Model trained. Final loss: {final_loss}")
122-
123-
# Example 3: GPU training (requires GPU node pool)
124-
# Uncomment to run if you have T4 GPU nodes available
125-
# print("\n--- Example 3: Model Training on T4 GPU ---")
126-
# final_loss = train_model_gpu()
127-
# print(f"Model trained. Final loss: {final_loss}")
128-
129-
print("\n" + "=" * 60)
130-
print("Examples completed!")
131-
print("=" * 60)
107+
"""Run examples."""
108+
print("=" * 60)
109+
print("Keras Remote - GKE Examples")
110+
print("=" * 60)
111+
112+
# Example 1: Simple computation (CPU)
113+
print("\n--- Example 1: Simple Computation (CPU) ---")
114+
print("Running simple_computation(10, 20) on GKE...")
115+
result = simple_computation(10, 20)
116+
print(f"Result: {result}")
117+
118+
# Example 2: Model training on CPU
119+
print("\n--- Example 2: Keras Model Training (CPU) ---")
120+
print("Training a simple model on CPU...")
121+
final_loss = train_simple_model_cpu()
122+
print(f"Model trained. Final loss: {final_loss}")
123+
124+
# Example 3: GPU training (requires GPU node pool)
125+
# Uncomment to run if you have T4 GPU nodes available
126+
# print("\n--- Example 3: Model Training on T4 GPU ---")
127+
# final_loss = train_model_gpu()
128+
# print(f"Model trained. Final loss: {final_loss}")
129+
130+
print("\n" + "=" * 60)
131+
print("Examples completed!")
132+
print("=" * 60)
132133

133134

134135
if __name__ == "__main__":
135-
# Prerequisites:
136-
# 1. Set KERAS_REMOTE_PROJECT environment variable to your GCP project ID
137-
# 2. Configure kubectl: gcloud container clusters get-credentials <cluster> --zone <zone>
138-
# 3. Ensure your GKE cluster has GPU nodes with the required accelerator type
139-
if not os.environ.get("KERAS_REMOTE_PROJECT"):
140-
print("ERROR: KERAS_REMOTE_PROJECT environment variable not set")
141-
print("Please set it to your GCP project ID:")
142-
print(" export KERAS_REMOTE_PROJECT=your-project-id")
143-
exit(1)
144-
145-
# Verify kubectl is configured
146-
try:
147-
result = subprocess.run(
148-
["kubectl", "cluster-info"], capture_output=True, text=True, timeout=10
149-
)
150-
if result.returncode != 0:
151-
print("ERROR: kubectl is not configured or cluster is not accessible")
152-
print("Please configure kubectl:")
153-
print(
154-
" gcloud container clusters get-credentials <cluster-name> --zone <zone>"
155-
)
156-
exit(1)
157-
except FileNotFoundError:
158-
print("ERROR: kubectl not found. Please install kubectl.")
159-
exit(1)
160-
except subprocess.TimeoutExpired:
161-
print("ERROR: kubectl timed out. Check your cluster connectivity.")
162-
exit(1)
163-
164-
main()
136+
# Prerequisites:
137+
# 1. Set KERAS_REMOTE_PROJECT environment variable to your GCP project ID
138+
# 2. Configure kubectl: gcloud container clusters get-credentials <cluster> --zone <zone>
139+
# 3. Ensure your GKE cluster has GPU nodes with the required accelerator type
140+
if not os.environ.get("KERAS_REMOTE_PROJECT"):
141+
print("ERROR: KERAS_REMOTE_PROJECT environment variable not set")
142+
print("Please set it to your GCP project ID:")
143+
print(" export KERAS_REMOTE_PROJECT=your-project-id")
144+
exit(1)
145+
146+
# Verify kubectl is configured
147+
try:
148+
result = subprocess.run(
149+
["kubectl", "cluster-info"], capture_output=True, text=True, timeout=10
150+
)
151+
if result.returncode != 0:
152+
print("ERROR: kubectl is not configured or cluster is not accessible")
153+
print("Please configure kubectl:")
154+
print(
155+
" gcloud container clusters get-credentials <cluster-name> --zone <zone>"
156+
)
157+
exit(1)
158+
except FileNotFoundError:
159+
print("ERROR: kubectl not found. Please install kubectl.")
160+
exit(1)
161+
except subprocess.TimeoutExpired:
162+
print("ERROR: kubectl timed out. Check your cluster connectivity.")
163+
exit(1)
164+
165+
main()

0 commit comments

Comments
 (0)