Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions .github/workflows/e2e-tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: E2E Tests

on:
workflow_dispatch:
pull_request:
types: [labeled]

permissions:
contents: read
id-token: write
pull-requests: write

jobs:
e2e:
if: >
github.event_name == 'workflow_dispatch' ||
(github.event_name == 'pull_request' && github.event.label.name == 'run-e2e')
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Remove run-e2e label
if: github.event_name == 'pull_request'
uses: actions/github-script@v7
with:
script: |
await github.rest.issues.removeLabel({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
name: 'run-e2e'
});


- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v2
with:
workload_identity_provider: ${{ secrets.WIF_PROVIDER }}
service_account: ${{ secrets.WIF_SERVICE_ACCOUNT }}

- name: Set up gcloud
uses: google-github-actions/setup-gcloud@v2

- name: Get GKE credentials
uses: google-github-actions/get-gke-credentials@v2
with:
cluster_name: ${{ secrets.GKE_CLUSTER }}
location: ${{ secrets.GKE_ZONE }}
project_id: ${{ secrets.GCP_PROJECT }}

- name: Install dependencies
run: pip install -e ".[test,cli]"

- name: Run E2E tests
env:
E2E_TESTS: "1"
KERAS_REMOTE_PROJECT: ${{ secrets.GCP_PROJECT }}
KERAS_REMOTE_ZONE: ${{ secrets.GKE_ZONE }}
KERAS_REMOTE_GKE_CLUSTER: ${{ secrets.GKE_CLUSTER }}
run: python -m unittest discover -s tests/e2e -p "*_test.py" -v
16 changes: 15 additions & 1 deletion keras_remote/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import cloudpickle
from absl import logging
from google.api_core import exceptions as google_exceptions

from keras_remote.backend import gke_client, pathways_client
from keras_remote.constants import get_default_zone, zone_to_region
Expand Down Expand Up @@ -294,13 +295,26 @@ def execute_remote(ctx: JobContext, backend: BackendClient) -> Any:
job = backend.submit_job(ctx)

# Phase 5: Wait for completion (with cleanup on failure)
job_error = None
try:
backend.wait_for_job(job, ctx)
except RuntimeError as e:
job_error = e
finally:
backend.cleanup_job(job, ctx)

# Phase 6: Download and deserialize result
result_payload = _download_result(ctx)
# Try even if the job failed — the runner may have captured a user
# exception and uploaded the result before exiting with non-zero.
if job_error is not None:
try:
result_payload = _download_result(ctx)
except google_exceptions.NotFound:
# Result wasn't uploaded (infrastructure failure), surface the
# original job error.
raise job_error from None
else:
result_payload = _download_result(ctx)

# Phase 7: Cleanup and return/raise
return _cleanup_and_return(ctx, result_payload)
5 changes: 5 additions & 0 deletions keras_remote/backend/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import MagicMock

from absl.testing import absltest
from google.api_core import exceptions as google_exceptions

from keras_remote.backend.execution import (
JobContext,
Expand Down Expand Up @@ -175,6 +176,10 @@ def test_cleanup_on_wait_failure(self):
with (
mock.patch("keras_remote.backend.execution._build_container"),
mock.patch("keras_remote.backend.execution._upload_artifacts"),
mock.patch(
"keras_remote.backend.execution._download_result",
side_effect=google_exceptions.NotFound("no result uploaded"),
),
):
ctx = self._make_ctx()
backend = MagicMock()
Expand Down
88 changes: 88 additions & 0 deletions tests/e2e/cpu_execution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""E2E tests for remote execution with CPU accelerator.

These tests require a real GCP project with:
- A GKE cluster with a CPU node pool
- Cloud Storage, Cloud Build, and Artifact Registry APIs enabled
- Proper IAM permissions

Set E2E_TESTS=1 to enable.
"""

import os
from unittest import mock

from absl.testing import absltest

import keras_remote
from tests.e2e.e2e_utils import skip_unless_e2e


@skip_unless_e2e()
class TestCpuExecution(absltest.TestCase):
def setUp(self):
super().setUp()

def test_simple_function(self):
"""Execute a simple add function remotely and verify the result."""

@keras_remote.run(accelerator="cpu")
def add(a, b):
return a + b

result = add(2, 3)
self.assertEqual(result, 5)

def test_complex_return_type(self):
"""Verify complex return types survive serialization roundtrip."""

@keras_remote.run(accelerator="cpu")
def complex_return():
return {
"key": [1, 2, 3],
"nested": {"a": True, "b": None},
"tuple": (4, 5),
}

result = complex_return()
self.assertEqual(result["key"], [1, 2, 3])
self.assertTrue(result["nested"]["a"])
self.assertIsNone(result["nested"]["b"])
self.assertEqual(result["tuple"], (4, 5))

def test_function_that_raises(self):
"""Verify remote exceptions are re-raised locally."""

@keras_remote.run(accelerator="cpu")
def bad_func():
raise ValueError("intentional test error")

with self.assertRaisesRegex(ValueError, "intentional test error"):
bad_func()

def test_env_var_propagation(self):
"""Verify captured env vars are available in the remote environment."""
with mock.patch.dict(os.environ, {"E2E_TEST_VAR": "hello_from_local"}):

@keras_remote.run(
accelerator="cpu",
capture_env_vars=["E2E_TEST_VAR"],
)
def read_env():
return os.environ.get("E2E_TEST_VAR")

result = read_env()
self.assertEqual(result, "hello_from_local")

def test_function_with_args_and_kwargs(self):
"""Verify positional and keyword arguments are passed correctly."""

@keras_remote.run(accelerator="cpu")
def compute(x, y, scale=1.0, offset=0.0):
return (x + y) * scale + offset

result = compute(3, 4, scale=2.0, offset=1.0)
self.assertEqual(result, 15.0)


if __name__ == "__main__":
absltest.main()
8 changes: 0 additions & 8 deletions tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,3 @@
def skip_unless_e2e(reason="E2E_TESTS not set"):
"""Skip decorator for e2e tests unless E2E_TESTS env var is set."""
return unittest.skipUnless(os.environ.get("E2E_TESTS"), reason)


def get_gcp_project():
"""Return GCP project from env, skip test if not set."""
project = os.environ.get("KERAS_REMOTE_PROJECT")
if not project:
raise unittest.SkipTest("KERAS_REMOTE_PROJECT not set")
return project
Loading