Skip to content

Add PatchSchema #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions ninja/patch_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import (
Any,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, create_model, model_validator

ModelT = TypeVar("ModelT", bound=BaseModel)

# Type alias for patched models to help with type checking
# This allows using cast(PatchedModel, PatchSchema[SomeModel]()) to properly type the model_dump() method
PatchedModel = BaseModel


class PatchSchema(Generic[ModelT]):
"""Generate a patchable version of a Pydantic model.

Makes all fields optional, but doesn't allow None values unless the field was originally defined as Optional.
This allows for partial updates where fields can be omitted or provided with legitimate values.

Example:
Given a schema:
class ExampleSchema(BaseModel):
example_field: str
optional_field: Optional[str] = None

PatchSchema[ExampleSchema] will allow:
- PatchSchema[ExampleSchema]() (no fields provided)
- PatchSchema[ExampleSchema](example_field="example") (field provided)
- PatchSchema[ExampleSchema](optional_field=None) (None allowed for originally optional fields)

But will not allow:
- PatchSchema[ExampleSchema](example_field=None) (None is not allowed for non-optional fields)

Usage:
# Define a regular schema
class UserSchema(BaseModel):
name: str
email: str
avatar_url: Optional[str] = None

# Create a patchable version that allows partial updates
PatchUserSchema = PatchSchema[UserSchema]

# Use the patched schema for partial updates
patch_data = PatchUserSchema(name="New Name") # Only updates the name field
patch_data = PatchUserSchema(avatar_url=None) # Can set avatar_url to None since it's optional
"""

def __new__(
cls,
*args: Any,
**kwargs: Any,
) -> "PatchSchema[ModelT]":
"""Cannot instantiate directly."""
raise TypeError("Cannot instantiate abstract PatchSchema class.")

def __init_subclass__(
cls,
*args: Any,
**kwargs: Any,
) -> None:
"""Cannot subclass."""
raise TypeError(f"Cannot subclass {cls.__module__}.PatchSchema")

@classmethod
def _is_optional_type(cls, annotation):
"""Check if a type annotation is Optional[X]."""
if get_origin(annotation) is Union:
args = get_args(annotation)
return type(None) in args
return False

@classmethod
def __class_getitem__(
cls,
wrapped_class: Type[ModelT],
) -> Type[PatchedModel]:
"""Convert model to a patchable model where fields are optional but can't be None unless originally Optional."""

# Create field definitions for the new model
fields = {}
# Keep track of which fields were originally optional
originally_optional_fields = set()

# Access model_fields through instance property, not class attribute
model_fields = getattr(wrapped_class, "model_fields", {})
for field_name, field_info in model_fields.items():
# Make the field optional by setting a default value
annotation = field_info.annotation

# Check if the field was originally optional
if cls._is_optional_type(annotation):
originally_optional_fields.add(field_name)

fields[field_name] = (Optional[annotation], None)

# Create the new model class
class PatchModel(BaseModel):
model_config = {"extra": "ignore", "arbitrary_types_allowed": True}

@model_validator(mode="before")
@classmethod
def validate_no_none_values(cls, data):
if isinstance(data, dict):
# Check for explicit None values and raise error for non-optional fields
for key, value in list(data.items()):
if value is None and key not in originally_optional_fields:
raise ValueError(f"Field '{key}' cannot be None")
# Keep only non-None values
return {
k: v
for k, v in data.items()
if v is not None or k in originally_optional_fields
}
return data

# We don't need a custom schema generator anymore since Pydantic v2 uses anyOf for optional fields

def model_dump(self, **kwargs) -> Dict[str, Any]:
# Filter out None values from the serialized object
# Only include fields that were explicitly set (not default None values)
dump = super().model_dump(**kwargs)
# Get fields that were explicitly set (excluding default None values)
fields_set = self.model_fields_set
return {k: v for k, v in dump.items() if k in fields_set}

patched_model = create_model(
f"Patched{wrapped_class.__name__}", __base__=PatchModel, **fields
)

# Pass the originally_optional_fields to the patched model
patched_model._originally_optional_fields = originally_optional_fields

# Fix return type by using explicit cast to match the declared return type
return patched_model # type: ignore
186 changes: 186 additions & 0 deletions tests/test_patch_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from typing import Optional, cast

import pytest
from pydantic import BaseModel, ValidationError

from ninja import NinjaAPI, Schema
from ninja.patch_schema import PatchedModel, PatchSchema
from ninja.testing import TestClient


class UserSchema(BaseModel):
name: str
email: str
age: int
avatar_url: Optional[str] = None


class UserSchemaWithDefault(BaseModel):
name: str = "Default"
email: str
age: int
avatar_url: Optional[str] = None


def test_patch_schema_basic():
PatchUserSchema = PatchSchema[UserSchema]

# Test empty initialization
patch_data = cast(PatchedModel, PatchUserSchema())
assert patch_data.model_dump() == {}

# Test field initialization
patch_data = cast(PatchedModel, PatchUserSchema(name="New Name"))
assert patch_data.model_dump() == {"name": "New Name"}

# Test multiple fields
patch_data = cast(
PatchedModel, PatchUserSchema(name="New Name", email="[email protected]")
)
assert patch_data.model_dump() == {"name": "New Name", "email": "[email protected]"}


def test_patch_schema_with_optional_fields():
PatchUserSchema = PatchSchema[UserSchema]

# Test setting optional field to None
patch_data = cast(PatchedModel, PatchUserSchema(avatar_url=None))
assert patch_data.model_dump() == {"avatar_url": None}

# Test setting optional field to value
patch_data = cast(
PatchedModel, PatchUserSchema(avatar_url="https://example.com/avatar.png")
)
assert patch_data.model_dump() == {"avatar_url": "https://example.com/avatar.png"}


def test_patch_schema_none_validation():
PatchUserSchema = PatchSchema[UserSchema]

# Non-optional fields should not allow None
with pytest.raises(ValidationError) as exc_info:
PatchUserSchema(name=None)

assert "Field 'name' cannot be None" in str(exc_info.value)

with pytest.raises(ValidationError) as exc_info:
PatchUserSchema(email=None)

assert "Field 'email' cannot be None" in str(exc_info.value)


def test_patch_schema_with_defaults():
PatchUserSchemaWithDefault = PatchSchema[UserSchemaWithDefault]

# Default values should not be included in the output unless explicitly set
patch_data = cast(PatchedModel, PatchUserSchemaWithDefault())
assert patch_data.model_dump() == {}

patch_data = cast(PatchedModel, PatchUserSchemaWithDefault(name="Custom Name"))
assert patch_data.model_dump() == {"name": "Custom Name"}


# API integration tests
api = NinjaAPI()
client = TestClient(api)


class UserSchemaAPI(Schema):
name: str
email: str
age: int
avatar_url: Optional[str] = None


@api.post("/users")
def create_user(request, data: UserSchemaAPI):
return data


@api.patch("/users/{user_id}")
def update_user(request, user_id: int, data: PatchSchema[UserSchemaAPI]):
# Return the data and its type to verify it's working correctly
return {"id": user_id, "data": data.model_dump(), "type": str(type(data).__name__)}


def test_api_integration():
# First create a user
create_response = client.post(
"/users", json={"name": "Test User", "email": "[email protected]", "age": 30}
)
assert create_response.status_code == 200

# Test partial update with patch
patch_response = client.patch("/users/1", json={"name": "Updated Name"})
assert patch_response.status_code == 200
assert patch_response.json() == {
"id": 1,
"data": {"name": "Updated Name"},
"type": f"Patched{UserSchemaAPI.__name__}",
}

# Test multiple fields update
patch_response = client.patch("/users/1", json={"name": "New Name", "age": 31})
assert patch_response.status_code == 200
assert patch_response.json() == {
"id": 1,
"data": {"name": "New Name", "age": 31},
"type": f"Patched{UserSchemaAPI.__name__}",
}

# Test optional field set to null
patch_response = client.patch("/users/1", json={"avatar_url": None})
assert patch_response.status_code == 200
assert patch_response.json() == {
"id": 1,
"data": {"avatar_url": None},
"type": f"Patched{UserSchemaAPI.__name__}",
}

# Test validation error when setting non-optional field to null
error_response = client.patch("/users/1", json={"name": None})
assert error_response.status_code == 422 # Validation error


def test_direct_instantiation_error():
with pytest.raises(TypeError) as exc_info:
PatchSchema()

assert "Cannot instantiate abstract PatchSchema class" in str(exc_info.value)


def test_subclass_error():
with pytest.raises(TypeError) as exc_info:

class MyPatchSchema(PatchSchema):
pass

assert "Cannot subclass" in str(exc_info.value)


def test_openapi_schema():
"""Test that the OpenAPI schema for a patched model is correctly generated."""
schema = api.get_openapi_schema()
patched_schema = schema["components"]["schemas"][f"Patched{UserSchemaAPI.__name__}"]

assert patched_schema["type"] == "object"
assert "properties" in patched_schema

# Check that name is optional in the schema
assert "name" in patched_schema["properties"]

# In Pydantic v2, optional fields use anyOf with multiple types including null
name_prop = patched_schema["properties"]["name"]
assert "anyOf" in name_prop
assert any(item["type"] == "string" for item in name_prop["anyOf"])
assert any(item["type"] == "null" for item in name_prop["anyOf"])

# No required fields in patched schema
assert "required" not in patched_schema or "name" not in patched_schema["required"]

# Check that avatar_url is still optional
assert "avatar_url" in patched_schema["properties"]
avatar_prop = patched_schema["properties"]["avatar_url"]
assert "anyOf" in avatar_prop
assert any(item["type"] == "string" for item in avatar_prop["anyOf"])
assert any(item["type"] == "null" for item in avatar_prop["anyOf"])
Loading