-
Notifications
You must be signed in to change notification settings - Fork 1
Add Pathways Integration #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
7453b29
20a54ab
01ac7b7
3a69ed8
08881d9
8085ae3
2769b1f
d4e97b6
83eca73
50cf0c2
b010eb6
364d81d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import os | ||
|
|
||
| os.environ["KERAS_BACKEND"] = "jax" | ||
|
|
||
| import keras | ||
| import numpy as np | ||
| from keras import layers | ||
|
|
||
| import keras_remote | ||
|
|
||
|
|
||
| # A simple model that will be executed remotely | ||
| @keras_remote.run( | ||
| accelerator="v5litepod-1", | ||
| ) | ||
divyashreepathihalli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def train_simple_model(): | ||
| print("Running Pathways job on JAX Backend!") | ||
|
|
||
| # Create a simple dataset | ||
| x = np.random.rand(1000, 10) | ||
| y = np.random.randint(0, 2, size=(1000, 1)) | ||
|
|
||
| # A simple sequential model | ||
| model = keras.Sequential( | ||
| [ | ||
| keras.Input(shape=(10,)), | ||
| layers.Dense(32, activation="relu"), | ||
| layers.Dense(16, activation="relu"), | ||
| layers.Dense(1, activation="sigmoid"), | ||
| ] | ||
| ) | ||
|
|
||
| model.compile( | ||
| optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] | ||
| ) | ||
|
|
||
| print("Model Architecture:") | ||
| model.summary() | ||
|
|
||
| # Train the model | ||
| print("\nStarting Training...") | ||
| history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2) | ||
|
|
||
| print("\nTraining completed successfully on Pathways!") | ||
| return history.history | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm, there are no user code changes to run on Pathways within their remote function? All it needs is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah if it is None. It auto detects if the user requested for a multi node TPU and picks the pathways backend |
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| print("Submitting Pathways training job...") | ||
| result_history = train_simple_model() | ||
| print("Final validation accuracy:", result_history["val_accuracy"][-1]) | ||
Uh oh!
There was an error while loading. Please reload this page.