Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions keras_remote/backend/gke_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from kubernetes import client, config
from kubernetes.client.rest import ApiException

from keras_remote.backend.log_streaming import LogStreamer
from keras_remote.core import accelerators
from keras_remote.core.accelerators import TpuConfig

Expand Down Expand Up @@ -108,37 +109,48 @@ def wait_for_job(job, namespace="default", timeout=3600, poll_interval=10):
start_time = time.time()
logged_running = False

while True:
# Check timeout
elapsed = time.time() - start_time
if elapsed > timeout:
raise RuntimeError(f"GKE job {job_name} timed out after {timeout}s")

# Get job status
try:
job_status = batch_v1.read_namespaced_job_status(job_name, namespace)
except ApiException as e:
raise RuntimeError(f"Failed to read job status: {e.reason}") from e

# Check completion conditions
if job_status.status.succeeded and job_status.status.succeeded >= 1:
logging.info("Job %s completed successfully", job_name)
return "success"

if job_status.status.failed and job_status.status.failed >= 1:
# Get pod logs for debugging
_print_pod_logs(core_v1, job_name, namespace)
raise RuntimeError(f"GKE job {job_name} failed")

# Check for pod scheduling issues
_check_pod_scheduling(core_v1, job_name, namespace)

# Job still running
if not logged_running:
logging.info("Job %s running...", job_name)
logged_running = True

time.sleep(poll_interval)
with LogStreamer(core_v1, namespace) as streamer:
while True:
# Check timeout
elapsed = time.time() - start_time
if elapsed > timeout:
raise RuntimeError(f"GKE job {job_name} timed out after {timeout}s")

# Get job status
try:
job_status = batch_v1.read_namespaced_job_status(job_name, namespace)
except ApiException as e:
raise RuntimeError(f"Failed to read job status: {e.reason}") from e

# Check completion conditions
if job_status.status.succeeded and job_status.status.succeeded >= 1:
logging.info("Job %s completed successfully", job_name)
return "success"

if job_status.status.failed and job_status.status.failed >= 1:
# Get pod logs for debugging
_print_pod_logs(core_v1, job_name, namespace)
raise RuntimeError(f"GKE job {job_name} failed")

# Check for pod scheduling issues
_check_pod_scheduling(core_v1, job_name, namespace)

# Start log streaming when pod is running
with suppress(ApiException):
pods = core_v1.list_namespaced_pod(
namespace, label_selector=f"job-name={job_name}"
)
for pod in pods.items:
if pod.status.phase == "Running":
streamer.start(pod.metadata.name)
break

# Job still running
if not logged_running:
logging.info("Job %s running...", job_name)
logged_running = True

time.sleep(poll_interval)


def cleanup_job(job_name, namespace="default"):
Expand Down
77 changes: 77 additions & 0 deletions keras_remote/backend/gke_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ def setUp(self):
mock.patch("keras_remote.backend.gke_client._load_kube_config")
)

self.mock_streamer = MagicMock()
self.enterContext(
mock.patch(
"keras_remote.backend.gke_client.LogStreamer",
return_value=self.mock_streamer,
)
)
self.mock_streamer.__enter__ = MagicMock(return_value=self.mock_streamer)
self.mock_streamer.__exit__ = MagicMock(return_value=False)

def _make_mock_job(self):
job = MagicMock()
job.metadata.name = "keras-remote-job-abc"
Expand Down Expand Up @@ -194,6 +204,7 @@ def test_first_poll_success(self):
):
result = wait_for_job(self._make_mock_job())
self.assertEqual(result, "success")
self.mock_streamer.start.assert_not_called()

def test_first_poll_failure(self):
mock_batch = MagicMock()
Expand Down Expand Up @@ -264,6 +275,72 @@ def test_polls_until_success(self):
self.assertEqual(result, "success")
mock_sleep.assert_called_with(5)

def test_starts_streaming_when_pod_running(self):
mock_batch = MagicMock()
running = MagicMock()
running.status.succeeded = None
running.status.failed = None
succeeded = MagicMock()
succeeded.status.succeeded = 1
succeeded.status.failed = None
mock_batch.read_namespaced_job_status.side_effect = [running, succeeded]

running_pod = MagicMock()
running_pod.status.phase = "Running"
running_pod.metadata.name = "keras-remote-job-abc-pod"

mock_core = MagicMock()
mock_core.list_namespaced_pod.return_value.items = [running_pod]

with (
mock.patch(
"keras_remote.backend.gke_client.client.BatchV1Api",
return_value=mock_batch,
),
mock.patch(
"keras_remote.backend.gke_client.client.CoreV1Api",
return_value=mock_core,
),
mock.patch("keras_remote.backend.gke_client.time.sleep"),
):
result = wait_for_job(self._make_mock_job())

self.assertEqual(result, "success")
self.mock_streamer.start.assert_called_once_with("keras-remote-job-abc-pod")

def test_no_streaming_when_pod_pending(self):
mock_batch = MagicMock()
running = MagicMock()
running.status.succeeded = None
running.status.failed = None
succeeded = MagicMock()
succeeded.status.succeeded = 1
succeeded.status.failed = None
mock_batch.read_namespaced_job_status.side_effect = [running, succeeded]

pending_pod = MagicMock()
pending_pod.status.phase = "Pending"
pending_pod.status.conditions = None

mock_core = MagicMock()
mock_core.list_namespaced_pod.return_value.items = [pending_pod]

with (
mock.patch(
"keras_remote.backend.gke_client.client.BatchV1Api",
return_value=mock_batch,
),
mock.patch(
"keras_remote.backend.gke_client.client.CoreV1Api",
return_value=mock_core,
),
mock.patch("keras_remote.backend.gke_client.time.sleep"),
):
result = wait_for_job(self._make_mock_job())

self.assertEqual(result, "success")
self.mock_streamer.start.assert_not_called()


class TestLoadKubeConfig(absltest.TestCase):
def test_kubeconfig_fallback(self):
Expand Down
137 changes: 137 additions & 0 deletions keras_remote/backend/log_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Live log streaming from Kubernetes pods.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the file doesnt really fit in backend. Maybe move it to utils?


Provides utilities to stream pod logs to stdout in real-time using a
background daemon thread. Used by both GKE and Pathways backends during
job execution.
"""

import sys
import threading
from collections import deque

import urllib3
from absl import logging
from kubernetes.client.rest import ApiException
from rich.console import Console
from rich.live import Live
from rich.panel import Panel

_MAX_DISPLAY_LINES = 25


def _stream_pod_logs(core_v1, pod_name, namespace):
"""Stream pod logs to stdout. Designed to run in a daemon thread.

Uses the Kubernetes follow API to tail logs in real-time. The stream
ends naturally when the container exits.

In interactive terminals, logs are displayed in a Rich Live panel.
In non-interactive contexts (piped output, CI), logs are streamed
as raw lines with Rich Rule delimiters.

Args:
core_v1: Kubernetes CoreV1Api client.
pod_name: Name of the pod to stream logs from.
namespace: Kubernetes namespace.
"""
console = Console()
resp = None
try:
resp = core_v1.read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
follow=True,
_preload_content=False,
)
if console.is_terminal:
_render_live_panel(resp, pod_name, console)
else:
_render_plain(resp, pod_name, console)
except ApiException:
pass # Pod deleted or not found
except urllib3.exceptions.ProtocolError:
pass # Connection broken mid-stream (pod terminated)
except Exception:
logging.warning(
"Log streaming from %s failed unexpectedly", pod_name, exc_info=True
)
finally:
if resp is not None:
resp.release_conn()


def _render_live_panel(resp, pod_name, console):
"""Render streaming logs inside a Rich Live panel."""
lines = deque(maxlen=_MAX_DISPLAY_LINES)
title = f"Remote logs \u2022 {pod_name}"
buffer = ""

with Live(
_make_log_panel(lines, title),
console=console,
refresh_per_second=4,
) as live:
for chunk in resp.stream(decode_content=True):
buffer += chunk.decode("utf-8", errors="replace")
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
lines.append(line)
live.update(_make_log_panel(lines, title))

# Flush remaining partial line
if buffer.strip():
lines.append(buffer)
live.update(_make_log_panel(lines, title))


def _render_plain(resp, pod_name, console):
"""Render streaming logs as raw lines with Rule delimiters."""
console.rule(f"Remote logs ({pod_name})", style="blue")
for chunk in resp.stream(decode_content=True):
sys.stdout.write(chunk.decode("utf-8", errors="replace"))
sys.stdout.flush()
console.rule("End remote logs", style="blue")


def _make_log_panel(lines, title):
"""Build a Panel renderable from accumulated log lines."""
content = "\n".join(lines) if lines else "Waiting for output..."
return Panel(content, title=title, border_style="blue")


class LogStreamer:
"""Context manager that owns the log-streaming thread lifecycle.

Usage::

with LogStreamer(core_v1, namespace) as streamer:
while polling:
...
if pod_is_running:
streamer.start(pod_name) # idempotent
"""

def __init__(self, core_v1, namespace):
self._core_v1 = core_v1
self._namespace = namespace
self._thread = None

def __enter__(self):
return self

def __exit__(self, *exc):
if self._thread is not None:
self._thread.join(timeout=5)
return False

def start(self, pod_name):
"""Start streaming if not already active (idempotent)."""
if self._thread is not None:
return
logging.info("Streaming logs from %s...", pod_name)
self._thread = threading.Thread(
target=_stream_pod_logs,
args=(self._core_v1, pod_name, self._namespace),
daemon=True,
)
self._thread.start()
Loading
Loading