Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/marvin/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,17 @@ class UsageType(TypeDecorator):
impl = JSON
cache_ok = True

def process_bind_param(self, value: Usage | None, dialect) -> dict | None:
def process_bind_param(
self, value: Usage | None, dialect: Any
) -> dict[str, Any] | None:
"""Convert Usage to JSON before storing in DB."""
if value is None:
return None
return usage_adapter.dump_python(value, mode="json")

def process_result_value(self, value: dict | None, dialect) -> Usage | None:
def process_result_value(
self, value: dict[str, Any] | None, dialect: Any
) -> Usage | None:
"""Convert JSON back to Usage when loading from DB."""
if value is None:
return None
Expand Down
4 changes: 2 additions & 2 deletions src/marvin/fns/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def cast_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T:
"""Asynchronously transforms input data into a specific type using a language model.

Expand Down Expand Up @@ -91,7 +91,7 @@ def cast(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T:
"""Transforms input data into a specific type using a language model.

Expand Down
12 changes: 6 additions & 6 deletions src/marvin/fns/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T: ...


Expand All @@ -41,7 +41,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]: ...


Expand All @@ -52,7 +52,7 @@ async def classify_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T | list[T]:
"""Asynchronously classifies input data into one or more predefined labels using a language model.

Expand Down Expand Up @@ -138,7 +138,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T: ...


Expand All @@ -151,7 +151,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]: ...


Expand All @@ -162,7 +162,7 @@ def classify(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> T | list[T]:
"""Classifies input data into one or more predefined labels using a language model.

Expand Down
4 changes: 2 additions & 2 deletions src/marvin/fns/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def extract_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Extracts entities of a specific type from the provided data.

Expand Down Expand Up @@ -80,7 +80,7 @@ def extract(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Extracts entities of a specific type from the provided data.

Expand Down
6 changes: 3 additions & 3 deletions src/marvin/fns/generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar, cast
from typing import Any, TypeVar, cast

from pydantic import conlist

Expand Down Expand Up @@ -34,7 +34,7 @@ async def generate_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description asynchronously.

Expand Down Expand Up @@ -82,7 +82,7 @@ def generate(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> list[T]:
"""Generates examples of a specific type or matching a description.

Expand Down
4 changes: 2 additions & 2 deletions src/marvin/fns/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def run_async(
agents: list[Actor] | None = None,
handlers: list[Handler | AsyncHandler] | None = None,
raise_on_failure: bool = True,
**kwargs,
**kwargs: Any,
) -> T:
task = Task[result_type](
instructions=instructions,
Expand All @@ -79,7 +79,7 @@ def run(
agents: list[Actor] | None = None,
raise_on_failure: bool = True,
handlers: list[Handler | AsyncHandler] | None = None,
**kwargs,
**kwargs: Any,
) -> T:
return marvin.utilities.asyncio.run_sync(
run_async(
Expand Down
6 changes: 4 additions & 2 deletions src/marvin/fns/say.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import marvin
from marvin.agents.agent import Agent
from marvin.thread import Thread, get_thread
Expand All @@ -9,7 +11,7 @@ async def say_async(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> str:
"""Responds to a user message in a conversational way.

Expand Down Expand Up @@ -52,7 +54,7 @@ def say(
instructions: str | None = None,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict | None = None,
context: dict[str, Any] | None = None,
) -> str:
"""Responds to a user message in a conversational way.

Expand Down
36 changes: 16 additions & 20 deletions src/marvin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Literal

from pydantic import Field, field_validator, model_validator
from pydantic import Field, ValidationInfo, field_validator, model_validator
from pydantic_ai.models import KnownModelName
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import Self
Expand All @@ -21,7 +21,7 @@ class Settings(BaseSettings):
case_sensitive=False,
env_file=".env",
env_file_encoding="utf-8",
extra="forbid",
extra="ignore",
validate_assignment=True,
)

Expand All @@ -45,31 +45,27 @@ def validate_home_path(cls, v: Path) -> Path:
description="Path to the database file. Defaults to `home_path / 'marvin.db'`.",
)

@model_validator(mode="after")
def validate_database_url(self) -> Self:
@field_validator("database_url")
@classmethod
def validate_database_url(cls, v: str | None, info: ValidationInfo) -> str:
"""Set and validate the database path."""
home_path = info.data.get("home_path")

# Set default if not provided
if self.database_url is None:
self.__dict__["database_url"] = str(self.home_path / "marvin.db")
return self
if v is None:
if not home_path:
raise ValueError("home_path must be set before database_url")
return str(home_path / "marvin.db")

# Handle in-memory database
if self.database_url == ":memory:":
return self

# Convert to Path for validation
path = Path(self.database_url)
if v == ":memory:":
return v

# Expand user and resolve to absolute path
path = path.expanduser().resolve()

# Ensure parent directory exists
# Convert to Path for validation and ensure parent directory exists
path = Path(v).expanduser().resolve()
path.parent.mkdir(parents=True, exist_ok=True)

# Store result as string
self.__dict__["database_url"] = str(path)

return self
return str(path)

# ------------ Logging settings ------------

Expand Down
4 changes: 0 additions & 4 deletions tests/ai/fns/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ def test_str_target_if_only_instructions_provided(self):
assert isinstance(result, str)
assert result == "1"

def test_error_if_no_target_and_no_instructions(self):
with pytest.raises(ValueError):
marvin.cast("one")


class TestAsync:
async def test_cast_text_to_int(self):
Expand Down
1 change: 1 addition & 0 deletions tests/basic/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_simple_run_with_result_type(test_model: TestModel):
assert result == 1


@pytest.mark.skip(reason="TODO: what is the expected behavior here?")
def test_simple_run_with_wrong_result_type(test_model: TestModel):
task = marvin.Task("Test task", result_type=int)
test_model.custom_result_args = dict(task_id=task.id, result="hello world")
Expand Down
30 changes: 30 additions & 0 deletions tests/settings/test_settings_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from marvin.settings import Settings


def test_database_url_default():
settings = Settings()
assert settings.database_url is not None
assert settings.database_url.endswith("/.marvin/marvin.db")


@pytest.mark.parametrize(
"env_var_value, expected_ending",
[
(":memory:", ":memory:"),
("~/.marvin/test.db", "/.marvin/test.db"),
],
)
def test_database_url_set_from_env_var(
monkeypatch: pytest.MonkeyPatch,
env_var_value: str,
expected_ending: str,
):
monkeypatch.setenv("MARVIN_DATABASE_URL", env_var_value)
settings = Settings()
if expected_ending == ":memory:":
assert settings.database_url == expected_ending
else:
assert settings.database_url is not None
assert settings.database_url.endswith(expected_ending)