Skip to content

Draft/RFC: Adds a Service mixin that can be used to get static type safety for classes #3077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .app import App, Stub
from .client import Client
from .cloud_bucket_mount import CloudBucketMount
from .cls import Cls, parameter
from .cls import Cls, Service, parameter
from .dict import Dict
from .exception import Error
from .file_pattern_matcher import FilePatternMatcher
Expand Down Expand Up @@ -77,6 +77,7 @@
"SandboxSnapshot",
"SchedulerPlacement",
"Secret",
"Service",
"Stub",
"Tunnel",
"Volume",
Expand Down
22 changes: 22 additions & 0 deletions modal/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Collection
from typing import Any, Callable, Optional, TypeVar, Union

import typing_extensions
from google.protobuf.message import Message
from grpclib import GRPCError, Status

Expand Down Expand Up @@ -804,3 +805,24 @@ class A:
"""
# has to return Any to be assignable to any annotation (https://github.com/microsoft/pyright/issues/5102)
return _Parameter(default=default, init=init)


class Service:
# Mixin to provide static types for "service" level methods
# The actual implementation of these methods are currently in the modal.Obj wrapper

def update_autoscaler(self, min_containers: Optional[int] = None): ...

def with_options(
self,
cpu: Optional[Union[float, tuple[float, float]]] = None,
memory: Optional[Union[int, tuple[int, int]]] = None,
gpu: GPU_T = None,
secrets: Collection[_Secret] = (),
volumes: dict[Union[str, os.PathLike], _Volume] = {},
retries: Optional[Union[int, Retries]] = None,
max_containers: Optional[int] = None, # Limit on the number of containers that can be concurrently running.
scaledown_window: Optional[int] = None, # Max amount of time a container can remain idle before scaling down.
timeout: Optional[int] = None,
allow_concurrent_inputs: Optional[int] = None,
) -> typing_extensions.Self: ...
18 changes: 11 additions & 7 deletions test/supports/type_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ async def async_typed_func(b: bool) -> str:
return ""


async_typed_func

should_be_str = async_typed_func.remote(False) # should be blocking without aio
assert_type(should_be_str, str)


@app.cls()
class Cls:
class UserCls(modal.Service):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So beautiful!

@method()
def foo(self, a: str) -> int:
return 1
Expand All @@ -54,20 +52,26 @@ async def bar(self, a: str) -> int:
return 1


instance = Cls()
should_be_int = instance.foo.remote("foo")
service = UserCls()
should_be_int = service.foo.remote("foo")
assert_type(should_be_int, int)

should_be_int = instance.bar.remote("bar")
should_be_int_2 = service.bar.remote("bar")
assert_type(should_be_int, int)

service.update_autoscaler(min_containers=10)
derived_service = service.with_options(cpu=10)
assert_type(derived_service, UserCls)
should_be_int_3 = derived_service.bar.remote(a="123")
assert_type(should_be_int_3, int)


async def async_block() -> None:
should_be_str_2 = await async_typed_func.remote.aio(True)
assert_type(should_be_str_2, str)
should_also_be_str = await async_typed_func.local(False) # local should be the original return type (!)
assert_type(should_also_be_str, str)
should_be_int = await instance.bar.local("bar")
should_be_int = await service.bar.local("bar")
assert_type(should_be_int, int)


Expand Down
Loading