From 06c13a50b8067f390bc2096a8d2fc229eac44fad Mon Sep 17 00:00:00 2001
From: Elias Freider
Date: Fri, 2 May 2025 15:17:00 +0200
Subject: [PATCH] Adds a Service mixin that can be used to get static type
safety for class methods
---
modal/__init__.py | 3 ++-
modal/cls.py | 22 ++++++++++++++++++++++
test/supports/type_assertions.py | 18 +++++++++++-------
3 files changed, 35 insertions(+), 8 deletions(-)
diff --git a/modal/__init__.py b/modal/__init__.py
index 70fd0aeda1..ff40af663d 100644
--- a/modal/__init__.py
+++ b/modal/__init__.py
@@ -14,7 +14,7 @@
from .app import App, Stub
from .client import Client
from .cloud_bucket_mount import CloudBucketMount
- from .cls import Cls, parameter
+ from .cls import Cls, Service, parameter
from .dict import Dict
from .exception import Error
from .file_pattern_matcher import FilePatternMatcher
@@ -77,6 +77,7 @@
"SandboxSnapshot",
"SchedulerPlacement",
"Secret",
+ "Service",
"Stub",
"Tunnel",
"Volume",
diff --git a/modal/cls.py b/modal/cls.py
index c76ed47469..767d841d46 100644
--- a/modal/cls.py
+++ b/modal/cls.py
@@ -6,6 +6,7 @@
from collections.abc import Collection
from typing import Any, Callable, Optional, TypeVar, Union
+import typing_extensions
from google.protobuf.message import Message
from grpclib import GRPCError, Status
@@ -804,3 +805,24 @@ class A:
"""
# has to return Any to be assignable to any annotation (https://github.com/microsoft/pyright/issues/5102)
return _Parameter(default=default, init=init)
+
+
+class Service:
+ # Mixin to provide static types for "service" level methods
+ # The actual implementation of these methods are currently in the modal.Obj wrapper
+
+ def update_autoscaler(self, min_containers: Optional[int] = None): ...
+
+ def with_options(
+ self,
+ cpu: Optional[Union[float, tuple[float, float]]] = None,
+ memory: Optional[Union[int, tuple[int, int]]] = None,
+ gpu: GPU_T = None,
+ secrets: Collection[_Secret] = (),
+ volumes: dict[Union[str, os.PathLike], _Volume] = {},
+ retries: Optional[Union[int, Retries]] = None,
+ max_containers: Optional[int] = None, # Limit on the number of containers that can be concurrently running.
+ scaledown_window: Optional[int] = None, # Max amount of time a container can remain idle before scaling down.
+ timeout: Optional[int] = None,
+ allow_concurrent_inputs: Optional[int] = None,
+ ) -> typing_extensions.Self: ...
diff --git a/test/supports/type_assertions.py b/test/supports/type_assertions.py
index 5ac9a910e6..d9d7757470 100644
--- a/test/supports/type_assertions.py
+++ b/test/supports/type_assertions.py
@@ -37,14 +37,12 @@ async def async_typed_func(b: bool) -> str:
return ""
-async_typed_func
-
should_be_str = async_typed_func.remote(False) # should be blocking without aio
assert_type(should_be_str, str)
@app.cls()
-class Cls:
+class UserCls(modal.Service):
@method()
def foo(self, a: str) -> int:
return 1
@@ -54,20 +52,26 @@ async def bar(self, a: str) -> int:
return 1
-instance = Cls()
-should_be_int = instance.foo.remote("foo")
+service = UserCls()
+should_be_int = service.foo.remote("foo")
assert_type(should_be_int, int)
-should_be_int = instance.bar.remote("bar")
+should_be_int_2 = service.bar.remote("bar")
assert_type(should_be_int, int)
+service.update_autoscaler(min_containers=10)
+derived_service = service.with_options(cpu=10)
+assert_type(derived_service, UserCls)
+should_be_int_3 = derived_service.bar.remote(a="123")
+assert_type(should_be_int_3, int)
+
async def async_block() -> None:
should_be_str_2 = await async_typed_func.remote.aio(True)
assert_type(should_be_str_2, str)
should_also_be_str = await async_typed_func.local(False) # local should be the original return type (!)
assert_type(should_also_be_str, str)
- should_be_int = await instance.bar.local("bar")
+ should_be_int = await service.bar.local("bar")
assert_type(should_be_int, int)