Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Lint

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3
with:
args: check
- uses: astral-sh/ruff-action@v3
with:
args: format --check
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.1
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,30 @@ Use `--yes` to skip the confirmation prompt.

Contributions are welcome. Please read our contributing guidelines before submitting pull requests.

### Development Setup

1. Install the package with dev dependencies:

```bash
pip install -e ".[dev]"
```

2. Install pre-commit hooks:

```bash
pre-commit install
```

This enables automatic linting and formatting checks (via [Ruff](https://docs.astral.sh/ruff/)) on every commit.

To run the checks manually against all files:

```bash
pre-commit run --all-files
```

### Submitting Changes

1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add amazing feature'`)
Expand Down
40 changes: 22 additions & 18 deletions examples/demo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import jax
import keras
import numpy as np

from keras_remote import core as keras_remote


@keras_remote.run(accelerator='v2-8')
@keras_remote.run(accelerator="v2-8")
def train_keras_jax_model():
host = socket.gethostname()
print(f"Running on host: {host}")
Expand All @@ -22,40 +23,43 @@ def train_keras_jax_model():
input_shape = (28, 28, 1)

model = keras.Sequential(
[
keras.layers.Input(shape=input_shape),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.5),
keras.layers.Dense(num_classes, activation="softmax"),
]
[
keras.layers.Input(shape=input_shape),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.5),
keras.layers.Dense(num_classes, activation="softmax"),
]
)
print("Model defined.")

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
],
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="acc"),
],
)
print("Model compiled.")

# Dummy data
num_samples = 128
x_train = np.random.rand(num_samples, *input_shape).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(np.int32)
y_train = np.random.randint(0, num_classes, size=(num_samples,)).astype(
np.int32
)

print("Starting model.fit...")
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
print("Model.fit finished.")

return f"Keras JAX training complete on {host}"


if __name__ == "__main__":
print("Starting Keras JAX demo...")
result = train_keras_jax_model()
Expand Down
187 changes: 94 additions & 93 deletions examples/example_gke.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,123 +42,124 @@

import keras_remote


# Example 1: CPU-only execution (works with default cluster)
@keras_remote.run(accelerator="cpu")
def simple_computation(x, y):
"""Simple addition that runs on remote CPU."""
result = x + y
print(f"Computing {x} + {y} = {result}")
return result
"""Simple addition that runs on remote CPU."""
result = x + y
print(f"Computing {x} + {y} = {result}")
return result


# Example 2: Keras model training on CPU
@keras_remote.run(accelerator="cpu")
def train_simple_model_cpu():
"""Train a simple Keras model on remote CPU."""

# Create a simple model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(1),
]
)
"""Train a simple Keras model on remote CPU."""

# Create a simple model
model = keras.Sequential(
[
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(1),
]
)

model.compile(optimizer="adam", loss="mse")
model.compile(optimizer="adam", loss="mse")

# Generate some dummy data
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
# Generate some dummy data
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)

# Train the model
print("Training model on CPU...")
history = model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)
# Train the model
print("Training model on CPU...")
history = model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)

print(f"Final loss: {history.history['loss'][-1]}")
return history.history["loss"][-1]
print(f"Final loss: {history.history['loss'][-1]}")
return history.history["loss"][-1]


# Example 3: GPU training (requires GPU node pool)
@keras_remote.run(accelerator="nvidia-tesla-t4")
def train_model_gpu():
"""Train a Keras model on remote GPU. Requires T4 GPU node pool."""
model = keras.Sequential(
[
keras.layers.Dense(128, activation="relu", input_shape=(20,)),
keras.layers.Dense(128, activation="relu"),
keras.layers.Dense(1),
]
)
"""Train a Keras model on remote GPU. Requires T4 GPU node pool."""
model = keras.Sequential(
[
keras.layers.Dense(128, activation="relu", input_shape=(20,)),
keras.layers.Dense(128, activation="relu"),
keras.layers.Dense(1),
]
)

model.compile(optimizer="adam", loss="mse")
model.compile(optimizer="adam", loss="mse")

x_train = np.random.randn(5000, 20)
y_train = np.random.randn(5000, 1)
x_train = np.random.randn(5000, 20)
y_train = np.random.randn(5000, 1)

print("Training model on T4 GPU...")
history = model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
print("Training model on T4 GPU...")
history = model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)

return history.history["loss"][-1]
return history.history["loss"][-1]


def main():
"""Run examples."""
print("=" * 60)
print("Keras Remote - GKE Examples")
print("=" * 60)

# Example 1: Simple computation (CPU)
print("\n--- Example 1: Simple Computation (CPU) ---")
print("Running simple_computation(10, 20) on GKE...")
result = simple_computation(10, 20)
print(f"Result: {result}")

# Example 2: Model training on CPU
print("\n--- Example 2: Keras Model Training (CPU) ---")
print("Training a simple model on CPU...")
final_loss = train_simple_model_cpu()
print(f"Model trained. Final loss: {final_loss}")

# Example 3: GPU training (requires GPU node pool)
# Uncomment to run if you have T4 GPU nodes available
# print("\n--- Example 3: Model Training on T4 GPU ---")
# final_loss = train_model_gpu()
# print(f"Model trained. Final loss: {final_loss}")

print("\n" + "=" * 60)
print("Examples completed!")
print("=" * 60)
"""Run examples."""
print("=" * 60)
print("Keras Remote - GKE Examples")
print("=" * 60)

# Example 1: Simple computation (CPU)
print("\n--- Example 1: Simple Computation (CPU) ---")
print("Running simple_computation(10, 20) on GKE...")
result = simple_computation(10, 20)
print(f"Result: {result}")

# Example 2: Model training on CPU
print("\n--- Example 2: Keras Model Training (CPU) ---")
print("Training a simple model on CPU...")
final_loss = train_simple_model_cpu()
print(f"Model trained. Final loss: {final_loss}")

# Example 3: GPU training (requires GPU node pool)
# Uncomment to run if you have T4 GPU nodes available
# print("\n--- Example 3: Model Training on T4 GPU ---")
# final_loss = train_model_gpu()
# print(f"Model trained. Final loss: {final_loss}")

print("\n" + "=" * 60)
print("Examples completed!")
print("=" * 60)


if __name__ == "__main__":
# Prerequisites:
# 1. Set KERAS_REMOTE_PROJECT environment variable to your GCP project ID
# 2. Configure kubectl: gcloud container clusters get-credentials <cluster> --zone <zone>
# 3. Ensure your GKE cluster has GPU nodes with the required accelerator type
if not os.environ.get("KERAS_REMOTE_PROJECT"):
print("ERROR: KERAS_REMOTE_PROJECT environment variable not set")
print("Please set it to your GCP project ID:")
print(" export KERAS_REMOTE_PROJECT=your-project-id")
exit(1)

# Verify kubectl is configured
try:
result = subprocess.run(
["kubectl", "cluster-info"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
print("ERROR: kubectl is not configured or cluster is not accessible")
print("Please configure kubectl:")
print(
" gcloud container clusters get-credentials <cluster-name> --zone <zone>"
)
exit(1)
except FileNotFoundError:
print("ERROR: kubectl not found. Please install kubectl.")
exit(1)
except subprocess.TimeoutExpired:
print("ERROR: kubectl timed out. Check your cluster connectivity.")
exit(1)

main()
# Prerequisites:
# 1. Set KERAS_REMOTE_PROJECT environment variable to your GCP project ID
# 2. Configure kubectl: gcloud container clusters get-credentials <cluster> --zone <zone>
# 3. Ensure your GKE cluster has GPU nodes with the required accelerator type
if not os.environ.get("KERAS_REMOTE_PROJECT"):
print("ERROR: KERAS_REMOTE_PROJECT environment variable not set")
print("Please set it to your GCP project ID:")
print(" export KERAS_REMOTE_PROJECT=your-project-id")
exit(1)

# Verify kubectl is configured
try:
result = subprocess.run(
["kubectl", "cluster-info"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
print("ERROR: kubectl is not configured or cluster is not accessible")
print("Please configure kubectl:")
print(
" gcloud container clusters get-credentials <cluster-name> --zone <zone>"
)
exit(1)
except FileNotFoundError:
print("ERROR: kubectl not found. Please install kubectl.")
exit(1)
except subprocess.TimeoutExpired:
print("ERROR: kubectl timed out. Check your cluster connectivity.")
exit(1)

main()
Loading