diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 1ac2fb544..8e2de7638 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -11,6 +11,7 @@ from types import GenericAlias from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union +from anyio.to_thread import run_sync from pydantic import BaseModel from pydantic.json_schema import JsonSchemaValue from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict @@ -31,11 +32,8 @@ async def run_in_executor(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: - if kwargs: - # noinspection PyTypeChecker - return await asyncio.get_running_loop().run_in_executor(None, partial(func, *args, **kwargs)) - else: - return await asyncio.get_running_loop().run_in_executor(None, func, *args) # type: ignore + wrapped_func = partial(func, *args, **kwargs) + return await run_sync(wrapped_func) def is_model_like(type_: Any) -> bool: diff --git a/tests/test_utils.py b/tests/test_utils.py index af6f452aa..2de515bda 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +import contextvars import os from collections.abc import AsyncIterator from importlib.metadata import distributions @@ -9,7 +10,7 @@ from inline_snapshot import snapshot from pydantic_ai import UserError -from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal +from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor from .models.mock_async_stream import MockAsyncStream @@ -136,3 +137,19 @@ def test_package_versions(capsys: pytest.CaptureFixture[str]): packages = sorted((package.metadata['Name'], package.version) for package in distributions()) for name, version in packages: print(f'{name:30} {version}') + + +async def test_run_in_executor_with_contextvars() -> None: + ctx_var = contextvars.ContextVar('test_var', default='default') + ctx_var.set('original_value') + + result = await run_in_executor(ctx_var.get) + assert result == ctx_var.get() + + ctx_var.set('new_value') + result = await run_in_executor(ctx_var.get) + assert result == ctx_var.get() + + # show that the old version did not work + old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get) + assert old_result != ctx_var.get()