Skip to content

Commit d6cc65d

Browse files
committed
Scary temporary commit for a hemorrhaging-edge release
* add concurrency to config * this basically works! * more descriptive names for predict functions * maybe pass through prediction id and try to make cancelation do both? * don't cancel from signal handler if a loop is running. expose worker busy state to runner * move handle_event_stream to PredictionEventHandler * make setup and canceling work * drop some checks around cancelation * try out eager_predict_state_change * keep track of multiple runner prediction tasks to make idempotent endpoint return the same result and fix tests somewhat * fix idempotent tests * fix remaining errors? * worker predict_generator shouldn't be eager * wip: make the stuff that handles events and sends webhooks etc async * drop Runner._result * drop comments * inline client code * get started * inline webhooks * move clients into runner, switch to httpx, move create_event_handler into runner * add some comments * more notes * rip out webhooks and most of files and put them in a new ClientManager that handles most of everything. inline upload_files for that * move create_event_handler into PredictionEventHandler.__init__ * fix one test * break out Path.validate into value_to_path and inline get_filename and File.validate * split out URLPath into BackwardsCompatibleDataURLTempFilePath and URLThatCanBeConvertedToPath with the download part of URLFile inlined * let's make DataURLTempFilePath also use convert and move value_to_path back to Path.validate * use httpx for downloading input urls and follow redirects * take get_filename back out for tests * don't upload in http and delete cog/files.py * drop should_cancel * prediction->request * split up predict/inner/prediction_ctx into enter_predict/exit_predict/prediction_ctx/inner_async_predict/predict/good_predict as one way to do it. however, exposing all of those for runner predict enter/coro exit still sucks, but this is still an improvement * bigish change: inline predict_and_handle_errors * inline make_error_handler into setup * move runner.setup into runner.Runner.setup * add concurrency to config in go * try explicitly using prediction_ctx __enter__ and __exit__ * make runner setup more correct and marginally better * fix a few tests * notes * wip ClientManager.convert * relax setup argument requirement to str * glom worker into runner * add logging message * fix prediction retry and improve logging * split out handle_event * use CURL_CA_BUNDLE for file upload * clean up comments * dubious upload fix * small fixes * attempt to add context logging? * tweak names * fix error for predictionOutputType(multi=False) * improve comments * fix lints * skip worker and webhook tests since those were erroring on removed imports. fix or xfail runner tests * upload in http instead of PredictionEventHandler. this makes tests pass and fixes some problems with validation, but also prevents streaming files and causes new problems. also xfail all the responses tests that need to be replaced with respx * format * fix some new-style type signatures and drop 3.8 support * drop 3.7 in code Signed-off-by: technillogue <[email protected]>
1 parent a700c7f commit d6cc65d

31 files changed

+1151
-875
lines changed

Diff for: .github/workflows/ci.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ jobs:
5353
strategy:
5454
fail-fast: false
5555
matrix:
56-
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
56+
# python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
57+
python-version: ["3.8", "3.9", "3.10", "3.11"]
5758
defaults:
5859
run:
5960
shell: bash

Diff for: Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ GO := go
1111
GOOS := $(shell $(GO) env GOOS)
1212
GOARCH := $(shell $(GO) env GOARCH)
1313

14-
PYTHON := python
14+
PYTHON ?= python
1515
PYTEST := $(PYTHON) -m pytest
1616
PYRIGHT := $(PYTHON) -m pyright
1717
RUFF := $(PYTHON) -m ruff

Diff for: pkg/config/config.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ type Example struct {
5050
}
5151

5252
type Config struct {
53-
Build *Build `json:"build" yaml:"build"`
54-
Image string `json:"image,omitempty" yaml:"image"`
55-
Predict string `json:"predict,omitempty" yaml:"predict"`
56-
Train string `json:"train,omitempty" yaml:"train"`
53+
Build *Build `json:"build" yaml:"build"`
54+
Image string `json:"image,omitempty" yaml:"image"`
55+
Predict string `json:"predict,omitempty" yaml:"predict"`
56+
Train string `json:"train,omitempty" yaml:"train"`
57+
Concurrency int `json:"concurrency,omitempty" yaml:"concurrency"`
5758
}
5859

5960
func DefaultConfig() *Config {

Diff for: pkg/config/data/config_schema_v1.0.json

+5
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@
140140
"$id": "#/properties/train",
141141
"type": "string",
142142
"description": "The pointer to the `Predictor` object in your code, which defines how predictions are run on your model."
143+
},
144+
"concurrency": {
145+
"$id": "#/properties/concurrency",
146+
"type": "number",
147+
"description": "Allowed concurrency."
143148
}
144149
},
145150
"additionalProperties": false

Diff for: pyproject.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ authors = [{ name = "Replicate", email = "[email protected]" }]
1010
license.file = "LICENSE"
1111
urls."Source" = "https://github.com/replicate/cog"
1212

13-
requires-python = ">=3.7"
13+
requires-python = ">=3.8"
1414
dependencies = [
1515
# intentionally loose. perhaps these should be vendored to not collide with user code?
1616
"attrs>=20.1,<24",
1717
"fastapi>=0.75.2,<0.99.0",
18+
# this version specification is pretty arbitrary, and we may not need http2
19+
"httpx[http2]>=0.25.0,<0.27",
1820
"pydantic>=1.9,<2",
1921
"PyYAML",
2022
"requests>=2,<3",
@@ -27,9 +29,9 @@ dependencies = [
2729
optional-dependencies = { "dev" = [
2830
"black",
2931
"build",
30-
"httpx",
3132
'hypothesis<6.80.0; python_version < "3.8"',
3233
'hypothesis; python_version >= "3.8"',
34+
"respx",
3335
'numpy<1.22.0; python_version < "3.8"',
3436
'numpy; python_version >= "3.8"',
3537
"pillow",

Diff for: python/cog/files.py

-75
This file was deleted.

Diff for: python/cog/json.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import io
21
from datetime import datetime
32
from enum import Enum
43
from types import GeneratorType
5-
from typing import Any, Callable
4+
from typing import Any
65

76
from pydantic import BaseModel
87

9-
from .types import Path
10-
118

129
def make_encodeable(obj: Any) -> Any:
1310
"""
@@ -39,21 +36,3 @@ def make_encodeable(obj: Any) -> Any:
3936
if isinstance(obj, np.ndarray):
4037
return obj.tolist()
4138
return obj
42-
43-
44-
def upload_files(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any:
45-
"""
46-
Iterates through an object from make_encodeable and uploads any files.
47-
48-
When a file is encountered, it will be passed to upload_file. Any paths will be opened and converted to files.
49-
"""
50-
if isinstance(obj, dict):
51-
return {key: upload_files(value, upload_file) for key, value in obj.items()}
52-
if isinstance(obj, list):
53-
return [upload_files(value, upload_file) for value in obj]
54-
if isinstance(obj, Path):
55-
with obj.open("rb") as f:
56-
return upload_file(f)
57-
if isinstance(obj, io.IOBase):
58-
return upload_file(obj)
59-
return obj

Diff for: python/cog/logging.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,5 @@ def setup_logging(*, log_level: int = logging.NOTSET) -> None:
8686

8787
# Reconfigure log levels for some overly chatty libraries
8888
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
89+
# FIXME: no more urllib3(?)
8990
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)

Diff for: python/cog/predictor.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,7 @@
3636
from .types import (
3737
File as CogFile,
3838
)
39-
from .types import (
40-
Input,
41-
URLPath,
42-
)
39+
from .types import Input
4340
from .types import (
4441
Path as CogPath,
4542
)
@@ -49,7 +46,7 @@
4946

5047
class BasePredictor(ABC):
5148
def setup(
52-
self, weights: Optional[Union[CogFile, CogPath]] = None
49+
self, weights: Optional[Union[CogFile, CogPath, str]] = None
5350
) -> Optional[Awaitable[None]]:
5451
"""
5552
An optional method to prepare the model so multiple predictions run efficiently.
@@ -79,34 +76,40 @@ async def run_setup_async(predictor: BasePredictor) -> None:
7976
return await maybe_coro
8077

8178

82-
def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]:
79+
def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, str, None]:
8380
# by the time we get here we assume predictor has a setup method
8481
weights_type = get_weights_type(predictor.setup)
8582
if weights_type is None:
8683
return None
8784
weights_url = os.environ.get("COG_WEIGHTS")
88-
weights_path = "weights"
85+
weights_path = "weights" # this is the source of a bug isn't it?
8986

9087
# TODO: Cog{File,Path}.validate(...) methods accept either "real"
9188
# paths/files or URLs to those things. In future we can probably tidy this
9289
# up a little bit.
9390
# TODO: CogFile/CogPath should have subclasses for each of the subtypes
91+
92+
# this is a breaking change
93+
# previously, CogPath wouldn't be converted; now it is
94+
# essentially everyone needs to switch from Path to str (or a new URL type)
9495
if weights_url:
9596
if weights_type == CogFile:
9697
return cast(CogFile, CogFile.validate(weights_url))
9798
if weights_type == CogPath:
9899
# TODO: So this can be a url. evil!
99100
return cast(CogPath, CogPath.validate(weights_url))
101+
if weights_type == str:
102+
return weights_url
100103
raise ValueError(
101-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
104+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
102105
)
103106
if os.path.exists(weights_path):
104107
if weights_type == CogFile:
105108
return cast(CogFile, open(weights_path, "rb"))
106109
if weights_type == CogPath:
107110
return CogPath(weights_path)
108111
raise ValueError(
109-
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
112+
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported"
110113
)
111114
return None
112115

@@ -212,17 +215,37 @@ def cleanup(self) -> None:
212215
Cleanup any temporary files created by the input.
213216
"""
214217
for _, value in self:
215-
# Handle URLPath objects specially for cleanup.
216-
if isinstance(value, URLPath):
217-
value.unlink()
218-
# Note this is pathlib.Path, which cog.Path is a subclass of. A pathlib.Path object shouldn't make its way here,
218+
# # Handle URLPath objects specially for cleanup.
219+
# if isinstance(value, URLPath):
220+
# value.unlink()
221+
# Note this is pathlib.Path, of which cog.Path is a subclass of.
222+
# A pathlib.Path object shouldn't make its way here,
219223
# but both have an unlink() method, so may as well be safe.
220-
elif isinstance(value, Path):
224+
#
225+
# URLTempFile, DataURLTempFilePath, pathlib.Path, doesn't matter
226+
# everyone can be unlinked
227+
if isinstance(value, Path):
221228
try:
222229
value.unlink()
223230
except FileNotFoundError:
224231
pass
225232

233+
# if we had a separate method to traverse the input and apply some function to each value
234+
# we could use something like these functions here
235+
236+
# def cleanup():
237+
# if isinstance(value, Path):
238+
# value.unlink()
239+
240+
# def get_tempfile():
241+
# if isinstance(value, URLTempFile):
242+
# return (value.url, value._path)
243+
244+
# # this one is very annoying because it's supposed to mutate
245+
# def convert():
246+
# if isinstance(value, URLTempFile):
247+
# return value.convert()
248+
226249

227250
def validate_input_type(type: Type[Any], name: str) -> None:
228251
if type is inspect.Signature.empty:

Diff for: python/cog/schema.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import secrets
12
import typing as t
23
from datetime import datetime
34
from enum import Enum
@@ -36,7 +37,15 @@ class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow):
3637

3738

3839
class PredictionRequest(PredictionBaseModel):
39-
id: t.Optional[str]
40+
# there's a problem here where the idempotent endpoint is supposed to
41+
# let you pass id in the route and omit it from the input
42+
# however this fills in the default
43+
# maybe it should be allowed to be optional without the factory initially
44+
# and be filled in later
45+
#
46+
#
47+
# actually, this changes the public api so we should really do this differently
48+
id: str = pydantic.Field(default_factory=lambda: secrets.token_hex(4))
4049
created_at: t.Optional[datetime]
4150

4251
# TODO: deprecate this
@@ -85,8 +94,10 @@ def with_types(cls, input_type: t.Type[t.Any], output_type: t.Type[t.Any]) -> t.
8594
output=(output_type, None),
8695
)
8796

97+
8898
class TrainingRequest(PredictionRequest):
8999
pass
90100

101+
91102
class TrainingResponse(PredictionResponse):
92103
pass

0 commit comments

Comments
 (0)