diff --git a/.github/workflows/e2e-tests.yaml b/.github/workflows/e2e-tests.yaml new file mode 100644 index 0000000..b536a95 --- /dev/null +++ b/.github/workflows/e2e-tests.yaml @@ -0,0 +1,65 @@ +name: E2E Tests + +on: + workflow_dispatch: + pull_request: + types: [labeled] + +permissions: + contents: read + id-token: write + pull-requests: write + +jobs: + e2e: + if: > + github.event_name == 'workflow_dispatch' || + (github.event_name == 'pull_request' && github.event.label.name == 'run-e2e') + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Remove run-e2e label + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + await github.rest.issues.removeLabel({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + name: 'run-e2e' + }); + + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + workload_identity_provider: ${{ secrets.WIF_PROVIDER }} + service_account: ${{ secrets.WIF_SERVICE_ACCOUNT }} + + - name: Set up gcloud + uses: google-github-actions/setup-gcloud@v2 + + - name: Get GKE credentials + uses: google-github-actions/get-gke-credentials@v2 + with: + cluster_name: ${{ secrets.GKE_CLUSTER }} + location: ${{ secrets.GKE_ZONE }} + project_id: ${{ secrets.GCP_PROJECT }} + + - name: Install dependencies + run: pip install -e ".[test,cli]" + + - name: Run E2E tests + env: + E2E_TESTS: "1" + KERAS_REMOTE_PROJECT: ${{ secrets.GCP_PROJECT }} + KERAS_REMOTE_ZONE: ${{ secrets.GKE_ZONE }} + KERAS_REMOTE_GKE_CLUSTER: ${{ secrets.GKE_CLUSTER }} + run: python -m unittest discover -s tests/e2e -p "*_test.py" -v diff --git a/keras_remote/backend/execution.py b/keras_remote/backend/execution.py index cd5e544..755ffbf 100644 --- a/keras_remote/backend/execution.py +++ b/keras_remote/backend/execution.py @@ -13,6 +13,7 @@ import cloudpickle from absl import logging +from google.api_core import exceptions as google_exceptions from keras_remote.backend import gke_client, pathways_client from keras_remote.constants import get_default_zone, zone_to_region @@ -294,13 +295,26 @@ def execute_remote(ctx: JobContext, backend: BackendClient) -> Any: job = backend.submit_job(ctx) # Phase 5: Wait for completion (with cleanup on failure) + job_error = None try: backend.wait_for_job(job, ctx) + except RuntimeError as e: + job_error = e finally: backend.cleanup_job(job, ctx) # Phase 6: Download and deserialize result - result_payload = _download_result(ctx) + # Try even if the job failed — the runner may have captured a user + # exception and uploaded the result before exiting with non-zero. + if job_error is not None: + try: + result_payload = _download_result(ctx) + except google_exceptions.NotFound: + # Result wasn't uploaded (infrastructure failure), surface the + # original job error. + raise job_error from None + else: + result_payload = _download_result(ctx) # Phase 7: Cleanup and return/raise return _cleanup_and_return(ctx, result_payload) diff --git a/keras_remote/backend/execution_test.py b/keras_remote/backend/execution_test.py index 64d2405..d7c9088 100644 --- a/keras_remote/backend/execution_test.py +++ b/keras_remote/backend/execution_test.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock from absl.testing import absltest +from google.api_core import exceptions as google_exceptions from keras_remote.backend.execution import ( JobContext, @@ -175,6 +176,10 @@ def test_cleanup_on_wait_failure(self): with ( mock.patch("keras_remote.backend.execution._build_container"), mock.patch("keras_remote.backend.execution._upload_artifacts"), + mock.patch( + "keras_remote.backend.execution._download_result", + side_effect=google_exceptions.NotFound("no result uploaded"), + ), ): ctx = self._make_ctx() backend = MagicMock() diff --git a/tests/e2e/cpu_execution_test.py b/tests/e2e/cpu_execution_test.py new file mode 100644 index 0000000..fc7cb8c --- /dev/null +++ b/tests/e2e/cpu_execution_test.py @@ -0,0 +1,88 @@ +"""E2E tests for remote execution with CPU accelerator. + +These tests require a real GCP project with: +- A GKE cluster with a CPU node pool +- Cloud Storage, Cloud Build, and Artifact Registry APIs enabled +- Proper IAM permissions + +Set E2E_TESTS=1 to enable. +""" + +import os +from unittest import mock + +from absl.testing import absltest + +import keras_remote +from tests.e2e.e2e_utils import skip_unless_e2e + + +@skip_unless_e2e() +class TestCpuExecution(absltest.TestCase): + def setUp(self): + super().setUp() + + def test_simple_function(self): + """Execute a simple add function remotely and verify the result.""" + + @keras_remote.run(accelerator="cpu") + def add(a, b): + return a + b + + result = add(2, 3) + self.assertEqual(result, 5) + + def test_complex_return_type(self): + """Verify complex return types survive serialization roundtrip.""" + + @keras_remote.run(accelerator="cpu") + def complex_return(): + return { + "key": [1, 2, 3], + "nested": {"a": True, "b": None}, + "tuple": (4, 5), + } + + result = complex_return() + self.assertEqual(result["key"], [1, 2, 3]) + self.assertTrue(result["nested"]["a"]) + self.assertIsNone(result["nested"]["b"]) + self.assertEqual(result["tuple"], (4, 5)) + + def test_function_that_raises(self): + """Verify remote exceptions are re-raised locally.""" + + @keras_remote.run(accelerator="cpu") + def bad_func(): + raise ValueError("intentional test error") + + with self.assertRaisesRegex(ValueError, "intentional test error"): + bad_func() + + def test_env_var_propagation(self): + """Verify captured env vars are available in the remote environment.""" + with mock.patch.dict(os.environ, {"E2E_TEST_VAR": "hello_from_local"}): + + @keras_remote.run( + accelerator="cpu", + capture_env_vars=["E2E_TEST_VAR"], + ) + def read_env(): + return os.environ.get("E2E_TEST_VAR") + + result = read_env() + self.assertEqual(result, "hello_from_local") + + def test_function_with_args_and_kwargs(self): + """Verify positional and keyword arguments are passed correctly.""" + + @keras_remote.run(accelerator="cpu") + def compute(x, y, scale=1.0, offset=0.0): + return (x + y) * scale + offset + + result = compute(3, 4, scale=2.0, offset=1.0) + self.assertEqual(result, 15.0) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/e2e/e2e_utils.py b/tests/e2e/e2e_utils.py index 35e4b94..5e36900 100644 --- a/tests/e2e/e2e_utils.py +++ b/tests/e2e/e2e_utils.py @@ -7,11 +7,3 @@ def skip_unless_e2e(reason="E2E_TESTS not set"): """Skip decorator for e2e tests unless E2E_TESTS env var is set.""" return unittest.skipUnless(os.environ.get("E2E_TESTS"), reason) - - -def get_gcp_project(): - """Return GCP project from env, skip test if not set.""" - project = os.environ.get("KERAS_REMOTE_PROJECT") - if not project: - raise unittest.SkipTest("KERAS_REMOTE_PROJECT not set") - return project