Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from .api.base.handler.dump_params import dump_params
from .api.base.handler.inspect import model_field_is_model, model_field_is_list_of_models
from .api.base.handler.model_provider import ModuleModelProvider, LazyModuleModelProvider
from .api.base.handler.result import serialize_result
from .api.base.handler.version import APIVersion, APIVersionsAdapter
Expand Down Expand Up @@ -205,6 +206,25 @@ def __init__(self, middleware: "Middleware"):
method.__func__.__method_name__ = method_name


def _coerce_mock_result(result, mock_return_model):
"""Convert dict mock results to Pydantic models for generic services."""
if mock_return_model is None:
return result

kind, entry_type = mock_return_model

if kind == 'model' and isinstance(result, dict):
return entry_type.model_construct(**result)

if kind == 'list' and isinstance(result, list):
return [
entry_type.model_construct(**item) if isinstance(item, dict) else item
for item in result
]

return result


class Middleware(LoadPluginsMixin, ServiceCallMixin, CallMixin):

CONSOLE_ONCE_PATH = f'{MIDDLEWARE_RUN_DIR}/.middlewared-console-once'
Expand Down Expand Up @@ -1504,20 +1524,53 @@ def set_mock(self, name: str, args: list, mock: typing.Callable):
raise ValueError(f'{name!r} is already mocked with {args!r}')

serviceobj, methodobj = self.get_method(name)
mock_return_model = self._get_mock_return_model(serviceobj, methodobj)

if inspect.iscoroutinefunction(mock):
async def f(*args, **kwargs):
return await mock(serviceobj, *args, **kwargs)
result = await mock(serviceobj, *args, **kwargs)
return _coerce_mock_result(result, mock_return_model)
else:
def f(*args, **kwargs):
return mock(serviceobj, *args, **kwargs)
result = mock(serviceobj, *args, **kwargs)
return _coerce_mock_result(result, mock_return_model)

if hasattr(methodobj, '_job'):
f._job = methodobj._job
copy_function_metadata(mock, f)

self.mocks[name].append((args, f))

@staticmethod
def _get_mock_return_model(serviceobj, methodobj):
"""Get the Pydantic model for auto-wrapping mock return values.

Only applies to generic services (``_config.generic = True``) where
methods return Pydantic model instances at runtime.
"""
config = getattr(serviceobj, '_config', None)
if not getattr(config, 'generic', False):
return None

if not hasattr(methodobj, 'new_style_returns'):
return None

try:
annotation = methodobj.new_style_returns.model_fields['result'].annotation
except (AttributeError, KeyError):
return None

if model := model_field_is_model(annotation):
return ('model', model)

if (
(model := model_field_is_list_of_models(annotation))
and isinstance(model, type) and issubclass(model, BaseModel)
):
return ('list', model)

return None

def remove_mock(self, name, args):
for i, (_args, _mock) in enumerate(self.mocks[name]):
if args == _args:
Expand Down
Loading