Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ test = [
"pytest-env>=1.1.5",
"pytest-rerunfailures>=15.0",
"chromadb>=0.6.0",
"dirty-equals>=0.9.0",
]
dev = ["marvin[test]", "copychat>=0.5.2", "ipython>=8.12.3", "pdbpp>=0.10.3"]

Expand Down
9 changes: 8 additions & 1 deletion src/marvin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from marvin.fns.classify import classify, classify_async
from marvin.fns.extract import extract, extract_async
from marvin.fns.cast import cast, cast_async
from marvin.fns.generate import generate, generate_async
from marvin.fns.generate import (
generate,
generate_async,
generate_schema,
generate_schema_async,
)
from marvin.fns.fn import fn
from marvin.fns.say import say, say_async
from marvin.fns.summarize import summarize, summarize_async
Expand All @@ -45,6 +50,8 @@
"fn",
"generate",
"generate_async",
"generate_schema",
"generate_schema_async",
"instructions",
"run",
"run_async",
Expand Down
69 changes: 52 additions & 17 deletions src/marvin/fns/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,11 @@
from marvin.agents.agent import Agent
from marvin.thread import Thread
from marvin.utilities.asyncio import run_sync
from marvin.utilities.jsonschema import JSONSchema
from marvin.utilities.types import TargetType

T = TypeVar("T")

PROMPT = """
You are an expert data generator that always creates high-quality, random
examples of a description or type. The data you produce is relied on for
testing, examples, demonstrations, and more. You use inference or deduction
whenever necessary to supply missing or omitted data. You will be given
instructions or a type format, as well as a number of entities to generate.

Unless the user explicitly says otherwise, assume they are request a VARIED and
REALISTIC selection of useful outputs that meet their criteria. However, you
should prefer common responses to uncommon ones.

If the user provides additional instructions or a description, assume they are
looking for examples that satisfy the description. Do not provide more
information than the user requests. For example, if they ask for various
technologies, give their names but do not explain what each technology is."""


async def generate_async(
target: TargetType[T] = str,
Expand Down Expand Up @@ -61,7 +46,23 @@ async def generate_async(

task_context = context or {}
task_context["Number to generate"] = n
prompt = PROMPT

prompt = """
You are an expert data generator that always creates high-quality, random
examples of a description or type. The data you produce is relied on for
testing, examples, demonstrations, and more. You use inference or deduction
whenever necessary to supply missing or omitted data. You will be given
instructions or a type format, as well as a number of entities to generate.

Unless the user explicitly says otherwise, assume they are request a VARIED and
REALISTIC selection of useful outputs that meet their criteria. However, you
should prefer common responses to uncommon ones.

If the user provides additional instructions or a description, assume they are
looking for examples that satisfy the description. Do not provide more
information than the user requests. For example, if they ask for various
technologies, give their names but do not explain what each technology is.
"""
if instructions:
prompt += f"\n\nYou must follow these instructions for your generation:\n{instructions}"

Expand Down Expand Up @@ -115,3 +116,37 @@ def generate(
context=context,
),
)


async def generate_schema_async(
instructions: str,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict[str, Any] | None = None,
) -> JSONSchema:
"""Generates a JSON schema from a description."""

prompt = """
Generate a JSON schema that matches the following description:
{instructions}
""".format(instructions=instructions)

task = marvin.Task[JSONSchema](
name="JSONSchema Generation",
instructions=prompt,
context=context,
result_type=JSONSchema,
agents=[agent] if agent else None,
)

return await task.run_async(thread=thread, handlers=[])


def generate_schema(
instructions: str,
agent: Agent | None = None,
thread: Thread | str | None = None,
context: dict[str, Any] | None = None,
) -> JSONSchema:
"""Generates a JSON schema from a description."""
return run_sync(generate_schema_async(instructions, agent, thread, context))
55 changes: 45 additions & 10 deletions src/marvin/utilities/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Literal,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
Expand All @@ -60,8 +61,9 @@
StringConstraints,
model_validator,
)
from typing_extensions import NotRequired

__all__ = ["jsonschema_to_type", "merge_defaults"]
__all__ = ["jsonschema_to_type", "JSONSchema"]

T = TypeVar("T")

Expand Down Expand Up @@ -287,12 +289,10 @@ def sanitize_name(name: str) -> str:
cleaned = re.sub(r"[^0-9a-zA-Z_]", "_", name)
# Step 2: deduplicate underscores
cleaned = re.sub(r"__+", "_", cleaned)
# Step 3: lowercase
cleaned = cleaned.lower()
# Step 4: if the first char of original name isn't a letter, prepend field_
# Step 3: if the first char of original name isn't a letter, prepend field_
if not name or not re.match(r"[a-zA-Z]", name[0]):
cleaned = f"field_{cleaned}"
# Step 5: deduplicate again and strip trailing underscores
# Step 4: deduplicate again and strip trailing underscores
cleaned = re.sub(r"__+", "_", cleaned).strip("_")
return cleaned

Expand Down Expand Up @@ -329,15 +329,17 @@ def create_dataclass(
) -> type:
"""Create dataclass from object schema."""
name = name or schema.get("title", "Root")
# Sanitize name for class creation
sanitized_name = sanitize_name(name)
schema_hash = hash_schema(schema)
cache_key = (schema_hash, name)
cache_key = (schema_hash, sanitized_name)
original_schema = schema.copy() # Store copy for validator

# Return existing class if already built
if cache_key in _classes:
existing = _classes[cache_key]
if existing is None:
return ForwardRef(name)
return ForwardRef(sanitized_name)
return existing

# Place placeholder for recursive references
Expand All @@ -346,7 +348,7 @@ def create_dataclass(
if "$ref" in schema:
ref = schema["$ref"]
if ref == "#":
return ForwardRef(name)
return ForwardRef(sanitized_name)
schema = resolve_ref(ref, schemas or {})

properties = schema.get("properties", {})
Expand All @@ -358,7 +360,7 @@ def create_dataclass(

# Check for self-reference in property
if prop_schema.get("$ref") == "#":
field_type = ForwardRef(name)
field_type = ForwardRef(sanitized_name)
else:
field_type = schema_to_type(prop_schema, schemas)

Expand Down Expand Up @@ -388,7 +390,7 @@ def create_dataclass(
else:
fields.append((field_name, Optional[field_type], field_def))

cls = make_dataclass(name, fields, kw_only=True)
cls = make_dataclass(sanitized_name, fields, kw_only=True)

# Add model validator for defaults
@model_validator(mode="before")
Expand Down Expand Up @@ -464,3 +466,36 @@ def merge_defaults(
)

return result


class JSONSchema(TypedDict):
type: NotRequired[Union[str, List[str]]]
properties: NotRequired[Dict[str, "JSONSchema"]]
required: NotRequired[List[str]]
additionalProperties: NotRequired[Union[bool, "JSONSchema"]]
items: NotRequired[Union["JSONSchema", List["JSONSchema"]]]
enum: NotRequired[List[Any]]
const: NotRequired[Any]
default: NotRequired[Any]
description: NotRequired[str]
title: NotRequired[str]
examples: NotRequired[List[Any]]
format: NotRequired[str]
allOf: NotRequired[List["JSONSchema"]]
anyOf: NotRequired[List["JSONSchema"]]
oneOf: NotRequired[List["JSONSchema"]]
not_: NotRequired["JSONSchema"]
definitions: NotRequired[Dict[str, "JSONSchema"]]
dependencies: NotRequired[Dict[str, Union["JSONSchema", List[str]]]]
pattern: NotRequired[str]
minLength: NotRequired[int]
maxLength: NotRequired[int]
minimum: NotRequired[Union[int, float]]
maximum: NotRequired[Union[int, float]]
exclusiveMinimum: NotRequired[Union[int, float]]
exclusiveMaximum: NotRequired[Union[int, float]]
multipleOf: NotRequired[Union[int, float]]
uniqueItems: NotRequired[bool]
minItems: NotRequired[int]
maxItems: NotRequired[int]
additionalItems: NotRequired[Union[bool, "JSONSchema"]]
32 changes: 32 additions & 0 deletions tests/ai/fns/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from dirty_equals import IsPartialDict
from pydantic import BaseModel, Field

import marvin
Expand Down Expand Up @@ -43,3 +44,34 @@ async def test_type_is_string_if_only_instructions_given(self):
result,
"a list of two major cities in California, both given as strings",
)


class TestGenerateSchema:
async def test_generate_list_of_integers_schema(self):
result = await marvin.generate_schema_async(
instructions="a list that contains exactly three integers",
)
assert result == {
"type": "array",
"items": {"type": "integer"},
"minItems": 3,
"maxItems": 3,
}

async def test_generate_schema_for_movie(self):
result = await marvin.generate_schema_async(
instructions="a movie with a title, director, and release_year",
)
assert result == IsPartialDict(
{
"type": "object",
"properties": IsPartialDict(
{
"title": {"type": "string"},
"director": {"type": "string"},
"release_year": {"type": "integer"},
}
),
"required": ["title", "director", "release_year"],
}
)
19 changes: 19 additions & 0 deletions tests/basic/utilities/test_jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,3 +1144,22 @@ def test_name_caching_with_different_titles(self):
assert Type1 is not Type2
assert Type1.__name__ == "Type1"
assert Type2.__name__ == "Type2"

def test_recursive_schema_with_invalid_python_name(self):
"""Test that recursive schemas work with titles that aren't valid Python identifiers"""
schema = {
"type": "object",
"title": "My Complex Type!",
"properties": {"name": {"type": "string"}, "child": {"$ref": "#"}},
}
Type = jsonschema_to_type(schema)
# The class should get a sanitized name
assert Type.__name__ == "My_Complex_Type"
# Create an instance to verify the recursive reference works
validator = TypeAdapter(Type)
result = validator.validate_python(
{"name": "parent", "child": {"name": "child", "child": None}}
)
assert result.name == "parent"
assert result.child.name == "child"
assert result.child.child is None
12 changes: 12 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading