Skip to content

Commit 6fa0af9

Browse files
committed
replace requests with httpx and factor out clients (#1574, #1707, #1714)
* input downloads, output uploads, and webhooks are now handled by ClientManager, which persists for the lifetime of runner, allowing us to reuse connections, which may significantly help with large uploads. * although I was originally going to drop output_file_prefix, it's not actually hard to maintain. the behavior is changed now and objects are uploaded as soon as they're outputted rather than after the prediction is completed. * there's an ugly hack with uploading an empty body to get the redirect instead of making api time out from trying to upload an 140GB file. that can be fixed by implemented an MPU endpoint and/or a "fetch upload url" endpoint. * the behavior of the non-indempotent endpoint is changed; the id is now randomly generated if it's not provided in the body. this isn't strictly required for this change alone, but is hard to carve out. * the behavior of Path is changed significantly. see https://www.notion.so/replicate/Cog-Setup-Path-Problem-2fc41d40bcaf47579ccd8b2f4c71ee24 Co-authored-by: Mattt <[email protected]> * format * stick a %s on line 190 clients.py (#1707) * local upload server can be called cluster.local in addition to .internal (#1714) Signed-off-by: technillogue <[email protected]>
1 parent f98468d commit 6fa0af9

26 files changed

+821
-651
lines changed

Diff for: pyproject.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ authors = [{ name = "Replicate", email = "[email protected]" }]
1010
license.file = "LICENSE"
1111
urls."Source" = "https://github.com/replicate/cog"
1212

13-
requires-python = ">=3.7"
13+
requires-python = ">=3.8"
1414
dependencies = [
1515
# intentionally loose. perhaps these should be vendored to not collide with user code?
1616
"attrs>=20.1,<24",
1717
"fastapi>=0.75.2,<0.99.0",
18+
# we may not need http2
19+
"httpx[http2]>=0.21.0,<1",
1820
"pydantic>=1.9,<2",
1921
"PyYAML",
2022
"requests>=2,<3",
@@ -27,9 +29,9 @@ dependencies = [
2729
optional-dependencies = { "dev" = [
2830
"black",
2931
"build",
30-
"httpx",
3132
'hypothesis<6.80.0; python_version < "3.8"',
3233
'hypothesis; python_version >= "3.8"',
34+
"respx",
3335
'numpy<1.22.0; python_version < "3.8"',
3436
'numpy; python_version >= "3.8"',
3537
"pillow",

Diff for: python/cog/files.py

-86
This file was deleted.

Diff for: python/cog/json.py

+1-25
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import io
21
from datetime import datetime
32
from enum import Enum
43
from types import GeneratorType
5-
from typing import Any, Callable
4+
from typing import Any
65

76
from pydantic import BaseModel
87

9-
from .types import Path
10-
118

129
def make_encodeable(obj: Any) -> Any:
1310
"""
@@ -39,24 +36,3 @@ def make_encodeable(obj: Any) -> Any:
3936
if isinstance(obj, np.ndarray):
4037
return obj.tolist()
4138
return obj
42-
43-
44-
def upload_files(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any:
45-
"""
46-
Iterates through an object from make_encodeable and uploads any files.
47-
48-
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
49-
"""
50-
# skip four isinstance checks for fast text models
51-
if type(obj) == str: # noqa: E721
52-
return obj
53-
if isinstance(obj, dict):
54-
return {key: upload_files(value, upload_file) for key, value in obj.items()}
55-
if isinstance(obj, list):
56-
return [upload_files(value, upload_file) for value in obj]
57-
if isinstance(obj, Path):
58-
with obj.open("rb") as f:
59-
return upload_file(f)
60-
if isinstance(obj, io.IOBase):
61-
return upload_file(obj)
62-
return obj

Diff for: python/cog/logging.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ def setup_logging(*, log_level: int = logging.NOTSET) -> None:
8686

8787
# Reconfigure log levels for some overly chatty libraries
8888
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
89+
# FIXME: no more urllib3(?)
8990
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)

Diff for: python/cog/predictor.py

+52-37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import enum
22
import importlib.util
33
import inspect
4-
import io
54
import os.path
65
import sys
76
import types
@@ -11,6 +10,7 @@
1110
from pathlib import Path
1211
from typing import (
1312
Any,
13+
Awaitable,
1414
Callable,
1515
Dict,
1616
List,
@@ -22,34 +22,24 @@
2222
)
2323
from unittest.mock import patch
2424

25-
import structlog
26-
27-
import cog.code_xforms as code_xforms
28-
2925
try:
3026
from typing import get_args, get_origin
3127
except ImportError: # Python < 3.8
3228
from typing_compat import get_args, get_origin # type: ignore
3329

30+
import structlog
3431
import yaml
3532
from pydantic import BaseModel, Field, create_model
3633
from pydantic.fields import FieldInfo
3734

3835
# Added in Python 3.9. Can be from typing if we drop support for <3.9
3936
from typing_extensions import Annotated
4037

38+
from . import code_xforms
4139
from .errors import ConfigDoesNotExist, PredictorNotSet
42-
from .types import (
43-
CogConfig,
44-
Input,
45-
URLPath,
46-
)
47-
from .types import (
48-
File as CogFile,
49-
)
50-
from .types import (
51-
Path as CogPath,
52-
)
40+
from .types import CogConfig, Input, URLTempFile
41+
from .types import File as CogFile
42+
from .types import Path as CogPath
5343
from .types import Secret as CogSecret
5444

5545
log = structlog.get_logger("cog.server.predictor")
@@ -66,7 +56,7 @@
6656

6757

6858
class BasePredictor(ABC):
69-
def setup(self, weights: Optional[Union[CogFile, CogPath, str]] = None) -> None:
59+
def setup(self, weights: Optional[Union[CogFile, CogPath, str]] = None) -> Optional[Awaitable[None]]:
7060
"""
7161
An optional method to prepare the model so multiple predictions run efficiently.
7262
"""
@@ -81,51 +71,69 @@ def predict(self, **kwargs: Any) -> Any:
8171

8272

8373
def run_setup(predictor: BasePredictor) -> None:
74+
weights = get_weights_argument(predictor)
75+
if weights:
76+
predictor.setup(weights=weights)
77+
else:
78+
predictor.setup()
79+
80+
81+
async def run_setup_async(predictor: BasePredictor) -> None:
82+
weights = get_weights_argument(predictor)
83+
maybe_coro = predictor.setup(weights=weights) if weights else predictor.setup()
84+
if maybe_coro:
85+
return await maybe_coro
86+
87+
88+
def get_weights_argument(
89+
predictor: BasePredictor,
90+
) -> Union[CogFile, CogPath, str, None]:
91+
# by the time we get here we assume predictor has a setup method
8492
weights_type = get_weights_type(predictor.setup)
8593

8694
# No weights need to be passed, so just run setup() without any arguments.
8795
if weights_type is None:
8896
predictor.setup()
8997
return
9098

91-
weights: Union[io.IOBase, Path, str, None]
92-
9399
weights_url = os.environ.get("COG_WEIGHTS")
94-
weights_path = "weights"
100+
weights_path = "weights" # this is the source of a bug isn't it?
95101

96102
# TODO: Cog{File,Path}.validate(...) methods accept either "real"
97103
# paths/files or URLs to those things. In future we can probably tidy this
98104
# up a little bit.
99105
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
106+
107+
# this is a breaking change
108+
# previously, CogPath wouldn't be converted in setup(); now it is
109+
# essentially everyone needs to switch from Path to str (or a new URL type)
100110
if weights_url:
101111
if weights_type == CogFile:
102-
weights = cast(CogFile, CogFile.validate(weights_url))
112+
return cast(CogFile, CogFile.validate(weights_url))
103113
elif weights_type == CogPath:
104114
# TODO: So this can be a url. evil!
105-
weights = cast(CogPath, CogPath.validate(weights_url))
115+
return cast(CogPath, CogPath.validate(weights_url))
106116
# allow people to download weights themselves
107117
elif weights_type == str: # noqa: E721
108-
weights = weights_url
118+
return weights_url
109119
else:
110120
raise ValueError(
111121
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
112122
)
113123
elif os.path.exists(weights_path):
114124
if weights_type == CogFile:
115-
weights = cast(CogFile, open(weights_path, "rb"))
125+
return cast(CogFile, open(weights_path, "rb"))
116126
elif weights_type == CogPath:
117-
weights = CogPath(weights_path)
118-
else:
119-
raise ValueError(
120-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
121-
)
122-
else:
123-
weights = None
124-
125-
predictor.setup(weights=weights)
127+
return CogPath(weights_path)
128+
raise ValueError(
129+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
130+
)
131+
return None
126132

127133

128-
def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]:
134+
def get_weights_type(
135+
setup_function: Callable[[Any], Optional[Awaitable[None]]],
136+
) -> Optional[Any]:
129137
signature = inspect.signature(setup_function)
130138
if "weights" not in signature.parameters:
131139
return None
@@ -266,12 +274,19 @@ def cleanup(self) -> None:
266274
Cleanup any temporary files created by the input.
267275
"""
268276
for _, value in self:
269-
# Handle URLPath objects specially for cleanup.
277+
# Handle URLTempFile objects specially for cleanup.
270278
# Also handle pathlib.Path objects, which cog.Path is a subclass of.
271279
# A pathlib.Path object shouldn't make its way here,
272280
# but both have an unlink() method, so we may as well be safe.
273-
if isinstance(value, (URLPath, Path)):
274-
value.unlink(missing_ok=True)
281+
if isinstance(value, (URLTempFile, Path)):
282+
try:
283+
value.unlink(missing_ok=True)
284+
except FileNotFoundError:
285+
pass
286+
287+
# if we had a separate method to traverse the input and apply some function to each value
288+
# we could have cleanup/get_tempfile/convert functions that operate on a single value
289+
# and do it that way. convert is supposed to mutate though, so it's tricky
275290

276291

277292
def validate_input_type(type: Type[Any], name: str) -> None:

Diff for: python/cog/schema.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib.util
22
import os
33
import os.path
4+
import secrets
45
import sys
56
import typing as t
67
from datetime import datetime
@@ -43,7 +44,14 @@ class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow):
4344

4445

4546
class PredictionRequest(PredictionBaseModel):
46-
id: t.Optional[str]
47+
# there's a problem here where the idempotent endpoint is supposed to
48+
# let you pass id in the route and omit it from the input
49+
# however this fills in the default
50+
# maybe it should be allowed to be optional without the factory initially
51+
# and be filled in later
52+
#
53+
# actually, this changes the public api so we should really do this differently
54+
id: str = pydantic.Field(default_factory=lambda: secrets.token_hex(4))
4755
created_at: t.Optional[datetime]
4856

4957
# TODO: deprecate this

0 commit comments

Comments
 (0)