|
| 1 | +# AGENTS.md — Keras Remote |
| 2 | + |
| 3 | +## Project Overview |
| 4 | + |
| 5 | +Keras Remote lets users execute Keras/JAX workloads on cloud TPUs and GPUs via a single decorator (`@keras_remote.run()`). It handles infrastructure provisioning, container building, job submission, and result retrieval on GCP. |
| 6 | + |
| 7 | +## Architecture |
| 8 | + |
| 9 | +``` |
| 10 | +keras_remote/ |
| 11 | +├── core/ # @run decorator, accelerator registry & parser |
| 12 | +├── backend/ # Job execution backends (GKE, Pathways) |
| 13 | +├── infra/ # Docker container building & caching |
| 14 | +├── runner/ # Remote worker entrypoint (runs inside container) |
| 15 | +├── utils/ # Serialization (packager) and Cloud Storage helpers |
| 16 | +├── cli/ # CLI for infrastructure provisioning (Pulumi-based) |
| 17 | +│ ├── commands/ # up, down, status, config |
| 18 | +│ └── infra/ # Pulumi programs and stack management |
| 19 | +└── constants.py # Zone/region utilities |
| 20 | +``` |
| 21 | + |
| 22 | +## Execution Pipeline |
| 23 | + |
| 24 | +```python |
| 25 | +@keras_remote.run() called |
| 26 | + → JobContext.from_params() # Resolve config from args/env vars |
| 27 | + → _prepare_artifacts() # Serialize function (cloudpickle), zip working dir |
| 28 | + → _build_container() # Build or retrieve cached Docker image |
| 29 | + → _upload_artifacts() # Upload payload.pkl, context.zip to GCS |
| 30 | + → backend.submit_job() # Create K8s Job (GKE) or LeaderWorkerSet (Pathways) |
| 31 | + → backend.wait_for_job() # Poll until completion |
| 32 | + → _download_result() # Fetch result.pkl from GCS |
| 33 | + → _cleanup_and_return() # Delete artifacts, return result or re-raise exception |
| 34 | +``` |
| 35 | + |
| 36 | +## Key Modules |
| 37 | + |
| 38 | +| Module | Responsibility | |
| 39 | +| ---------------------------- | ----------------------------------------------------------------------------- | |
| 40 | +| `core/core.py` | `@run()` decorator, backend routing, env var capture | |
| 41 | +| `core/accelerators.py` | Accelerator registry (`GPUS`, `TPUS`), parser (`parse_accelerator`) | |
| 42 | +| `backend/execution.py` | `JobContext` dataclass, `BackendClient` protocol, `execute_remote()` pipeline | |
| 43 | +| `backend/gke_client.py` | K8s Job creation, status polling, pod log retrieval | |
| 44 | +| `backend/pathways_client.py` | LeaderWorkerSet creation for multi-host TPUs | |
| 45 | +| `infra/container_builder.py` | Content-hashed Docker image building via Cloud Build | |
| 46 | +| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()` | |
| 47 | +| `utils/storage.py` | GCS upload/download/cleanup for job artifacts | |
| 48 | +| `runner/remote_runner.py` | Runs inside container: deserialize, execute, upload result | |
| 49 | +| `cli/main.py` | CLI entry point (`keras-remote` command) | |
| 50 | + |
| 51 | +## Key Abstractions |
| 52 | + |
| 53 | +- **`JobContext`** (`backend/execution.py`): Mutable dataclass carrying all job state through the pipeline — inputs, generated IDs, artifact paths, image URI. |
| 54 | +- **`BackendClient`** protocol (`backend/execution.py`): Interface with `submit_job`, `wait_for_job`, `cleanup_job`. Implemented by `GKEBackend` and `PathwaysBackend`. |
| 55 | +- **`GpuConfig` / `TpuConfig`** (`core/accelerators.py`): Frozen dataclasses for accelerator metadata. Single source of truth used by runtime, container builder, and CLI. |
| 56 | +- **`InfraConfig`** (`cli/config.py`): CLI provisioning configuration (project, zone, cluster, accelerator). |
| 57 | + |
| 58 | +## Conventions |
| 59 | + |
| 60 | +### Code Style |
| 61 | + |
| 62 | +- **Formatter/linter**: `ruff` (2-space indent, 80-char line length target, E501 ignored) |
| 63 | +- **Rules**: B, E, F, N, PYI, T20, TID, SIM, W, I, NPY |
| 64 | +- **Dataclasses**: Frozen for immutable configs, mutable for state objects |
| 65 | +- **Protocols**: Used for backend abstraction (not ABCs) |
| 66 | + |
| 67 | +### Environment Variables |
| 68 | + |
| 69 | +- `KERAS_REMOTE_PROJECT` (required): GCP project ID |
| 70 | +- `KERAS_REMOTE_ZONE` (optional): GCP zone, defaults to `us-central1-a` |
| 71 | +- `KERAS_REMOTE_CLUSTER` (optional): GKE cluster name |
| 72 | +- `KERAS_REMOTE_GKE_NAMESPACE` (optional): K8s namespace, defaults to `default` |
| 73 | + |
| 74 | +### Testing |
| 75 | + |
| 76 | +- **Framework**: `absl.testing` (not pytest) |
| 77 | +- **Location**: Colocated `*_test.py` files alongside source modules |
| 78 | +- **Patterns**: `@parameterized.named_parameters` for multi-case tests, mocked GCP/K8s APIs, `tempfile.TemporaryDirectory()` for file ops |
| 79 | +- **Integration tests**: `tests/integration/` |
| 80 | +- **E2E tests**: `tests/e2e/` (requires live GCP resources) |
| 81 | +- **Run tests**: Use pytest (e.g., `/opt/miniconda3/envs/keras-remote-3.12/bin/python -m pytest`). Tests use `absl.testing` internally but should be run via pytest for better output. |
| 82 | + |
| 83 | +### Container Caching |
| 84 | + |
| 85 | +Images are tagged with `SHA256(base_image + accelerator_type + requirements.txt + remote_runner.py)`. Identical inputs produce the same tag, skipping rebuild. |
| 86 | + |
| 87 | +### Artifact Registry API |
| 88 | + |
| 89 | +`get_docker_image` requires digest-based names (`image@sha256:...`), not tag-based (`image:tag`). Use `get_tag` with resource name `projects/{p}/locations/{l}/repositories/{r}/packages/{image}/tags/{tag}` to check tagged images. |
| 90 | + |
| 91 | +## Build System |
| 92 | + |
| 93 | +- **Build tool**: hatchling |
| 94 | +- **Python**: >=3.11 |
| 95 | +- **Core deps**: absl-py, cloudpickle, numpy, keras, google-cloud-{artifact-registry,storage,build}, kubernetes |
| 96 | +- **CLI deps** (optional `[cli]`): click, rich, pulumi, pulumi-gcp |
| 97 | +- **Dev deps** (optional `[dev]`): pre-commit, ruff |
| 98 | +- **Entry point**: `keras-remote` → `keras_remote.cli.main:cli` |
| 99 | + |
| 100 | +## Backend Selection Logic |
| 101 | + |
| 102 | +- **CPU / single-node GPU / single-node TPU**: GKE backend (K8s Job) |
| 103 | +- **Multi-node TPU** (`TpuConfig.num_nodes > 1`): Pathways backend (LeaderWorkerSet) |
| 104 | +- Explicit `backend=` parameter overrides auto-detection |
| 105 | + |
| 106 | +## Result Serialization Format |
| 107 | + |
| 108 | +```python |
| 109 | +{ |
| 110 | + "success": bool, |
| 111 | + "result": Any, # if success=True |
| 112 | + "exception": Exception, # if success=False |
| 113 | + "traceback": str, # if success=False |
| 114 | +} |
| 115 | +``` |
| 116 | + |
| 117 | +Exceptions are re-raised locally with the original traceback. |
0 commit comments