diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py index d97aa3fa..fc57ac1d 100644 --- a/projects/fal/src/fal/app.py +++ b/projects/fal/src/fal/app.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import inspect import json import os @@ -52,7 +53,7 @@ r"\.git/", r"\.DS_Store$", ] - +LOG_CONTEXT_PREFIX = "LOG_" EndpointT = TypeVar("EndpointT", bound=Callable[..., Any]) logger = get_logger(__name__) @@ -89,9 +90,28 @@ async def open_isolate_channel(address: str) -> async_grpc.Channel | None: return channel +def merge_contextvars(logger_labels: dict[str, str]) -> None: + for k, v in logger_labels.items(): + ContextVar(f"{LOG_CONTEXT_PREFIX}{k}").set(v) # type: ignore + + +def clear_contextvars() -> None: + ctx = contextvars.copy_context() + for k in ctx: + if k.name.startswith(LOG_CONTEXT_PREFIX): + k.set(Ellipsis) + + async def _set_logger_labels( logger_labels: dict[str, str], channel: async_grpc.Channel ): + # Set the labels into the current context so isolate agent can use them for + # associating log lines with the current request context + if not logger_labels: + clear_contextvars() + else: + merge_contextvars(logger_labels) + try: import sys diff --git a/projects/fal/tests/unit/test_app.py b/projects/fal/tests/unit/test_app.py index 4c532be6..3e06d1c2 100644 --- a/projects/fal/tests/unit/test_app.py +++ b/projects/fal/tests/unit/test_app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import os import pickle from contextvars import ContextVar @@ -8,6 +9,7 @@ import fal from fal import App, endpoint +from fal.app import LOG_CONTEXT_PREFIX, clear_contextvars, merge_contextvars from fal.container import ContainerImage @@ -257,6 +259,39 @@ class LeakCheckApp(App): assert "app_auth" not in hk +def test_merge_context_vars(): + labels = {"fal_request_id": "123", "fal_endpoint": "/"} + request_id_var = f"{LOG_CONTEXT_PREFIX}fal_request_id" + endpoint_var = f"{LOG_CONTEXT_PREFIX}fal_endpoint" + unrelated_var = "unrelated_key" + contextvars.ContextVar(unrelated_var).set("value") + + # We have to convert to dict and lookup by name because each ContextVar + # is a different object. Since merge_contextvars creates new ContextVars, + # we can't just do direct lookups. + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) + + assert vars.get(unrelated_var) == "value" + + assert vars.get(request_id_var) is None + assert vars.get(endpoint_var) is None + + merge_contextvars(labels) + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) + + assert vars.get(request_id_var) == "123" + assert vars.get(endpoint_var) == "/" + + clear_contextvars() + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) + + # Cleared contextvars are set to Ellipsis + assert vars.get(request_id_var) is Ellipsis + assert vars.get(endpoint_var) is Ellipsis + # Does not clear unrelated contextvars + assert vars.get("unrelated_key") == "value" + + @pytest.mark.asyncio async def test_runner_state_lifecycle_complete(): """Test that FAL_RUNNER_STATE transitions through all phases correctly"""