Skip to content

Commit fd5fd14

Browse files
jopemachineclaude
andauthored
refactor(BA-5978): introduce BackendAISchema for per-domain validation errors (#11514)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2649806 commit fd5fd14

80 files changed

Lines changed: 852 additions & 629 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

changes/11514.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Introduce `BackendAISchema`, a Pydantic base whose `model_validate` / `model_validate_json` auto-convert validation failures into a domain-specific `BackendAIError` (HTTP 400) via an overridable `build_validation_error` classmethod, so each model surfaces its own 400 with structured per-field error details instead of raw `pydantic.ValidationError`.

src/ai/backend/agent/agent.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
import zmq.asyncio
5252
from async_timeout import timeout
5353
from cachetools import LRUCache, cached
54-
from pydantic import ValidationError
5554
from tenacity import (
5655
AsyncRetrying,
5756
RetryError,
@@ -158,7 +157,10 @@
158157
AbstractBroadcastEvent,
159158
AbstractEvent,
160159
)
161-
from ai.backend.common.exception import ConfigurationError, VolumeMountFailed
160+
from ai.backend.common.exception import (
161+
ConfigurationError,
162+
VolumeMountFailed,
163+
)
162164
from ai.backend.common.json import (
163165
dump_json,
164166
dump_json_str,
@@ -233,7 +235,6 @@
233235
ImagePullTimeoutError,
234236
ModelDefinitionEmptyError,
235237
ModelDefinitionNotFoundError,
236-
ModelDefinitionValidationError,
237238
ModelFolderNotSpecifiedError,
238239
PortConflictError,
239240
ReservedPortError,
@@ -3292,13 +3293,7 @@ async def _load_model_definition(
32923293
f" vFolder {model_folder.name} (ID {model_folder.vfid})",
32933294
)
32943295

3295-
try:
3296-
parsed = ModelDefinition.model_validate(inlined)
3297-
except ValidationError as e:
3298-
raise ModelDefinitionValidationError(
3299-
"Failed to validate model definition for vFolder"
3300-
f" {model_folder.name} (ID {model_folder.vfid})",
3301-
) from e
3296+
parsed = ModelDefinition.model_validate(inlined)
33023297
if not parsed.models:
33033298
raise ModelDefinitionEmptyError
33043299
model_definition = parsed.model_dump(mode="json")

src/ai/backend/agent/errors/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
ModelDefinitionEmptyError,
2121
ModelDefinitionInvalidYAMLError,
2222
ModelDefinitionNotFoundError,
23-
ModelDefinitionValidationError,
2423
ModelFolderNotSpecifiedError,
2524
PortConflictError,
2625
ReservedPortError,
@@ -64,7 +63,6 @@
6463
"ModelDefinitionEmptyError",
6564
"ModelDefinitionInvalidYAMLError",
6665
"ModelDefinitionNotFoundError",
67-
"ModelDefinitionValidationError",
6866
"ModelFolderNotSpecifiedError",
6967
"PortConflictError",
7068
"ReservedPortError",

src/ai/backend/agent/errors/agent.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,6 @@ def error_code(self) -> ErrorCode:
169169
)
170170

171171

172-
class ModelDefinitionValidationError(BackendAIError, web.HTTPBadRequest):
173-
"""Raised when model definition validation fails."""
174-
175-
error_type = "https://api.backend.ai/probs/agent/model-definition-validation-failed"
176-
error_title = "Model definition validation failed."
177-
178-
def error_code(self) -> ErrorCode:
179-
return ErrorCode(
180-
domain=ErrorDomain.MODEL_SERVICE,
181-
operation=ErrorOperation.ACCESS,
182-
error_detail=ErrorDetail.INVALID_PARAMETERS,
183-
)
184-
185-
186172
class ModelFolderNotSpecifiedError(BackendAIError, web.HTTPBadRequest):
187173
"""Raised when no model virtual folder is specified."""
188174

src/ai/backend/common/api_handlers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
MiddlewareParamParsingFailed,
3333
ParameterNotParsedError,
3434
)
35+
from .types import BackendAISchema
3536

3637
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
3738

@@ -45,21 +46,21 @@ class Sentinel(enum.Enum):
4546
SENTINEL = Sentinel.TOKEN
4647

4748

48-
class BaseRequestModel(BaseModel):
49+
class BaseRequestModel(BackendAISchema):
4950
model_config = ConfigDict(
5051
arbitrary_types_allowed=True,
5152
validate_by_name=True,
5253
)
5354

5455

55-
class BaseFieldModel(BaseModel):
56+
class BaseFieldModel(BackendAISchema):
5657
model_config = ConfigDict(
5758
arbitrary_types_allowed=True,
5859
validate_by_name=True,
5960
)
6061

6162

62-
class BaseResponseModel(BaseModel):
63+
class BaseResponseModel(BackendAISchema):
6364
pass
6465

6566

src/ai/backend/common/config.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,22 @@
55
import sys
66
from collections.abc import Mapping, MutableMapping
77
from pathlib import Path
8-
from typing import Any
8+
from typing import Any, override
99

1010
import humps
1111
import tomli
1212
import trafaret as t
1313
from pydantic import (
1414
AliasChoices,
15-
BaseModel,
1615
ConfigDict,
1716
Field,
1817
field_validator,
1918
)
2019

2120
from . import validators as tx
2221
from .etcd import AsyncEtcd, ConfigScopes
23-
from .exception import ConfigurationError
24-
from .types import RedisHelperConfig
22+
from .exception import BackendAIError, ConfigurationError, ModelDefinitionValidationError
23+
from .types import BackendAISchema, RedisHelperConfig, SchemaValidationFailureInfo
2524

2625
__all__ = (
2726
"ConfigurationError",
@@ -40,7 +39,7 @@
4039
)
4140

4241

43-
class BaseConfigSchema(BaseModel):
42+
class BaseConfigSchema(BackendAISchema):
4443
@staticmethod
4544
def snake_to_kebab_case(string: str) -> str:
4645
return string.replace("_", "-")
@@ -53,7 +52,7 @@ def snake_to_kebab_case(string: str) -> str:
5352
)
5453

5554

56-
class BaseConfigModel(BaseModel):
55+
class BaseConfigModel(BackendAISchema):
5756
@staticmethod
5857
def snake_to_kebab_case(string: str) -> str:
5958
return string.replace("_", "-")
@@ -478,6 +477,14 @@ class ModelDefinition(BaseConfigModel):
478477
description="List of models in the model definition.",
479478
)
480479

480+
@override
481+
@classmethod
482+
def build_validation_error(cls, info: SchemaValidationFailureInfo) -> BackendAIError:
483+
return ModelDefinitionValidationError(
484+
extra_msg=info.summary,
485+
extra_data={"errors": info.errors},
486+
)
487+
481488
def merge(self, override: ModelDefinition) -> ModelDefinition:
482489
"""Merge the given override into this definition, returning a new instance."""
483490
return _merge_definition(self, override)
@@ -532,10 +539,10 @@ class ModelHealthCheckDraft(BaseConfigModel):
532539
initial_delay: float | None = None
533540

534541
def to_resolved(self) -> ModelHealthCheck:
535-
if self.path is None:
536-
raise ValueError("ModelHealthCheck.path is required")
537542
# Drop unset (None) fields so the strict type's ``Field(default=...)``
538543
# declarations remain the single source of truth for default values.
544+
# Missing required fields (e.g. ``path``) surface as the strict
545+
# type's ``BackendAISchemaValidationFailed`` via ``model_validate``.
539546
return ModelHealthCheck.model_validate(self.model_dump(exclude_none=True))
540547

541548

@@ -552,16 +559,15 @@ def _coerce_start_command(cls, value: Any) -> Any:
552559
return _normalize_start_command(value)
553560

554561
def to_resolved(self) -> ModelServiceConfig:
555-
if self.port is None:
556-
raise ValueError("ModelServiceConfig.port is required")
557562
# Drop unset (None) scalars so the strict type's ``Field(default=...)``
558563
# declarations remain the single source of truth for default values;
559564
# resolve the nested ``health_check`` draft explicitly so its own
560-
# required-field check (``path``) fires with a clear error message.
561-
return ModelServiceConfig(
562-
**self.model_dump(exclude_none=True, exclude={"health_check"}),
563-
health_check=self.health_check.to_resolved() if self.health_check else None,
564-
)
565+
# required-field check (``path``) fires through its own
566+
# ``model_validate``. Missing required fields (e.g. ``port``)
567+
# surface as ``BackendAISchemaValidationFailed``.
568+
payload = self.model_dump(exclude_none=True, exclude={"health_check"})
569+
payload["health_check"] = self.health_check.to_resolved() if self.health_check else None
570+
return ModelServiceConfig.model_validate(payload)
565571

566572

567573
class ModelConfigDraft(BaseConfigModel):
@@ -571,25 +577,18 @@ class ModelConfigDraft(BaseConfigModel):
571577
metadata: ModelMetadata | None = None # ModelMetadata is already all-Optional.
572578

573579
def to_resolved(self) -> ModelConfig:
574-
if self.name is None:
575-
raise ValueError("ModelConfig.name is required")
576-
if self.model_path is None:
577-
raise ValueError("ModelConfig.model_path is required")
578580
service = self.service.to_resolved() if self.service else None
579-
if service is not None and service.start_command:
581+
if service is not None and service.start_command and self.model_path is not None:
580582
# ``{model_path}`` placeholders in the variant baseline's
581583
# ``start_command`` are resolved here, at the same moment the
582584
# draft becomes a strict ``ModelConfig`` and ``model_path`` is
583585
# finalized. Placeholders therefore never propagate downstream.
584586
service.start_command = [
585587
token.replace("{model_path}", self.model_path) for token in service.start_command
586588
]
587-
return ModelConfig(
588-
name=self.name,
589-
model_path=self.model_path,
590-
service=service,
591-
metadata=self.metadata,
592-
)
589+
payload = self.model_dump(exclude_none=True, exclude={"service"})
590+
payload["service"] = service
591+
return ModelConfig.model_validate(payload)
593592

594593

595594
def _merge_health_check_draft(
@@ -665,6 +664,14 @@ class ModelDefinitionDraft(BaseConfigModel):
665664

666665
models: list[ModelConfigDraft] | None = None
667666

667+
@override
668+
@classmethod
669+
def build_validation_error(cls, info: SchemaValidationFailureInfo) -> BackendAIError:
670+
return ModelDefinitionValidationError(
671+
extra_msg=info.summary,
672+
extra_data={"errors": info.errors},
673+
)
674+
668675
def merge(self, override: ModelDefinitionDraft) -> ModelDefinitionDraft:
669676
"""Merge ``override`` over ``self`` and return a new draft.
670677
@@ -689,13 +696,9 @@ def merge(self, override: ModelDefinitionDraft) -> ModelDefinitionDraft:
689696
return ModelDefinitionDraft.model_construct(models=merged)
690697

691698
def to_resolved(self) -> ModelDefinition:
692-
"""Build the strict ``ModelDefinition`` from this draft.
693-
694-
Each child draft is converted via its own ``to_resolved`` and the
695-
strict type's constructor performs Pydantic validation; missing
696-
required fields propagate as ``pydantic.ValidationError``.
697-
"""
698-
return ModelDefinition(models=[m.to_resolved() for m in (self.models or [])])
699+
return ModelDefinition.model_validate({
700+
"models": [m.to_resolved() for m in (self.models or [])],
701+
})
699702

700703

701704
def find_config_file(daemon_name: str) -> Path:

src/ai/backend/common/dto/agent/request.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from pydantic import BaseModel, ConfigDict
1+
from pydantic import ConfigDict
22

3+
from ai.backend.common.types import BackendAISchema
34

4-
class BaseAgentRequestModel(BaseModel):
5+
6+
class BaseAgentRequestModel(BackendAISchema):
57
"""Base class for pydantic request payloads on agent RPC v3 methods.
68
79
Mirrors the role of ``ai.backend.common.api_handlers.BaseRequestModel``

src/ai/backend/common/dto/agent/response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from pydantic import BaseModel, ConfigDict, Field
77

88
from ai.backend.common.dto.internal.health import ConnectivityCheckResponse, HealthStatus
9+
from ai.backend.common.types import BackendAISchema
910

1011
T = TypeVar("T")
1112

1213

13-
class BaseAgentResponseModel(BaseModel):
14+
class BaseAgentResponseModel(BackendAISchema):
1415
"""Base class for pydantic response payloads on agent RPC v3 methods.
1516
1617
Counterpart to ``BaseAgentRequestModel`` on the response side.

src/ai/backend/common/exception.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Self
88

99
from aiohttp import web
10+
from pydantic_core import ErrorDetails
1011

1112
from .json import dump_json
1213

@@ -443,6 +444,51 @@ def error_code(self) -> ErrorCode:
443444
)
444445

445446

447+
class BackendAISchemaValidationFailed(BackendAIError, web.HTTPBadRequest):
448+
"""Default 400 raised by :class:`BackendAISchema.build_validation_error`.
449+
450+
Kept distinct from :class:`InvalidAPIParameters` so handlers can
451+
catch one without picking up the other.
452+
"""
453+
454+
error_type = "https://api.backend.ai/probs/schema-validation-failed"
455+
error_title = "Schema validation failed."
456+
457+
def error_code(self) -> ErrorCode:
458+
return ErrorCode(
459+
domain=ErrorDomain.BACKENDAI,
460+
operation=ErrorOperation.PARSING,
461+
error_detail=ErrorDetail.INVALID_PARAMETERS,
462+
)
463+
464+
def errors(self) -> list[ErrorDetails]:
465+
"""Per-field errors in the same shape as
466+
``pydantic.ValidationError.errors()``. Empty when no
467+
``extra_data["errors"]`` is attached."""
468+
if not self.extra_data:
469+
return []
470+
return list(self.extra_data.get("errors") or [])
471+
472+
473+
class ModelDefinitionValidationError(BackendAIError, web.HTTPBadRequest):
474+
"""400 raised by ``ModelDefinition.model_validate`` (via its
475+
:meth:`BackendAISchema.build_validation_error` override).
476+
477+
Lives in ``common`` so ``ModelDefinition`` (also in ``common``) can
478+
construct it without an upward-layer import.
479+
"""
480+
481+
error_type = "https://api.backend.ai/probs/model-definition-validation-failed"
482+
error_title = "Model definition validation failed."
483+
484+
def error_code(self) -> ErrorCode:
485+
return ErrorCode(
486+
domain=ErrorDomain.MODEL_SERVICE,
487+
operation=ErrorOperation.PARSING,
488+
error_detail=ErrorDetail.INVALID_PARAMETERS,
489+
)
490+
491+
446492
class DeprecatedAPI(BackendAIError, web.HTTPBadRequest):
447493
error_type = "https://api.backend.ai/probs/deprecated"
448494
error_title = "This API is deprecated."

0 commit comments

Comments
 (0)