Keras Remote lets users execute Keras/JAX workloads on cloud TPUs and GPUs via a single decorator (@keras_remote.run()). It handles infrastructure provisioning, container building, job submission, and result retrieval on GCP.
keras_remote/
├── core/ # @run decorator, accelerator registry & parser
├── backend/ # Job execution backends (GKE, Pathways)
├── infra/ # Docker container building & caching
├── runner/ # Remote worker entrypoint (runs inside container)
├── utils/ # Serialization (packager) and Cloud Storage helpers
├── cli/ # CLI for infrastructure provisioning (Pulumi-based)
│ ├── commands/ # up, down, status, config
│ └── infra/ # Pulumi programs and stack management
├── credentials.py # Credential verification & auto-setup (shared by core & CLI)
└── constants.py # Zone/region utilities
@keras_remote.run() called
→ JobContext.from_params() # Resolve config from args/env vars
→ ensure_credentials() # Verify/auto-configure gcloud, ADC, kubeconfig
→ _prepare_artifacts() # Serialize function (cloudpickle), zip working dir
→ _build_container() # Build or retrieve cached Docker image
→ _upload_artifacts() # Upload payload.pkl, context.zip to GCS
→ backend.submit_job() # Create K8s Job (GKE) or LeaderWorkerSet (Pathways)
→ backend.wait_for_job() # Poll until completion
→ _download_result() # Fetch result.pkl from GCS
→ _cleanup_and_return() # Delete artifacts, return result or re-raise exception| Module | Responsibility |
|---|---|
core/core.py |
@run() decorator, backend routing, env var capture |
core/accelerators.py |
Accelerator registry (GPUS, TPUS), parser (parse_accelerator) |
credentials.py |
Credential verification & auto-setup (gcloud, ADC, kubeconfig) |
backend/execution.py |
JobContext dataclass, BaseK8sBackend base class, execute_remote() pipeline |
backend/gke_client.py |
K8s Job creation, status polling, pod log retrieval |
backend/pathways_client.py |
LeaderWorkerSet creation for multi-host TPUs |
infra/container_builder.py |
Content-hashed Docker image building via Cloud Build |
utils/packager.py |
save_payload() (cloudpickle), zip_working_dir() |
utils/storage.py |
GCS upload/download/cleanup for job artifacts |
runner/remote_runner.py |
Runs inside container: deserialize, execute, upload result |
cli/main.py |
CLI entry point (keras-remote command) |
JobContext(backend/execution.py): Mutable dataclass carrying all job state through the pipeline — inputs, generated IDs, artifact paths, image URI.BaseK8sBackend(backend/execution.py): Base class withsubmit_job,wait_for_job,cleanup_job. Subclassed byGKEBackendandPathwaysBackend.GpuConfig/TpuConfig(core/accelerators.py): Frozen dataclasses for accelerator metadata. Single source of truth used by runtime, container builder, and CLI.InfraConfig(cli/config.py): CLI provisioning configuration (project, zone, cluster, accelerator).
- Formatter/linter:
ruff(2-space indent, 80-char line length target, E501 ignored) - Rules: B, E, F, N, PYI, T20, TID, SIM, W, I, NPY
- Dataclasses: Frozen for immutable configs, mutable for state objects
Every customizable resource name must follow the same resolution model across all usage paths:
@run()decorator: explicit parameter → env var → error or default- CLI commands:
--flag(withenvvar=) → env var → interactive prompt or default config show: displays current value and source for every configurable name
| Env Var | @run() param |
CLI flag | config show |
Default |
|---|---|---|---|---|
KERAS_REMOTE_PROJECT |
project= |
--project |
Yes | (required) |
KERAS_REMOTE_ZONE |
zone= |
--zone |
Yes | us-central1-a |
KERAS_REMOTE_CLUSTER |
cluster= |
--cluster-name |
Yes | keras-remote-cluster |
KERAS_REMOTE_GKE_NAMESPACE |
namespace= |
(runtime only) | Yes | default |
When adding a new configurable resource name, ensure it is wired into all three paths (decorator, CLI flags on every relevant command, and config show). The GOOGLE_CLOUD_PROJECT env var is also accepted as a fallback for project ID (after KERAS_REMOTE_PROJECT).
- Framework:
absl.testing(not pytest) - Location: Colocated
*_test.pyfiles alongside source modules - Patterns:
@parameterized.named_parametersfor multi-case tests, mocked GCP/K8s APIs,tempfile.TemporaryDirectory()for file ops - Integration tests:
tests/integration/ - E2E tests:
tests/e2e/(requires live GCP resources) - Run tests: Use pytest (e.g.,
/opt/miniconda3/envs/keras-remote-3.12/bin/python -m pytest). Tests useabsl.testinginternally but should be run via pytest for better output.
Images are tagged with SHA256(base_image + accelerator_type + requirements.txt + remote_runner.py). Identical inputs produce the same tag, skipping rebuild.
get_docker_image requires digest-based names (image@sha256:...), not tag-based (image:tag). Use get_tag with resource name projects/{p}/locations/{l}/repositories/{r}/packages/{image}/tags/{tag} to check tagged images.
- Build tool: hatchling
- Python: >=3.11
- Core deps: absl-py, cloudpickle, numpy, keras, google-cloud-{artifact-registry,storage,build}, kubernetes
- CLI deps (optional
[cli]): click, rich, pulumi, pulumi-gcp - Dev deps (optional
[dev]): pre-commit, ruff - Entry point:
keras-remote→keras_remote.cli.main:cli
- CPU / single-node GPU / single-node TPU: GKE backend (K8s Job)
- Multi-node TPU (
TpuConfig.num_nodes > 1): Pathways backend (LeaderWorkerSet) - Explicit
backend=parameter overrides auto-detection
{
"success": bool,
"result": Any, # if success=True
"exception": Exception, # if success=False
"traceback": str, # if success=False
}Exceptions are re-raised locally with the original traceback.