diff --git a/api/context/execution_context.py b/api/context/execution_context.py index ea3d996929021f..cf87d451641619 100644 --- a/api/context/execution_context.py +++ b/api/context/execution_context.py @@ -57,6 +57,10 @@ def user(self) -> Any: """Get user object.""" ... + def refresh_context_vars(self) -> None: + """Re-capture current context variables for propagation to worker threads.""" + ... + @final class ExecutionContext: @@ -93,6 +97,15 @@ def user(self) -> Any: """Get captured user object.""" return self._user + def refresh_context_vars(self) -> None: + """Re-capture current context variables. + + Call this after ContextVars have been updated in the current thread + (e.g. by a GraphEngine layer's ``on_graph_start``) so that worker + threads created afterwards receive the updated values. + """ + self._context_vars = contextvars.copy_context() + @contextmanager def enter(self) -> Generator[None, None, None]: """Enter this execution context.""" diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 1201bad041e0d4..fdd06fb45f7e57 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -135,6 +135,10 @@ def user(self) -> Any: """Get user object.""" return self._user + def refresh_context_vars(self) -> None: + """Re-capture current context variables for propagation to worker threads.""" + self._context_vars = contextvars.copy_context() + def __enter__(self) -> "FlaskExecutionContext": """Enter the Flask execution context.""" # Restore non-Flask context variables to avoid leaking Flask tokens across threads diff --git a/api/core/app/workflow/layers/__init__.py b/api/core/app/workflow/layers/__init__.py index 7d5841275db6dd..d8df448e8890a8 100644 --- a/api/core/app/workflow/layers/__init__.py +++ b/api/core/app/workflow/layers/__init__.py @@ -1,6 +1,7 @@ """Workflow-level GraphEngine layers that depend on outer infrastructure.""" from .llm_quota import LLMQuotaLayer +from .log_context import WorkflowLogContextLayer from .observability import ObservabilityLayer from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer @@ -8,5 +9,6 @@ "LLMQuotaLayer", "ObservabilityLayer", "PersistenceWorkflowInfo", + "WorkflowLogContextLayer", "WorkflowPersistenceLayer", ] diff --git a/api/core/app/workflow/layers/log_context.py b/api/core/app/workflow/layers/log_context.py new file mode 100644 index 00000000000000..2175ff07e8d5e2 --- /dev/null +++ b/api/core/app/workflow/layers/log_context.py @@ -0,0 +1,58 @@ +"""GraphEngine layer that manages node_id log context via ContextVars. + +This layer tracks ``node_id`` during node execution (set on +``on_node_run_start``, cleared on ``on_node_run_end``). The +``app_id`` / ``workflow_id`` / ``error_source`` lifecycle is managed by +``WorkflowEntry.run`` directly, which has full control over the try/finally +timing relative to ``logger.exception``. + +On ``on_graph_start``, this layer refreshes the ``execution_context`` +snapshot so that worker threads inherit the ContextVars that +``WorkflowEntry.run`` set just before starting the graph engine. +""" + +from typing import override + +from context import IExecutionContext +from core.logging.context import set_node_log_context +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase +from graphon.nodes.base.node import Node + + +class WorkflowLogContextLayer(GraphEngineLayer): + """Manage node_id log context lifecycle during graph execution.""" + + def __init__(self, *, execution_context: IExecutionContext | None = None) -> None: + super().__init__() + self._execution_context = execution_context + + @override + def on_graph_start(self) -> None: + # Refresh the execution context snapshot so that worker threads + # (started after on_graph_start) inherit the ContextVars that + # WorkflowEntry.run set before starting the graph engine. + # Without this, workers would see stale default values because the + # snapshot was captured in WorkflowEntry.__init__ before run(). + if self._execution_context is not None: + self._execution_context.refresh_context_vars() + + @override + def on_node_run_start(self, node: Node) -> None: + set_node_log_context(node.id) + + @override + def on_node_run_end( + self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None + ) -> None: + set_node_log_context("") + + @override + def on_event(self, event: GraphEngineEvent) -> None: + _ = event + + @override + def on_graph_end(self, error: Exception | None) -> None: + # app_id / workflow_id / error_source are managed by WorkflowEntry.run. + # node_id is cleared in on_node_run_end after each node. + _ = error diff --git a/api/core/logging/__init__.py b/api/core/logging/__init__.py index db046cc9fa81f7..7f406751bce9ba 100644 --- a/api/core/logging/__init__.py +++ b/api/core/logging/__init__.py @@ -1,20 +1,45 @@ """Structured logging components for Dify.""" from core.logging.context import ( + ErrorSource, + clear_error_source, clear_request_context, + clear_workflow_log_context, + get_app_id, + get_error_source, + get_node_id, get_request_id, get_trace_id, + get_workflow_id, init_request_context, + set_error_source, + set_node_log_context, + set_workflow_log_context, +) +from core.logging.filters import ( + IdentityContextFilter, + TraceContextFilter, + WorkflowLogContextFilter, ) -from core.logging.filters import IdentityContextFilter, TraceContextFilter from core.logging.structured_formatter import StructuredJSONFormatter __all__ = [ + "ErrorSource", "IdentityContextFilter", "StructuredJSONFormatter", "TraceContextFilter", + "WorkflowLogContextFilter", + "clear_error_source", "clear_request_context", + "clear_workflow_log_context", + "get_app_id", + "get_error_source", + "get_node_id", "get_request_id", "get_trace_id", + "get_workflow_id", "init_request_context", + "set_error_source", + "set_node_log_context", + "set_workflow_log_context", ] diff --git a/api/core/logging/context.py b/api/core/logging/context.py index 18633a0b050765..8c2efe731bd8d3 100644 --- a/api/core/logging/context.py +++ b/api/core/logging/context.py @@ -6,10 +6,32 @@ import uuid from contextvars import ContextVar +from enum import StrEnum + + +class ErrorSource(StrEnum): + """Classification of error sources for structured logging. + + Used in the ``error_source`` field of ERROR+ log records to enable + differentiated alerting rules (e.g. workflow errors are user-caused, + system errors trigger on-call alerts). + """ + + WORKFLOW = "workflow" + SYSTEM = "system" + _request_id: ContextVar[str] = ContextVar("log_request_id", default="") _trace_id: ContextVar[str] = ContextVar("log_trace_id", default="") +# Workflow log context +_app_id: ContextVar[str] = ContextVar("log_app_id", default="") +_workflow_id: ContextVar[str] = ContextVar("log_workflow_id", default="") +_node_id: ContextVar[str] = ContextVar("log_node_id", default="") + +# Error source context (set by WorkflowEntry.run during workflow execution) +_error_source: ContextVar[ErrorSource] = ContextVar("log_error_source", default=ErrorSource.SYSTEM) + def get_request_id() -> str: """Get current request ID (10 hex chars).""" @@ -33,3 +55,85 @@ def clear_request_context() -> None: """Clear request context. Call at end of request (optional).""" _request_id.set("") _trace_id.set("") + + +# --------------------------------------------------------------------------- +# Workflow log context +# --------------------------------------------------------------------------- + + +def get_app_id() -> str: + """Get current workflow app_id for logging.""" + return _app_id.get() + + +def get_workflow_id() -> str: + """Get current workflow_id for logging.""" + return _workflow_id.get() + + +def get_node_id() -> str: + """Get current node_id for logging.""" + return _node_id.get() + + +def set_workflow_log_context(app_id: str, workflow_id: str) -> None: + """Set workflow-level log context (app_id, workflow_id). + + Call at graph start. Use ``clear_workflow_log_context`` at graph end. + """ + _app_id.set(app_id) + _workflow_id.set(workflow_id) + + +def set_node_log_context(node_id: str) -> None: + """Set or clear node-level log context. + + Pass empty string to clear node_id between node executions. + """ + _node_id.set(node_id) + + +def clear_workflow_log_context() -> None: + """Clear workflow log context (app_id, workflow_id, node_id). + + Call at graph end to ensure no stale context leaks to subsequent logs. + + Note: This does **not** reset ``error_source``. When ``on_graph_end`` + receives a non-None error, the subsequent ``logger.exception`` call in + ``WorkflowEntry.run`` still needs ``error_source == WORKFLOW`` to + correctly classify the error. ``error_source`` is reset separately + via ``clear_error_source`` after the error has been logged. + """ + _app_id.set("") + _workflow_id.set("") + _node_id.set("") + + +# --------------------------------------------------------------------------- +# Error source context +# --------------------------------------------------------------------------- + + +def get_error_source() -> ErrorSource: + """Get current error_source for logging. + + Defaults to ``ErrorSource.SYSTEM`` when no execution context is active. + Set to ``ErrorSource.WORKFLOW`` by ``WorkflowEntry.run`` during + workflow graph execution. + """ + return _error_source.get() + + +def set_error_source(source: ErrorSource) -> None: + """Set error_source context. + + Typically called by ``WorkflowEntry.run`` with + ``ErrorSource.WORKFLOW`` before graph execution starts. + """ + _error_source.set(source) + + +def clear_error_source() -> None: + """Reset error_source context to the default (SYSTEM).""" + _error_source.set(ErrorSource.SYSTEM) diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index 3f6c565e13164c..afc49dd2c63a23 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -6,7 +6,13 @@ import flask -from core.logging.context import get_request_id, get_trace_id +from core.logging.context import ( + get_app_id, + get_node_id, + get_request_id, + get_trace_id, + get_workflow_id, +) from core.logging.structured_formatter import IdentityDict @@ -97,3 +103,19 @@ def _extract_identity(self) -> IdentityDict: return identity except Exception: return {} + + +class WorkflowLogContextFilter(logging.Filter): + """Inject workflow log context (app_id, workflow_id, node_id) into log records. + + Values are read from ContextVars that are managed by ``WorkflowEntry.run`` + (app_id / workflow_id / error_source) and ``WorkflowLogContextLayer`` + (node_id). + """ + + @override + def filter(self, record: logging.LogRecord) -> bool: + record.app_id = get_app_id() + record.workflow_id = get_workflow_id() + record.node_id = get_node_id() + return True diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index 56ea748242717e..f3117a187be0b6 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -8,6 +8,7 @@ import orjson from configs import dify_config +from core.logging.context import get_error_source class IdentityDict(TypedDict, total=False): @@ -16,6 +17,12 @@ class IdentityDict(TypedDict, total=False): user_type: str +class LogContextDict(TypedDict, total=False): + app_id: str + workflow_id: str + node_id: str + + class LogDict(TypedDict): ts: str severity: str @@ -25,6 +32,8 @@ class LogDict(TypedDict): trace_id: NotRequired[str] span_id: NotRequired[str] identity: NotRequired[IdentityDict] + context: NotRequired[LogContextDict] + error_source: NotRequired[str] attributes: NotRequired[dict[str, Any]] stack_trace: NotRequired[str] @@ -93,6 +102,15 @@ def _build_log_dict(self, record: logging.LogRecord) -> LogDict: if identity: log_dict["identity"] = identity + # Workflow log context (from WorkflowLogContextFilter) + context = self._extract_log_context(record) + if context: + log_dict["context"] = context + + # Error source inference (ERROR and above only) + if record.levelno >= logging.ERROR: + log_dict["error_source"] = self._infer_error_source(record) + # Dynamic attributes attributes = getattr(record, "attributes", None) if attributes: @@ -121,6 +139,33 @@ def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None: identity["user_type"] = user_type return identity + def _extract_log_context(self, record: logging.LogRecord) -> LogContextDict | None: + """Extract workflow log context (app_id, workflow_id, node_id) from record.""" + app_id = getattr(record, "app_id", "") or "" + workflow_id = getattr(record, "workflow_id", "") or "" + node_id = getattr(record, "node_id", "") or "" + + if not any([app_id, workflow_id, node_id]): + return None + + context: LogContextDict = {} + if app_id: + context["app_id"] = app_id + if workflow_id: + context["workflow_id"] = workflow_id + if node_id: + context["node_id"] = node_id + return context + + def _infer_error_source(self, record: logging.LogRecord) -> str: + """Return the error_source for this ERROR+ log record. + + The value comes from the ``_error_source`` ContextVar, which defaults + to ``"system"`` and is set to ``"workflow"`` by ``WorkflowEntry.run`` + during workflow graph execution. + """ + return get_error_source().value + def _format_exception(self, exc_info: tuple[Any, ...]) -> str: if exc_info and exc_info[0] is not None: return "".join(traceback.format_exception(*exc_info)) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 9de26b8214b117..5a2c6fb9cc326e 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -9,7 +9,16 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.app.workflow.layers.log_context import WorkflowLogContextLayer from core.app.workflow.layers.observability import ObservabilityLayer +from core.logging.context import ( + ErrorSource, + clear_error_source, + clear_workflow_log_context, + set_error_source, + set_node_log_context, + set_workflow_log_context, +) from core.workflow.node_factory import ( DifyGraphInitContext, DifyNodeFactory, @@ -61,6 +70,34 @@ def iter_dify_graph_engine_events(engine: GraphEngine) -> Generator[GraphEngineE ) +def _extract_failed_node_id(graph_engine: GraphEngine) -> str: + """Extract the node_id of the first failed node from graph_execution. + + Node execution runs in worker threads where ContextVar values don't + propagate back to the main thread. This function recovers the + failed node_id from the domain model (``graph_execution.node_executions``) + so it can be set on the main thread's ContextVar before logging. + + Returns empty string if no failed node is found or if the graph + execution state is inaccessible. + """ + try: + graph_execution = graph_engine.graph_runtime_state.graph_execution + if graph_execution is None: + return "" + + node_executions = graph_execution.node_executions + for _node_id, node_exec in node_executions.items(): + # NodeExecutionProtocol doesn't declare `error`, but the concrete + # NodeExecution dataclass has it. Use getattr for safety. + if getattr(node_exec, "error", None) is not None: + return getattr(node_exec, "node_id", _node_id) + except (AttributeError, TypeError, KeyError): + pass + + return "" + + class _WorkflowChildEngineBuilder: tenant_id: str @@ -195,6 +232,8 @@ def __init__( command_channel = InMemoryChannel() self.command_channel = command_channel + self._app_id = app_id + self._workflow_id = workflow_id execution_context = capture_current_context() graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id) @@ -231,6 +270,9 @@ def __init__( self.graph_engine.layer(limits_layer) self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id)) + # Add workflow log context layer (node_id tracking in logs) + self.graph_engine.layer(WorkflowLogContextLayer(execution_context=execution_context)) + # Add observability layer when OTel is enabled if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): self.graph_engine.layer(ObservabilityLayer()) @@ -238,6 +280,13 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: graph_engine = self.graph_engine + # Set workflow log context before graph execution starts. + # This ensures app_id / workflow_id / error_source are available + # to all log records, including the logger.exception call below. + # Context is cleared in the finally block after all logging is done. + set_workflow_log_context(self._app_id, self._workflow_id) + set_error_source(ErrorSource.WORKFLOW) + try: # Preserve Dify's response-stream semantics on top of Graphon 0.5.0. generator = iter_dify_graph_engine_events(graph_engine) @@ -245,9 +294,21 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: except GenerateTaskStoppedError: pass except Exception as e: + # Extract the failed node_id from graph_execution so that + # logger.exception carries it in the log context. Node + # execution runs in worker threads where ContextVar doesn't + # propagate back to this main thread, so we recover node_id + # from the domain model instead. + failed_node_id = _extract_failed_node_id(graph_engine) + if failed_node_id: + set_node_log_context(failed_node_id) + logger.exception("Unknown Error when workflow entry running") yield GraphRunFailedEvent(error=str(e)) return + finally: + clear_workflow_log_context() + clear_error_source() @classmethod def single_step_run( diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 5817e6a6cebe0b..eca8549df0320f 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -32,11 +32,16 @@ def init_app(app: DifyApp): log_handlers.append(sh) # Apply filters to all handlers - from core.logging.filters import IdentityContextFilter, TraceContextFilter + from core.logging.filters import ( + IdentityContextFilter, + TraceContextFilter, + WorkflowLogContextFilter, + ) for handler in log_handlers: handler.addFilter(TraceContextFilter()) handler.addFilter(IdentityContextFilter()) + handler.addFilter(WorkflowLogContextFilter()) # Configure formatter based on format type formatter = _create_formatter() diff --git a/api/tests/unit_tests/context/test_execution_context.py b/api/tests/unit_tests/context/test_execution_context.py new file mode 100644 index 00000000000000..67b69452564125 --- /dev/null +++ b/api/tests/unit_tests/context/test_execution_context.py @@ -0,0 +1,46 @@ +"""Tests for ExecutionContext.refresh_context_vars.""" + +import contextvars + +from context.execution_context import ExecutionContext + + +class TestRefreshContextVars: + """Tests for ExecutionContext.refresh_context_vars.""" + + def test_refresh_captures_current_contextvars(self): + """refresh_context_vars should re-capture the current ContextVar state.""" + test_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_refresh_var", default="initial") + + # Set a value so it appears in the copy_context snapshot + token1 = test_var.set("initial") + try: + ctx = ExecutionContext(context_vars=contextvars.copy_context()) + assert ctx.context_vars is not None + + # Change the ContextVar in the current thread + token2 = test_var.set("updated") + try: + # Before refresh, the snapshot still has "initial" + old_val = ctx.context_vars.get(test_var) + assert old_val == "initial" + + # After refresh, the snapshot should have "updated" + ctx.refresh_context_vars() + new_val = ctx.context_vars.get(test_var) + assert new_val == "updated" + finally: + test_var.reset(token2) + finally: + test_var.reset(token1) + + def test_refresh_replaces_context_vars(self): + """refresh_context_vars should replace the _context_vars attribute.""" + ctx = ExecutionContext(context_vars=contextvars.copy_context()) + original = ctx.context_vars + assert original is not None + + ctx.refresh_context_vars() + assert ctx.context_vars is not None + # Should be a new Context object (not the same reference) + assert ctx.context_vars is not original diff --git a/api/tests/unit_tests/core/app/workflow/test_log_context_layer.py b/api/tests/unit_tests/core/app/workflow/test_log_context_layer.py new file mode 100644 index 00000000000000..a05374767893c7 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/test_log_context_layer.py @@ -0,0 +1,140 @@ +"""Tests for WorkflowLogContextLayer.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.workflow.layers.log_context import WorkflowLogContextLayer +from core.logging.context import ( + clear_workflow_log_context, + get_node_id, + set_workflow_log_context, +) + + +@pytest.fixture(autouse=True) +def _clean_context(): + """Ensure clean ContextVar state before and after each test.""" + clear_workflow_log_context() + yield + clear_workflow_log_context() + + +class TestWorkflowLogContextLayer: + """Tests for WorkflowLogContextLayer lifecycle management. + + The layer only manages node_id. app_id / workflow_id / error_source + are managed by WorkflowEntry.run directly. + """ + + def test_on_graph_start_does_not_set_app_or_workflow_id(self): + """app_id / workflow_id are set by WorkflowEntry.run, not the layer.""" + layer = WorkflowLogContextLayer() + + layer.on_graph_start() + + # Layer should not set app_id or workflow_id + # (they remain whatever they were before) + assert get_node_id() == "" + + def test_on_graph_start_refreshes_execution_context(self): + execution_context = MagicMock() + layer = WorkflowLogContextLayer(execution_context=execution_context) + + layer.on_graph_start() + + execution_context.refresh_context_vars.assert_called_once() + + def test_on_graph_start_without_execution_context_does_not_raise(self): + layer = WorkflowLogContextLayer(execution_context=None) + + layer.on_graph_start() # should not raise + + def test_on_node_run_start_sets_node_id(self): + layer = WorkflowLogContextLayer() + layer.on_graph_start() + + node = SimpleNamespace(id="node-abc") + layer.on_node_run_start(node) + + assert get_node_id() == "node-abc" + + def test_on_node_run_end_clears_node_id(self): + layer = WorkflowLogContextLayer() + layer.on_graph_start() + + node = SimpleNamespace(id="node-abc") + layer.on_node_run_start(node) + assert get_node_id() == "node-abc" + + layer.on_node_run_end(node, error=None) + assert get_node_id() == "" + + def test_on_graph_end_is_noop(self): + """on_graph_end does not clear context — WorkflowEntry.run handles that.""" + set_workflow_log_context("app-001", "wf-002") + layer = WorkflowLogContextLayer() + layer.on_graph_start() + + node = SimpleNamespace(id="node-abc") + layer.on_node_run_start(node) + + layer.on_graph_end(error=None) + + # Layer should not clear anything on_graph_end + assert get_node_id() == "node-abc" + + def test_node_id_switches_between_nodes(self): + layer = WorkflowLogContextLayer() + layer.on_graph_start() + + node1 = SimpleNamespace(id="node-1") + node2 = SimpleNamespace(id="node-2") + + layer.on_node_run_start(node1) + assert get_node_id() == "node-1" + + layer.on_node_run_end(node1, error=None) + assert get_node_id() == "" + + layer.on_node_run_start(node2) + assert get_node_id() == "node-2" + + layer.on_node_run_end(node2, error=None) + assert get_node_id() == "" + + def test_on_event_is_noop(self): + layer = WorkflowLogContextLayer() + # Should not raise + layer.on_event(object()) + + def test_full_lifecycle(self): + """Simulate a full graph run with two nodes.""" + # app_id / workflow_id are set by WorkflowEntry.run, not the layer. + set_workflow_log_context("app-100", "wf-200") + + layer = WorkflowLogContextLayer() + + # Graph start (refreshes execution context only) + layer.on_graph_start() + + # Node 1 + node1 = SimpleNamespace(id="n1") + layer.on_node_run_start(node1) + assert get_node_id() == "n1" + layer.on_node_run_end(node1, error=None) + assert get_node_id() == "" + + # Node 2 + node2 = SimpleNamespace(id="n2") + layer.on_node_run_start(node2) + assert get_node_id() == "n2" + layer.on_node_run_end(node2, error=None) + assert get_node_id() == "" + + # Graph end — layer is a no-op; WorkflowEntry.run clears context + layer.on_graph_end(error=None) + assert get_node_id() == "" diff --git a/api/tests/unit_tests/core/logging/test_context.py b/api/tests/unit_tests/core/logging/test_context.py index f388a3a0b9f3f0..4956ccb51bfbdb 100644 --- a/api/tests/unit_tests/core/logging/test_context.py +++ b/api/tests/unit_tests/core/logging/test_context.py @@ -3,10 +3,20 @@ import uuid from core.logging.context import ( + ErrorSource, + clear_error_source, clear_request_context, + clear_workflow_log_context, + get_app_id, + get_error_source, + get_node_id, get_request_id, get_trace_id, + get_workflow_id, init_request_context, + set_error_source, + set_node_log_context, + set_workflow_log_context, ) @@ -77,3 +87,86 @@ def test_context_isolation(self): # IDs should be different assert id1 != id2 + + +class TestWorkflowLogContext: + """Tests for workflow log context functions.""" + + def setup_method(self): + clear_workflow_log_context() + + def teardown_method(self): + clear_workflow_log_context() + + def test_default_values_are_empty(self): + assert get_app_id() == "" + assert get_workflow_id() == "" + assert get_node_id() == "" + + def test_set_workflow_log_context(self): + set_workflow_log_context("app-001", "wf-002") + assert get_app_id() == "app-001" + assert get_workflow_id() == "wf-002" + # node_id should still be empty + assert get_node_id() == "" + + def test_set_node_log_context(self): + set_workflow_log_context("app-001", "wf-002") + set_node_log_context("node-abc") + assert get_node_id() == "node-abc" + + def test_clear_node_log_context(self): + set_workflow_log_context("app-001", "wf-002") + set_node_log_context("node-abc") + set_node_log_context("") + assert get_node_id() == "" + + def test_clear_workflow_log_context_clears_all(self): + set_workflow_log_context("app-001", "wf-002") + set_node_log_context("node-abc") + + clear_workflow_log_context() + assert get_app_id() == "" + assert get_workflow_id() == "" + assert get_node_id() == "" + + +class TestErrorSourceContext: + """Tests for error_source context functions.""" + + def setup_method(self): + clear_error_source() + + def teardown_method(self): + clear_error_source() + + def test_default_value_is_system(self): + assert get_error_source() == ErrorSource.SYSTEM + assert get_error_source().value == "system" + + def test_set_error_source(self): + set_error_source(ErrorSource.WORKFLOW) + assert get_error_source() == ErrorSource.WORKFLOW + assert get_error_source().value == "workflow" + + def test_clear_error_source(self): + set_error_source(ErrorSource.WORKFLOW) + clear_error_source() + assert get_error_source() == ErrorSource.SYSTEM + + def test_clear_workflow_log_context_does_not_reset_error_source(self): + """clear_workflow_log_context should NOT reset error_source. + + error_source is managed independently by WorkflowEntry.run: + it is set before graph execution and cleared in the finally block + after all logging is done. + """ + set_workflow_log_context("app-001", "wf-002") + set_error_source(ErrorSource.WORKFLOW) + + clear_workflow_log_context() + # error_source should remain WORKFLOW — clearing workflow context + # (app_id/workflow_id/node_id) must not reset error classification. + assert get_error_source() == ErrorSource.WORKFLOW + + clear_error_source() diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index a8b186ac8aa18e..c3dab522c3191d 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -292,3 +292,67 @@ class AnotherClass: assert log_record.tenant_id == "tenant_id" assert log_record.user_id == "end_user_id" assert log_record.user_type == "end_user" + + +class TestWorkflowLogContextFilter: + """Tests for WorkflowLogContextFilter.""" + + def test_sets_empty_context_by_default(self, log_record): + from core.logging.context import clear_workflow_log_context + from core.logging.filters import WorkflowLogContextFilter + + clear_workflow_log_context() + + filter = WorkflowLogContextFilter() + result = filter.filter(log_record) + + assert result is True + assert log_record.app_id == "" + assert log_record.workflow_id == "" + assert log_record.node_id == "" + + def test_sets_context_from_contextvars(self, log_record): + from core.logging.context import ( + clear_workflow_log_context, + set_workflow_log_context, + ) + from core.logging.filters import WorkflowLogContextFilter + + clear_workflow_log_context() + set_workflow_log_context("app-100", "wf-200") + + filter = WorkflowLogContextFilter() + filter.filter(log_record) + + assert log_record.app_id == "app-100" + assert log_record.workflow_id == "wf-200" + assert log_record.node_id == "" + + def test_sets_node_id_from_contextvar(self, log_record): + from core.logging.context import ( + clear_workflow_log_context, + set_node_log_context, + set_workflow_log_context, + ) + from core.logging.filters import WorkflowLogContextFilter + + clear_workflow_log_context() + set_workflow_log_context("app-100", "wf-200") + set_node_log_context("node-xyz") + + filter = WorkflowLogContextFilter() + filter.filter(log_record) + + assert log_record.app_id == "app-100" + assert log_record.workflow_id == "wf-200" + assert log_record.node_id == "node-xyz" + + def test_filter_always_returns_true(self, log_record): + from core.logging.context import clear_workflow_log_context + from core.logging.filters import WorkflowLogContextFilter + + clear_workflow_log_context() + + filter = WorkflowLogContextFilter() + result = filter.filter(log_record) + assert result is True diff --git a/api/tests/unit_tests/core/logging/test_structured_formatter.py b/api/tests/unit_tests/core/logging/test_structured_formatter.py index 94b91d205e074a..81f23b7e698e04 100644 --- a/api/tests/unit_tests/core/logging/test_structured_formatter.py +++ b/api/tests/unit_tests/core/logging/test_structured_formatter.py @@ -5,6 +5,9 @@ import orjson +from core.logging.context import ErrorSource, clear_error_source, set_error_source +from core.logging.structured_formatter import StructuredJSONFormatter + class TestStructuredJSONFormatter: def test_basic_log_format(self): @@ -265,3 +268,169 @@ def test_fallback_for_non_serializable_attributes(self): log_dict = json.loads(output) assert log_dict["message"] == "Test with non-serializable" assert "attributes" in log_dict + + +# --------------------------------------------------------------------------- +# Workflow log context (app_id / workflow_id / node_id) +# --------------------------------------------------------------------------- + + +class TestLogContextExtraction: + """Tests for workflow log context extraction in the formatter.""" + + def test_context_included_when_set(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.app_id = "app-123" + record.workflow_id = "wf-456" + record.node_id = "node-789" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "context" in log_dict + assert log_dict["context"]["app_id"] == "app-123" + assert log_dict["context"]["workflow_id"] == "wf-456" + assert log_dict["context"]["node_id"] == "node-789" + + def test_context_partial_fields(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.app_id = "app-123" + # workflow_id and node_id not set + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "context" in log_dict + assert log_dict["context"]["app_id"] == "app-123" + assert "workflow_id" not in log_dict["context"] + assert "node_id" not in log_dict["context"] + + def test_no_context_when_all_empty(self): + from core.logging.structured_formatter import StructuredJSONFormatter + + formatter = StructuredJSONFormatter() + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + record.app_id = "" + record.workflow_id = "" + record.node_id = "" + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "context" not in log_dict + + +# --------------------------------------------------------------------------- +# error_source inference +# --------------------------------------------------------------------------- + + +class TestInferErrorSource: + """Tests for _infer_error_source with ContextVar-based source.""" + + def _make_record(self, level: int = logging.ERROR) -> logging.LogRecord: + return logging.LogRecord( + name="test", + level=level, + pathname="test.py", + lineno=1, + msg="Test", + args=(), + exc_info=None, + ) + + def test_default_is_system(self): + """Without any context, default to 'system'.""" + clear_error_source() + formatter = StructuredJSONFormatter() + record = self._make_record() + + assert formatter._infer_error_source(record) == "system" + + def test_contextvar_workflow(self): + """When ContextVar is set to WORKFLOW, error_source should be 'workflow'.""" + clear_error_source() + set_error_source(ErrorSource.WORKFLOW) + formatter = StructuredJSONFormatter() + record = self._make_record() + + assert formatter._infer_error_source(record) == "workflow" + + clear_error_source() + + def test_error_source_only_for_error_and_above(self): + """error_source field should NOT appear for INFO/DEBUG logs.""" + clear_error_source() + formatter = StructuredJSONFormatter() + record = self._make_record(level=logging.INFO) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert "error_source" not in log_dict + + def test_error_source_present_for_error(self): + """error_source field should appear for ERROR logs.""" + clear_error_source() + set_error_source(ErrorSource.WORKFLOW) + formatter = StructuredJSONFormatter() + record = self._make_record(level=logging.ERROR) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["error_source"] == "workflow" + + clear_error_source() + + def test_error_source_present_for_critical(self): + """error_source field should appear for CRITICAL logs.""" + clear_error_source() + formatter = StructuredJSONFormatter() + record = self._make_record(level=logging.CRITICAL) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["error_source"] == "system" + + def test_error_source_default_system_in_output(self): + """Without context, ERROR log should have error_source='system'.""" + clear_error_source() + formatter = StructuredJSONFormatter() + record = self._make_record(level=logging.ERROR) + + output = formatter.format(record) + log_dict = orjson.loads(output) + + assert log_dict["error_source"] == "system" diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 3ccfdf76f5a517..2787463788b000 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -322,18 +322,24 @@ def test_applies_debug_and_observability_layers(self): max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id") - assert graph_engine.layer.call_args_list == [ - ((debug_layer,), {}), - ((execution_limits_layer,), {}), - ((llm_quota_layer,), {}), - ((observability_layer,), {}), - ] + # Layers are registered in order: debug, limits, llm_quota, log_context, observability + layer_calls = graph_engine.layer.call_args_list + assert len(layer_calls) == 5 + assert layer_calls[0] == ((debug_layer,), {}) + assert layer_calls[1] == ((execution_limits_layer,), {}) + assert layer_calls[2] == ((llm_quota_layer,), {}) + # layer 3 is WorkflowLogContextLayer (instantiated directly, not mocked) + log_ctx_layer = layer_calls[3][0][0] + assert type(log_ctx_layer).__name__ == "WorkflowLogContextLayer" + assert layer_calls[4] == ((observability_layer,), {}) class TestWorkflowEntryRun: def test_run_swallows_generate_task_stopped_errors(self): entry = object.__new__(workflow_entry.WorkflowEntry) entry.graph_engine = MagicMock() + entry._app_id = "app-id" + entry._workflow_id = "workflow-id" entry.graph_engine.run.side_effect = GenerateTaskStoppedError() assert list(entry.run()) == [] @@ -373,6 +379,8 @@ def test_iter_dify_graph_engine_events_applies_response_stream_filter(self): def test_run_delegates_to_dify_event_iterator(self): entry = object.__new__(workflow_entry.WorkflowEntry) entry.graph_engine = sentinel.graph_engine + entry._app_id = "app-id" + entry._workflow_id = "workflow-id" with patch.object( workflow_entry, @@ -387,6 +395,8 @@ def test_run_delegates_to_dify_event_iterator(self): def test_run_emits_failed_event_for_unexpected_errors(self): entry = object.__new__(workflow_entry.WorkflowEntry) entry.graph_engine = MagicMock() + entry._app_id = "app-id" + entry._workflow_id = "workflow-id" entry.graph_engine.run.side_effect = RuntimeError("boom") events = list(entry.run()) @@ -396,6 +406,59 @@ def test_run_emits_failed_event_for_unexpected_errors(self): assert events[0].error == "boom" +class TestExtractFailedNodeId: + """Tests for _extract_failed_node_id helper.""" + + def test_returns_empty_when_graph_execution_is_none(self): + + engine = MagicMock() + engine.graph_runtime_state.graph_execution = None + + assert workflow_entry._extract_failed_node_id(engine) == "" + + def test_returns_empty_when_no_failed_nodes(self): + from graphon.graph_engine.domain.graph_execution import GraphExecution + + graph_exec = GraphExecution(workflow_id="wf-1") + graph_exec.get_or_create_node_execution("node-1").mark_taken() + + engine = MagicMock() + engine.graph_runtime_state.graph_execution = graph_exec + + assert workflow_entry._extract_failed_node_id(engine) == "" + + def test_returns_failed_node_id(self): + from graphon.graph_engine.domain.graph_execution import GraphExecution + + graph_exec = GraphExecution(workflow_id="wf-1") + graph_exec.get_or_create_node_execution("node-1").mark_taken() + graph_exec.get_or_create_node_execution("node-2").mark_failed("JSON parse error") + + engine = MagicMock() + engine.graph_runtime_state.graph_execution = graph_exec + + assert workflow_entry._extract_failed_node_id(engine) == "node-2" + + def test_returns_first_failed_when_multiple_failures(self): + from graphon.graph_engine.domain.graph_execution import GraphExecution + + graph_exec = GraphExecution(workflow_id="wf-1") + graph_exec.get_or_create_node_execution("node-a").mark_failed("error A") + graph_exec.get_or_create_node_execution("node-b").mark_failed("error B") + + engine = MagicMock() + engine.graph_runtime_state.graph_execution = graph_exec + + result = workflow_entry._extract_failed_node_id(engine) + assert result in ("node-a", "node-b") + + def test_returns_empty_on_exception(self): + engine = MagicMock() + engine.graph_runtime_state.graph_execution.node_executions.items.side_effect = TypeError("mock") + + assert workflow_entry._extract_failed_node_id(engine) == "" + + class TestWorkflowEntrySingleStepRun: def test_preloads_constructor_variables_before_creating_memory_node(self): class FakeLLMNode: