Skip to content

feat: Integrate LiteLLM Router for advanced LLM management #8268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 146 additions & 43 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ class LM(BaseLM):

def __init__(
self,
model: str,
model: Optional[str] = None, # Made optional
router: Optional[litellm.Router] = None, # Added router parameter
model_type: Literal["chat", "text"] = "chat",
temperature: float = 0.0,
max_tokens: int = 4000,
cache: bool = True,
cache_in_memory: bool = True,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 3,
provider=None,
provider=None, # Keep for now, will adjust logic
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
train_kwargs: Optional[dict[str, Any]] = None,
Expand All @@ -47,7 +48,9 @@ def __init__(

Args:
model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
supported by LiteLLM. For example, ``"openai/gpt-4o"``.
supported by LiteLLM. For example, ``"openai/gpt-4o"``. This is optional if 'router' is provided.
router: A LiteLLM Router instance to use for model routing. If provided, 'model' can be None
or specify a default model/group for the router.
model_type: The type of the model, either ``"chat"`` or ``"text"``.
temperature: The sampling temperature to use when generating responses.
max_tokens: The maximum number of tokens to generate per response.
Expand All @@ -57,37 +60,67 @@ def __init__(
callbacks: A list of callback functions to run before and after each request.
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
backoff. Note: LiteLLM Router may have its own retry logic.
provider: The provider to use. If not specified, the provider will be inferred from the model.
This is ignored if 'router' is provided.
finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
from the models available for inference.
"""
# Remember to update LM.copy() if you modify the constructor!
if router is None and model is None:
raise ValueError("Either 'model' or 'router' must be specified.")
# The following checks are for clarity during development and can be simplified.
if router is not None and model is None:
# If router is provided, model might be implicitly handled by router or can be a model group
# For now, let's allow model to be None if router is set.
pass
if router is not None and model is not None:
# User might provide a model string that the router will use as default
pass

self.router = router
self.model = model
self.model_type = model_type
self.cache = cache
self.cache_in_memory = cache_in_memory
self.provider = provider or self.infer_provider()

if self.router:
self.provider = None # Or a new GenericRouterProvider() if we define one. For now, None.
elif provider:
self.provider = provider
else:
# This check is important: self.model must exist if self.router is None.
if self.model is None: # Should have been caught by the initial check, but as a safeguard.
raise ValueError("If 'router' is not provided, 'model' must be specified.")
self.provider = self.infer_provider()

self.callbacks = callbacks or []
self.history = []
self.num_retries = num_retries
self.num_retries = num_retries # LiteLLM Router has its own retry, consider how this interacts
self.finetuning_model = finetuning_model
self.launch_kwargs = launch_kwargs or {}
self.train_kwargs = train_kwargs or {}

# Handle model-specific configuration for different model families
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()

# Match pattern: o[1,3,4] at the start, optionally followed by -mini and anything else
model_pattern = re.match(r"^o([134])(?:-mini)?", model_family)

if model_pattern:
# Handle OpenAI reasoning models (o1, o3)
assert (
max_tokens >= 20_000 and temperature == 1.0
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
# This part might need adjustment if 'model' is None when a router is used.
if self.model:
model_family = self.model.split("/")[-1].lower() if "/" in self.model else self.model.lower()
model_pattern = re.match(r"^o([134])(?:-mini)?", model_family)

if model_pattern:
assert (
max_tokens >= 20_000 and temperature == 1.0
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
else:
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
elif self.router:
# If using a router, these defaults apply unless overridden in calls or router config.
# The router itself will handle specific model params.
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
else:
# This case should ideally not be reached if the initial check (router or model must be provided) is in place.
# However, as a fallback, initialize kwargs.
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache):
Expand Down Expand Up @@ -115,20 +148,42 @@ def _get_cached_completion_fn(self, completion_fn, cache, enable_memory_cache):

def forward(self, prompt=None, messages=None, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)

messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

if self.router:
# Ensure 'model' is not passed directly in kwargs to router.completion if self.model is the source
router_kwargs = {key: value for key, value in kwargs.items() if key != 'model'}
all_kwargs = {**self.kwargs, **router_kwargs}

if not self.model:
raise ValueError("self.model must be set when using a router to specify the target model/group.")

# LiteLLM Router handles its own caching and retries.
# DSPy's cache parameters and self.num_retries are not passed here.
results = self.router.completion(
model=self.model, # This tells the router which model/group to use
messages=messages,
**all_kwargs,
)
else:
# Existing logic for non-router path
cache = kwargs.pop("cache", self.cache)
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
all_kwargs = {**self.kwargs, **kwargs} # ensure kwargs passed to method are combined

completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
completion = litellm_completion if self.model_type == "chat" else litellm_text_completion
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)

# self.model must be present here as router is None (constructor should ensure this)
if not self.model:
raise ValueError("self.model must be set when not using a router.")

results = completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
cache=litellm_cache_args,
)

results = completion(
request=dict(model=self.model, messages=messages, **all_kwargs),
num_retries=self.num_retries,
cache=litellm_cache_args,
)

if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
Expand All @@ -140,25 +195,50 @@ def forward(self, prompt=None, messages=None, **kwargs):
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
settings.usage_tracker.add_usage(self.model, dict(results.usage))
# For router, self.model is the group/alias. For non-router, it's the specific model.
# Usage tracking model key might need adjustment if router provides specific model info in response.
# For now, using self.model for both.
model_key_for_usage = self.model
settings.usage_tracker.add_usage(model_key_for_usage, dict(results.usage))
return results

async def aforward(self, prompt=None, messages=None, **kwargs):
# Build the request.
cache = kwargs.pop("cache", self.cache)
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)

messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)
if self.router:
# Ensure 'model' is not passed directly in kwargs to router.acompletion if self.model is the source
router_kwargs = {key: value for key, value in kwargs.items() if key != 'model'}
all_kwargs = {**self.kwargs, **router_kwargs}

results = await completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
cache=litellm_cache_args,
)
if not self.model:
raise ValueError("self.model must be set when using a router to specify the target model/group.")

# LiteLLM Router handles its own caching and retries.
# DSPy's cache parameters and self.num_retries are not passed here.
results = await self.router.acompletion(
model=self.model, # This tells the router which model/group to use
messages=messages,
**all_kwargs,
)
else:
# Existing logic for non-router path
cache = kwargs.pop("cache", self.cache)
enable_memory_cache = kwargs.pop("cache_in_memory", self.cache_in_memory)
all_kwargs = {**self.kwargs, **kwargs} # ensure kwargs passed to method are combined

completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)

# self.model must be present here as router is None (constructor should ensure this)
if not self.model:
raise ValueError("self.model must be set when not using a router.")

results = await completion(
request=dict(model=self.model, messages=messages, **all_kwargs),
num_retries=self.num_retries,
cache=litellm_cache_args,
)

if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
Expand All @@ -170,7 +250,11 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
settings.usage_tracker.add_usage(self.model, dict(results.usage))
# For router, self.model is the group/alias. For non-router, it's the specific model.
# Usage tracking model key might need adjustment if router provides specific model info in response.
# For now, using self.model for both.
model_key_for_usage = self.model
settings.usage_tracker.add_usage(model_key_for_usage, dict(results.usage))
return results

def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
Expand Down Expand Up @@ -242,7 +326,7 @@ def infer_provider(self) -> Provider:

def dump_state(self):
state_keys = [
"model",
"model", # If router used, this is the target model/group
"model_type",
"cache",
"cache_in_memory",
Expand All @@ -251,7 +335,26 @@ def dump_state(self):
"launch_kwargs",
"train_kwargs",
]
return {key: getattr(self, key) for key in state_keys} | self.kwargs

state = {}
for key in state_keys:
if hasattr(self, key):
state[key] = getattr(self, key)

state["router_is_configured"] = self.router is not None

if self.provider:
state["provider_name"] = self.provider.__class__.__name__
else:
state["provider_name"] = None

# Merge self.kwargs. Prioritize values in 'state' if collisions occur.
final_state = self.kwargs.copy() # Start with a copy of kwargs
final_state.update(state) # Update with collected state; 'state' values overwrite kwargs on collision.
# This means if 'model' was in kwargs, the one from state_keys wins.
# This is generally desirable.

return final_state


def _get_stream_completion_fn(
Expand Down
Loading