Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/isolate/backends/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def log(
source: LogSource = LogSource.BUILDER,
) -> None:
"""Log a message."""
log_msg = Log(message, level=level, source=source, bound_env=self)
log_msg = Log(
message,
level=level,
source=source,
bound_env=self,
is_json=self.settings.json_logs,
)
self.settings.log(log_msg)


Expand Down
4 changes: 3 additions & 1 deletion src/isolate/backends/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

_SYSTEM_TEMP_DIR = Path(tempfile.gettempdir())
_STRICT_CACHE = os.getenv("ISOLATE_STRICT_CACHE", "0") == "1"
JSON_LOGS = os.getenv("ISOLATE_JSON_LOGS", "0") == "1"


@dataclass(frozen=True)
Expand All @@ -26,6 +27,7 @@ class IsolateSettings:
serialization_method: str = "pickle"
log_hook: Callable[[Log], None] = print
strict_cache: bool = _STRICT_CACHE
json_logs: bool = JSON_LOGS

def log(self, log: Log) -> None:
self.log_hook(self._infer_log_level(log))
Expand All @@ -39,7 +41,7 @@ def _infer_log_level(self, log: Log) -> Log:
if log.source in (LogSource.BUILDER, LogSource.BRIDGE):
return replace(log, level=LogLevel.TRACE)

line = log.message.lower()
line = log.message_str().lower()

if line.startswith("error") or "[error]" in line:
return replace(log, level=LogLevel.ERROR)
Expand Down
5 changes: 4 additions & 1 deletion src/isolate/connections/_local/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from isolate import __version__ as isolate_version
from isolate.backends.common import active_python, get_executable_path, logged_io
from isolate.backends.settings import JSON_LOGS
from isolate.connections.common import AGENT_SIGNATURE
from isolate.logs import LogLevel, LogSource

Expand Down Expand Up @@ -102,6 +103,7 @@ class PythonExecutionBase(Generic[ConnectionType]):
environment: BaseEnvironment
environment_path: Path
extra_inheritance_paths: list[Path] = field(default_factory=list)
json_logs: bool = True

@contextmanager
def start_process(
Expand Down Expand Up @@ -147,7 +149,7 @@ def start_process(
),
) as (stdout, stderr, log_fd):
yield subprocess.Popen(
self.get_python_cmd(python_executable, connection, log_fd),
self.get_python_cmd(python_executable, connection, log_fd, JSON_LOGS),
env=env,
stdout=stdout,
stderr=stderr,
Expand Down Expand Up @@ -191,6 +193,7 @@ def get_python_cmd(
executable: Path,
connection: ConnectionType,
log_fd: int,
json_logs: bool = False,
) -> list[str | Path]:
"""Return the command to run the agent process with."""
raise NotImplementedError
Expand Down
2 changes: 2 additions & 0 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def get_python_cmd(
executable: Path,
connection: str,
log_fd: int,
json_logs: bool = False,
) -> List[Union[str, Path]]:
return [
executable,
Expand All @@ -173,6 +174,7 @@ def get_python_cmd(
connection,
"--log-fd",
str(log_fd),
*(["--json-logs"] if json_logs else []),
]

def handle_agent_log(
Expand Down
131 changes: 126 additions & 5 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
from __future__ import annotations

import asyncio
import contextvars
import json
import os
import signal
import sys
import threading
import traceback
from argparse import ArgumentParser
from concurrent import futures
Expand All @@ -22,6 +25,7 @@
Any,
AsyncIterator,
Iterable,
TextIO,
)

from grpc import StatusCode, aio, local_server_credentials
Expand All @@ -41,18 +45,118 @@

IDLE_TIMEOUT_SECONDS = int(os.getenv("ISOLATE_AGENT_IDLE_TIMEOUT_SECONDS", "0"))

isolate_log_context: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"ISOLATE_CONTEXT_VAR_LOG", default={}
)


def get_log_context() -> dict[str, Any]:
"""Extract the contextvar that is set to the log_context."""
value = isolate_log_context.get()
if not isinstance(value, dict):
return {}

return value


class JsonStdoutProxy:
"""
A proxy around a real text stream (usually sys.__stdout__).
- Intercepts write/writelines and emits JSON lines with contextvars
- Delegates everything else to the wrapped stream to preserve compatibility.
- Avoids recursion by always writing to the underlying stream.
"""

def __init__(self, underlying):
self._u = underlying
self._local = threading.local()

def write(self, s: str) -> int:
# Many libs call write(""); keep semantics cheap.
if not s:
return 0

# Prevent re-entrancy if something in encoding/IO calls back into sys.stdout.
if getattr(self._local, "in_write", False):
return self._u.write(s)

self._local.in_write = True
try:
# Preserve "print" behavior: print() typically writes text
# possibly with '\n'
# Emit one JSON object per line. Keep partial lines buffered per-thread.
buf = getattr(self._local, "buf", "")
buf += s
lines = buf.splitlines(keepends=True)

out_count = 0
new_buf = ""

for chunk in lines:
if chunk.endswith("\n") or chunk.endswith("\r\n"):
msg = chunk.rstrip("\r\n")
payload = self._format_record(msg)
out_count += self._u.write(payload + "\n")
else:
# Incomplete line: keep buffering
new_buf += chunk

self._local.buf = new_buf
return len(s)
finally:
self._local.in_write = False

def flush(self) -> None:
# Flush any buffered partial line as a final record (optional choice).
if getattr(self._local, "in_write", False):
return self._u.flush()

buf = getattr(self._local, "buf", "")
if buf:
self._local.buf = ""
payload = self._format_record(buf)
self._u.write(payload + "\n")
self._u.flush()

def writelines(self, lines: list[str]) -> None:
for line in lines:
self.write(line)

def _format_record(self, message: str) -> str:
record = {
"line": message,
}
# Add the log context to the json so we propagate contextvars
record.update(get_log_context())
return json.dumps(record, ensure_ascii=False)

def __getattr__(self, name: str) -> Any:
# Delegate missing attrs/methods to underlying stream
# (isatty, fileno, encoding, etc.)
return getattr(self._u, name)

def __iter__(self):
return iter(self._u)

def __enter__(self):
self._u.__enter__()
return self

def __exit__(self, exc_type, exc, tb):
return self._u.__exit__(exc_type, exc, tb)


@dataclass
class AbortException(Exception):
message: str


class AgentServicer(definitions.AgentServicer):
def __init__(self, log_fd: int | None = None):
def __init__(self, log_file: TextIO | None = None):
super().__init__()

self._run_cache: dict[str, Any] = {}
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")
self._log = log_file if log_file is not None else sys.stdout
self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)
self._idle_timeout_seconds = IDLE_TIMEOUT_SECONDS
self._is_running = asyncio.Event()
Expand Down Expand Up @@ -305,10 +409,24 @@ def create_server(address: str) -> aio.Server:
return server


async def run_agent(address: str, log_fd: int | None = None) -> int:
async def run_agent(
address: str, log_fd: int | None = None, json_logs: bool = False
) -> int:
"""Run the agent servicer on the given address."""
# Determine the base log file
if log_fd is None:
log_file: TextIO = sys.stdout
else:
log_file = os.fdopen(log_fd, "w")

# Apply JSON wrapper if requested
if json_logs:
sys.stdout = JsonStdoutProxy(sys.__stdout__) # type: ignore[assignment]
sys.stderr = JsonStdoutProxy(sys.__stderr__) # type: ignore[assignment]
log_file = JsonStdoutProxy(log_file) # type: ignore[assignment]

server = create_server(address)
servicer = AgentServicer(log_fd=log_fd)
servicer = AgentServicer(log_file=log_file)

# This function just calls some methods on the server
# and register a generic handler for the bridge. It does
Expand All @@ -335,9 +453,12 @@ async def main() -> int:
parser = ArgumentParser()
parser.add_argument("address", type=str)
parser.add_argument("--log-fd", type=int)
parser.add_argument("--json-logs", action="store_true", default=False)

options = parser.parse_args()
return await run_agent(options.address, log_fd=options.log_fd)
return await run_agent(
options.address, log_fd=options.log_fd, json_logs=options.json_logs
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/isolate/connections/grpc/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _(message: definitions.Log) -> Log:
@to_grpc.register
def _(obj: Log) -> definitions.Log:
return definitions.Log(
message=obj.message,
message=obj.message_str(),
source=definitions.LogSource.Value(obj.source.name.upper()),
level=definitions.LogLevel.Value(obj.level.name.upper()),
timestamp=timestamp.from_datetime(obj.timestamp),
Expand Down
2 changes: 2 additions & 0 deletions src/isolate/connections/ipc/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def get_python_cmd(
executable: Path,
connection: AgentListener,
log_fd: int,
json_logs: bool = False,
) -> list[str | Path]:
assert isinstance(connection.address, tuple)
return [
Expand All @@ -217,6 +218,7 @@ def get_python_cmd(
self.environment.settings.serialization_method,
"--log-fd",
str(log_fd),
*(["--json-logs"] if json_logs else []),
]

def handle_agent_log(
Expand Down
9 changes: 8 additions & 1 deletion src/isolate/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ class IsolateLogger:
def __init__(self, log_labels: Dict[str, str]):
self.log_labels = log_labels

def log(self, level: LogLevel, message: str, source: LogSource) -> None:
def log(
self,
level: LogLevel,
message: str,
source: LogSource,
line_labels: Dict[str, str],
) -> None:
record = {
# Set the timestamp from source so we can be sure no buffering or
# latency is affecting the timestamp.
Expand All @@ -25,6 +31,7 @@ def log(self, level: LogLevel, message: str, source: LogSource) -> None:
"message": message,
**self.log_labels,
**self.extra_labels,
**line_labels,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

}
print(json.dumps(record))

Expand Down
26 changes: 26 additions & 0 deletions src/isolate/logs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import tempfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
Expand Down Expand Up @@ -63,6 +64,8 @@ class Log:
level: LogLevel = LogLevel.INFO
bound_env: BaseEnvironment | None = field(default=None, repr=False)
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
is_json: bool = field(default=False)
_parsed_message: dict | None = field(default=None, init=False, repr=False)

def __str__(self) -> str:
parts = [self.timestamp.strftime("%m/%d/%Y %H:%M:%S")]
Expand All @@ -74,3 +77,26 @@ def __str__(self) -> str:
parts.append(f"[{self.source}]".ljust(10))
parts.append(f"[{self.level}]".ljust(10))
return " ".join(parts) + self.message

def message_str(self) -> str:
parsed = self.from_json()
return parsed["line"] if "line" in parsed else self.message

def message_meta(self) -> dict:
parsed = self.from_json()
if "line" in parsed:
# The metadata is everything except the actual log line.
return {k: v for k, v in parsed.items() if k != "line"}
return parsed

def from_json(self) -> dict[str, str]:
if not self.is_json:
return {}
if self._parsed_message is not None:
return self._parsed_message
try:
self._parsed_message = json.loads(self.message)
return self._parsed_message
except json.JSONDecodeError:
self._parsed_message = {}
return self._parsed_message
9 changes: 7 additions & 2 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,13 @@ class LogHandler:
task: RunTask

def handle(self, log: Log) -> None:
if not SKIP_EMPTY_LOGS or log.message.strip():
self.task.logger.log(log.level, log.message, source=log.source)
if not SKIP_EMPTY_LOGS or log.message_str().strip():
self.task.logger.log(
log.level,
log.message_str(),
source=log.source,
line_labels=log.message_meta(),
)
self._add_log_to_queue(log)

def _add_log_to_queue(self, log: Log) -> None:
Expand Down
Loading