|
36 | 36 | from .types import ( |
37 | 37 | File as CogFile, |
38 | 38 | ) |
39 | | -from .types import ( |
40 | | - Input, |
41 | | - URLPath, |
42 | | -) |
| 39 | +from .types import Input |
43 | 40 | from .types import ( |
44 | 41 | Path as CogPath, |
45 | 42 | ) |
|
49 | 46 |
|
50 | 47 | class BasePredictor(ABC): |
51 | 48 | def setup( |
52 | | - self, weights: Optional[Union[CogFile, CogPath]] = None |
| 49 | + self, weights: Optional[Union[CogFile, CogPath, str]] = None |
53 | 50 | ) -> Optional[Awaitable[None]]: |
54 | 51 | """ |
55 | 52 | An optional method to prepare the model so multiple predictions run efficiently. |
@@ -79,34 +76,40 @@ async def run_setup_async(predictor: BasePredictor) -> None: |
79 | 76 | return await maybe_coro |
80 | 77 |
|
81 | 78 |
|
82 | | -def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, None]: |
| 79 | +def get_weights_argument(predictor: BasePredictor) -> Union[CogFile, CogPath, str, None]: |
83 | 80 | # by the time we get here we assume predictor has a setup method |
84 | 81 | weights_type = get_weights_type(predictor.setup) |
85 | 82 | if weights_type is None: |
86 | 83 | return None |
87 | 84 | weights_url = os.environ.get("COG_WEIGHTS") |
88 | | - weights_path = "weights" |
| 85 | + weights_path = "weights" # this is the source of a bug isn't it? |
89 | 86 |
|
90 | 87 | # TODO: Cog{File,Path}.validate(...) methods accept either "real" |
91 | 88 | # paths/files or URLs to those things. In future we can probably tidy this |
92 | 89 | # up a little bit. |
93 | 90 | # TODO: CogFile/CogPath should have subclasses for each of the subtypes |
| 91 | + |
| 92 | + # this is a breaking change |
| 93 | + # previously, CogPath wouldn't be converted; now it is |
| 94 | + # essentially everyone needs to switch from Path to str (or a new URL type) |
94 | 95 | if weights_url: |
95 | 96 | if weights_type == CogFile: |
96 | 97 | return cast(CogFile, CogFile.validate(weights_url)) |
97 | 98 | if weights_type == CogPath: |
98 | 99 | # TODO: So this can be a url. evil! |
99 | 100 | return cast(CogPath, CogPath.validate(weights_url)) |
| 101 | + if weights_type == str: |
| 102 | + return weights_url |
100 | 103 | raise ValueError( |
101 | | - f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported" |
| 104 | + f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported" |
102 | 105 | ) |
103 | 106 | if os.path.exists(weights_path): |
104 | 107 | if weights_type == CogFile: |
105 | 108 | return cast(CogFile, open(weights_path, "rb")) |
106 | 109 | if weights_type == CogPath: |
107 | 110 | return CogPath(weights_path) |
108 | 111 | raise ValueError( |
109 | | - f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported" |
| 112 | + f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File, Path and str are supported" |
110 | 113 | ) |
111 | 114 | return None |
112 | 115 |
|
@@ -212,17 +215,37 @@ def cleanup(self) -> None: |
212 | 215 | Cleanup any temporary files created by the input. |
213 | 216 | """ |
214 | 217 | for _, value in self: |
215 | | - # Handle URLPath objects specially for cleanup. |
216 | | - if isinstance(value, URLPath): |
217 | | - value.unlink() |
218 | | - # Note this is pathlib.Path, which cog.Path is a subclass of. A pathlib.Path object shouldn't make its way here, |
| 218 | + # # Handle URLPath objects specially for cleanup. |
| 219 | + # if isinstance(value, URLPath): |
| 220 | + # value.unlink() |
| 221 | + # Note this is pathlib.Path, of which cog.Path is a subclass of. |
| 222 | + # A pathlib.Path object shouldn't make its way here, |
219 | 223 | # but both have an unlink() method, so may as well be safe. |
220 | | - elif isinstance(value, Path): |
| 224 | + # |
| 225 | + # URLTempFile, DataURLTempFilePath, pathlib.Path, doesn't matter |
| 226 | + # everyone can be unlinked |
| 227 | + if isinstance(value, Path): |
221 | 228 | try: |
222 | 229 | value.unlink() |
223 | 230 | except FileNotFoundError: |
224 | 231 | pass |
225 | 232 |
|
| 233 | + # if we had a separate method to traverse the input and apply some function to each value |
| 234 | + # we could use something like these functions here |
| 235 | + |
| 236 | + # def cleanup(): |
| 237 | + # if isinstance(value, Path): |
| 238 | + # value.unlink() |
| 239 | + |
| 240 | + # def get_tempfile(): |
| 241 | + # if isinstance(value, URLTempFile): |
| 242 | + # return (value.url, value._path) |
| 243 | + |
| 244 | + # # this one is very annoying because it's supposed to mutate |
| 245 | + # def convert(): |
| 246 | + # if isinstance(value, URLTempFile): |
| 247 | + # return value.convert() |
| 248 | + |
226 | 249 |
|
227 | 250 | def validate_input_type(type: Type[Any], name: str) -> None: |
228 | 251 | if type is inspect.Signature.empty: |
|
0 commit comments