From 06c13a50b8067f390bc2096a8d2fc229eac44fad Mon Sep 17 00:00:00 2001 From: Elias Freider Date: Fri, 2 May 2025 15:17:00 +0200 Subject: [PATCH] Adds a Service mixin that can be used to get static type safety for class methods --- modal/__init__.py | 3 ++- modal/cls.py | 22 ++++++++++++++++++++++ test/supports/type_assertions.py | 18 +++++++++++------- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/modal/__init__.py b/modal/__init__.py index 70fd0aeda1..ff40af663d 100644 --- a/modal/__init__.py +++ b/modal/__init__.py @@ -14,7 +14,7 @@ from .app import App, Stub from .client import Client from .cloud_bucket_mount import CloudBucketMount - from .cls import Cls, parameter + from .cls import Cls, Service, parameter from .dict import Dict from .exception import Error from .file_pattern_matcher import FilePatternMatcher @@ -77,6 +77,7 @@ "SandboxSnapshot", "SchedulerPlacement", "Secret", + "Service", "Stub", "Tunnel", "Volume", diff --git a/modal/cls.py b/modal/cls.py index c76ed47469..767d841d46 100644 --- a/modal/cls.py +++ b/modal/cls.py @@ -6,6 +6,7 @@ from collections.abc import Collection from typing import Any, Callable, Optional, TypeVar, Union +import typing_extensions from google.protobuf.message import Message from grpclib import GRPCError, Status @@ -804,3 +805,24 @@ class A: """ # has to return Any to be assignable to any annotation (https://github.com/microsoft/pyright/issues/5102) return _Parameter(default=default, init=init) + + +class Service: + # Mixin to provide static types for "service" level methods + # The actual implementation of these methods are currently in the modal.Obj wrapper + + def update_autoscaler(self, min_containers: Optional[int] = None): ... + + def with_options( + self, + cpu: Optional[Union[float, tuple[float, float]]] = None, + memory: Optional[Union[int, tuple[int, int]]] = None, + gpu: GPU_T = None, + secrets: Collection[_Secret] = (), + volumes: dict[Union[str, os.PathLike], _Volume] = {}, + retries: Optional[Union[int, Retries]] = None, + max_containers: Optional[int] = None, # Limit on the number of containers that can be concurrently running. + scaledown_window: Optional[int] = None, # Max amount of time a container can remain idle before scaling down. + timeout: Optional[int] = None, + allow_concurrent_inputs: Optional[int] = None, + ) -> typing_extensions.Self: ... diff --git a/test/supports/type_assertions.py b/test/supports/type_assertions.py index 5ac9a910e6..d9d7757470 100644 --- a/test/supports/type_assertions.py +++ b/test/supports/type_assertions.py @@ -37,14 +37,12 @@ async def async_typed_func(b: bool) -> str: return "" -async_typed_func - should_be_str = async_typed_func.remote(False) # should be blocking without aio assert_type(should_be_str, str) @app.cls() -class Cls: +class UserCls(modal.Service): @method() def foo(self, a: str) -> int: return 1 @@ -54,20 +52,26 @@ async def bar(self, a: str) -> int: return 1 -instance = Cls() -should_be_int = instance.foo.remote("foo") +service = UserCls() +should_be_int = service.foo.remote("foo") assert_type(should_be_int, int) -should_be_int = instance.bar.remote("bar") +should_be_int_2 = service.bar.remote("bar") assert_type(should_be_int, int) +service.update_autoscaler(min_containers=10) +derived_service = service.with_options(cpu=10) +assert_type(derived_service, UserCls) +should_be_int_3 = derived_service.bar.remote(a="123") +assert_type(should_be_int_3, int) + async def async_block() -> None: should_be_str_2 = await async_typed_func.remote.aio(True) assert_type(should_be_str_2, str) should_also_be_str = await async_typed_func.local(False) # local should be the original return type (!) assert_type(should_also_be_str, str) - should_be_int = await instance.bar.local("bar") + should_be_int = await service.bar.local("bar") assert_type(should_be_int, int)