Skip to content

Make dspy.settings and dspy.context safe in async setup #8203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import contextvars
import copy
import threading
from contextlib import contextmanager
Expand Down Expand Up @@ -36,13 +38,7 @@
# Global lock for settings configuration
global_lock = threading.Lock()


class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict()


thread_local_overrides = ThreadLocalOverrides()
thread_local_overrides = contextvars.ContextVar("context_overrides", default=dotdict())


class Settings:
Expand Down Expand Up @@ -73,7 +69,7 @@ def lock(self):
return global_lock

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
if name in overrides:
return overrides[name]
elif name in main_thread_config:
Expand All @@ -94,7 +90,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
return key in overrides or key in main_thread_config

def get(self, key, default=None):
Expand All @@ -104,23 +100,62 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
return dotdict({**main_thread_config, **overrides})

@property
def config(self):
return self.copy()

def configure(self, **kwargs):
def _ensure_configure_allowed(self):
global main_thread_config, config_owner_thread_id
current_thread_id = threading.get_ident()

with self.lock:
# First configuration: establish ownership. If ownership established, only that thread can configure.
if config_owner_thread_id in [None, current_thread_id]:
config_owner_thread_id = current_thread_id
else:
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")
if config_owner_thread_id is None:
# First `configure` call is always allowed.
config_owner_thread_id = current_thread_id
return

if config_owner_thread_id != current_thread_id:
# Disallow a second `configure` calls from other threads.
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")

# Async task doesn't allow a second `configure` call, must use dspy.context(...) instead.
is_async_task = False
try:
if asyncio.current_task() is not None:
is_async_task = True
except RuntimeError:
# This exception (e.g., "no current task") means we are not in an async loop/task,
# or asyncio module itself is not fully functional in this specific sub-thread context.
is_async_task = False

if not is_async_task:
return

# We are in an async task. Now check for IPython and allow calling `configure` from IPython.
in_ipython = False
try:
from IPython import get_ipython

# get_ipython is a global injected by IPython environments.
# We check its existence and type to be more robust.
shell = get_ipython()
if shell is not None and "InteractiveShell" in shell.__class__.__name__:
in_ipython = True
except Exception:
# If `IPython` is not installed or `get_ipython` failed, we are not in an IPython environment.
in_ipython = False

if not in_ipython:
raise RuntimeError(
"dspy.settings.configure(...) cannot be called a second time from an async task. Use "
"`dspy.context(...)` instead."
)

def configure(self, **kwargs):
# If no exception is raised, the `configure` call is allowed.
self._ensure_configure_allowed()

# Update global config
for k, v in kwargs.items():
Expand All @@ -134,17 +169,17 @@ def context(self, **kwargs):
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
"""

original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
original_overrides = thread_local_overrides.get().copy()
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
thread_local_overrides.overrides = new_overrides
token = thread_local_overrides.set(new_overrides)

try:
yield
finally:
thread_local_overrides.overrides = original_overrides
thread_local_overrides.reset(token)

def __repr__(self):
overrides = getattr(thread_local_overrides, "overrides", dotdict())
overrides = thread_local_overrides.get()
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

Expand Down
7 changes: 0 additions & 7 deletions dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,10 @@ def apply_sync_streaming(async_generator: AsyncGenerator) -> Generator:

# To propagate prediction request ID context to the child thread
context = contextvars.copy_context()
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()

def producer():
"""Runs in a background thread to fetch items asynchronously."""

original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()

async def runner():
try:
async for item in async_generator:
Expand All @@ -241,7 +235,6 @@ async def runner():
queue.put(stop_sentinel)

context.run(asyncio.run, runner())
thread_local_overrides.overrides = original_overrides

# Start the producer in a background thread
thread = threading.Thread(target=producer, daemon=True)
Expand Down
8 changes: 4 additions & 4 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ async def async_program(*args, **kwargs) -> Any:
# Capture the current overrides at call-time.
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()
parent_overrides = thread_local_overrides.get().copy()

def wrapped_program(*a, **kw):
from dspy.dsp.utils.settings import thread_local_overrides

original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
original_overrides = thread_local_overrides.get()
token = thread_local_overrides.set({**original_overrides, **parent_overrides.copy()})
try:
return program(*a, **kw)
finally:
thread_local_overrides.overrides = original_overrides
thread_local_overrides.reset(token)

# Create a fresh asyncified callable each time, ensuring the latest context is used.
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
Expand Down
8 changes: 4 additions & 4 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def worker(parent_overrides, submission_id, index, item):
# Apply parent's thread-local overrides
from dspy.dsp.utils.settings import thread_local_overrides

original = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
original = thread_local_overrides.get()
token = thread_local_overrides.set({**original, **parent_overrides.copy()})
if parent_overrides.get("usage_tracker"):
# Usage tracker needs to be deep copied across threads so that each thread tracks its own usage
thread_local_overrides.overrides["usage_tracker"] = copy.deepcopy(parent_overrides["usage_tracker"])

try:
return index, function(item)
finally:
thread_local_overrides.overrides = original
thread_local_overrides.reset(token)

# Handle Ctrl-C in the main thread
@contextlib.contextmanager
Expand All @@ -121,7 +121,7 @@ def handler(sig, frame):
with interrupt_manager():
from dspy.dsp.utils.settings import thread_local_overrides

parent_overrides = thread_local_overrides.overrides.copy()
parent_overrides = thread_local_overrides.get().copy()

futures_map = {}
futures_set = set()
Expand Down
5 changes: 2 additions & 3 deletions tests/adapters/test_two_step_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ class TestSignature(dspy.Signature):
mock_extraction_lm.kwargs = {"temperature": 1.0}
mock_extraction_lm.model = "openai/gpt-4o"

dspy.configure(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm))

result = await program.acall(question="What is 5 + 7?")
with dspy.context(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm)):
result = await program.acall(question="What is 5 + 7?")

assert result.answer == 12

Expand Down
9 changes: 4 additions & 5 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,12 @@ def test_callback_complex_module():
@pytest.mark.asyncio
async def test_callback_async_module():
callback = MyCallback()
dspy.settings.configure(
with dspy.context(
lm=DummyLM({"How are you?": {"answer": "test output", "reasoning": "No more responses"}}),
callbacks=[callback],
)

cot = dspy.ChainOfThought("question -> answer", n=3)
result = await cot.acall(question="How are you?")
):
cot = dspy.ChainOfThought("question -> answer", n=3)
result = await cot.acall(question="How are you?")
assert result["answer"] == "test output"
assert result["reasoning"] == "No more responses"

Expand Down
8 changes: 4 additions & 4 deletions tests/predict/test_chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_initialization_with_string_signature():
@pytest.mark.asyncio
async def test_async_chain_of_thought():
lm = DummyLM([{"reasoning": "find the number after 1", "answer": "2"}])
dspy.settings.configure(lm=lm)
program = ChainOfThought("question -> answer")
result = await program.acall(question="What is 1+1?")
assert result.answer == "2"
with dspy.context(lm=lm):
program = ChainOfThought("question -> answer")
result = await program.acall(question="What is 1+1?")
assert result.answer == "2"
73 changes: 70 additions & 3 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pydantic
import pytest
import ujson
import os
import time
import asyncio
import types
from litellm import ModelResponse

import dspy
Expand Down Expand Up @@ -506,6 +510,69 @@ def test_lm_usage():
assert result.get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


def test_lm_usage_with_parallel():
program = Predict("question -> answer")

def program_wrapper(question):
# Sleep to make it possible to cause a race condition
time.sleep(0.5)
return program(question=question)

dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
with patch(
"dspy.clients.lm.litellm_completion",
return_value=ModelResponse(
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
usage={"total_tokens": 10},
),
):
parallelizer = dspy.Parallel()
input_pairs = [
(program_wrapper, {"question": "What is the capital of France?"}),
(program_wrapper, {"question": "What is the capital of France?"}),
]
results = parallelizer(input_pairs)
assert results[0].answer == "Paris"
assert results[1].answer == "Paris"
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


@pytest.mark.asyncio
async def test_lm_usage_with_async():
program = Predict("question -> answer")

original_aforward = program.aforward

async def patched_aforward(self, **kwargs):
await asyncio.sleep(1)
return await original_aforward(**kwargs)

program.aforward = types.MethodType(patched_aforward, program)

with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True):
with patch(
"litellm.acompletion",
return_value=ModelResponse(
choices=[{"message": {"content": "[[ ## answer ## ]]\nParis"}}],
usage={"total_tokens": 10},
),
):
coroutines = [
program.acall(question="What is the capital of France?"),
program.acall(question="What is the capital of France?"),
program.acall(question="What is the capital of France?"),
program.acall(question="What is the capital of France?"),
]
results = await asyncio.gather(*coroutines)
assert results[0].answer == "Paris"
assert results[1].answer == "Paris"
assert results[0].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[1].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[2].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10
assert results[3].get_lm_usage()["openai/gpt-4o-mini"]["total_tokens"] == 10


def test_positional_arguments():
program = Predict("question -> answer")
with pytest.raises(ValueError) as e:
Expand Down Expand Up @@ -569,9 +636,9 @@ class ConstrainedSignature(dspy.Signature):
@pytest.mark.asyncio
async def test_async_predict():
program = Predict("question -> answer")
dspy.settings.configure(lm=DummyLM([{"answer": "Paris"}]))
result = await program.acall(question="What is the capital of France?")
assert result.answer == "Paris"
with dspy.context(lm=DummyLM([{"answer": "Paris"}])):
result = await program.acall(question="What is the capital of France?")
assert result.answer == "Paris"


def test_predicted_outputs_piped_from_predict_to_lm_call():
Expand Down
24 changes: 11 additions & 13 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,15 @@ class InvitationSignature(dspy.Signature):
},
]
)
dspy.settings.configure(lm=lm)

outputs = await react.acall(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants={"Alice": "female", "Bob": "male"},
),
)
with dspy.context(lm=lm):
outputs = await react.acall(
participant_name="Alice",
event_info=CalendarEvent(
name="Science Fair",
date="Friday",
participants={"Alice": "female", "Bob": "male"},
),
)
assert outputs.invitation_letter == "It's my honor to invite Alice to the Science Fair event on Friday."

expected_trajectory = {
Expand Down Expand Up @@ -309,9 +308,8 @@ async def foo(a, b):
{"reasoning": "I added the numbers successfully", "c": 3},
]
)
dspy.settings.configure(lm=lm)

outputs = await react.acall(a=1, b=2, max_iters=2)
with dspy.context(lm=lm):
outputs = await react.acall(a=1, b=2, max_iters=2)
traj = outputs.trajectory

# Exact-match checks (thoughts + tool calls)
Expand Down
Loading