Skip to content

Commit 1ad50b0

Browse files
authored
feat: enhance Auto Class to support remote and local path (#5)
* refactor: rename parameter pretrained_model_name_or_path * feat: enhance Auto Class to support remote and local path
1 parent 4226db0 commit 1ad50b0

24 files changed

+365
-218
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,5 @@ cython_debug/
170170

171171
*.mp4
172172
*.mkv
173+
174+
/cccv/cache_models/

README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pip install cccv
2020

2121
### Start
2222

23-
#### cv2
23+
#### Load a registered model in cccv
2424

2525
a simple example to use the SISR (Single Image Super-Resolution) model to process an image
2626

@@ -37,6 +37,22 @@ img = model.inference_image(img)
3737
cv2.imwrite("test_out.jpg", img)
3838
```
3939

40+
#### Load a custom model from remote repository or local path
41+
42+
a simple example to use [remote repository](https://github.com/EutropicAI/cccv_demo_remote_model) or local path, auto register the model then load
43+
44+
```python
45+
import cv2
46+
import numpy as np
47+
48+
from cccv import AutoModel, SRBaseModel
49+
50+
# remote repo
51+
model: SRBaseModel = AutoModel.from_pretrained("https://github.com/EutropicAI/cccv_demo_remote_model")
52+
# local path
53+
model: SRBaseModel = AutoModel.from_pretrained("/path/to/cccv_demo_model")
54+
```
55+
4056
#### VapourSynth
4157

4258
a simple example to use the VapourSynth to process a video

cccv/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@
2626

2727
from cccv.arch import ARCH_REGISTRY
2828
from cccv.auto import AutoConfig, AutoModel
29-
from cccv.config import CONFIG_REGISTRY, BaseConfig, SRBaseConfig, VFIBaseConfig, VSRBaseConfig
29+
from cccv.config import CONFIG_REGISTRY, AutoBaseConfig, BaseConfig, SRBaseConfig, VFIBaseConfig, VSRBaseConfig
3030
from cccv.model import MODEL_REGISTRY, AuxiliaryBaseModel, CCBaseModel, SRBaseModel, VFIBaseModel, VSRBaseModel
3131
from cccv.type import ArchType, BaseModelInterface, ConfigType, ModelType

cccv/arch/sr/dat_arch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ def __init__(
365365
elif idx == 1:
366366
W_sp, H_sp = self.split_size[0], self.split_size[1]
367367
else:
368-
print("ERROR MODE", idx)
369-
exit(0)
368+
raise ValueError(f"[CCCV] ERROR MODE: invalid idx {idx}, expected 0 or 1")
370369
self.H_sp = H_sp
371370
self.W_sp = W_sp
372371

cccv/arch/sr/upcunet_arch.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro):
376376
t2 = tile_mode * 2
377377
crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode)
378378
else:
379-
print("tile_mode config error")
380-
os._exit(233)
379+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
381380

382381
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
383382
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
@@ -526,8 +525,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro):
526525
t2 = tile_mode * 2
527526
crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode)
528527
else:
529-
print("tile_mode config error")
530-
os._exit(233)
528+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
531529
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
532530
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
533531
x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect")
@@ -767,8 +765,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro):
767765
t4 = tile_mode * 4
768766
crop_size = (((h0 - 1) // t4 * t4 + t4) // tile_mode, ((w0 - 1) // t4 * t4 + t4) // tile_mode)
769767
else:
770-
print("tile_mode config error")
771-
os._exit(233)
768+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
772769
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
773770
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
774771
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
@@ -916,8 +913,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro):
916913
t4 = tile_mode * 4
917914
crop_size = (((h0 - 1) // t4 * t4 + t4) // tile_mode, ((w0 - 1) // t4 * t4 + t4) // tile_mode)
918915
else:
919-
print("tile_mode config error")
920-
os._exit(233)
916+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
921917
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
922918
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
923919
x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect")
@@ -1162,8 +1158,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro):
11621158
t2 = tile_mode * 2
11631159
crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode)
11641160
else:
1165-
print("tile_mode config error")
1166-
os._exit(233)
1161+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
11671162
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
11681163
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
11691164
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")
@@ -1323,8 +1318,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro):
13231318
t2 = tile_mode * 2
13241319
crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode) # 5.6G
13251320
else:
1326-
print("tile_mode config error")
1327-
os._exit(233)
1321+
raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value")
13281322
ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0]
13291323
pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1]
13301324
x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect")

cccv/arch/vfi/drba_arch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import warnings
3+
24
import numpy as np
35
import torch
46
import torch.nn as nn
@@ -61,7 +63,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa
6163
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i]
6264
)
6365
if ensemble:
64-
print("warning: ensemble is not supported since RIFEv4.21")
66+
warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2)
6567
else:
6668
wf0 = warp(f0, flow[:, :2])
6769
wf1 = warp(f1, flow[:, 2:4])
@@ -71,7 +73,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa
7173
scale=scale_list[i],
7274
)
7375
if ensemble:
74-
print("warning: ensemble is not supported since RIFEv4.21")
76+
warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2)
7577
else:
7678
mask = m0
7779
flow = flow + fd
@@ -83,7 +85,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa
8385
mask = torch.sigmoid(mask)
8486
merged[4] = warped_img0 * mask + warped_img1 * (1 - mask)
8587
if not fastmode:
86-
print("contextnet is removed")
88+
warnings.warn("[CCCV] contextnet is removed", stacklevel=2)
8789
"""
8890
c0 = self.contextnet(img0, flow[:, :2])
8991
c1 = self.contextnet(img1, flow[:, 2:4])

cccv/arch/vfi/ifnet_arch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# type: ignore
2+
import warnings
3+
24
import torch
35
import torch.nn as nn
46
import torch.nn.functional as F
@@ -43,7 +45,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals
4345
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i]
4446
)
4547
if ensemble:
46-
print("warning: ensemble is not supported since RIFEv4.21")
48+
warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2)
4749
else:
4850
wf0 = warp(f0, flow[:, :2])
4951
wf1 = warp(f1, flow[:, 2:4])
@@ -53,7 +55,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals
5355
scale=scale_list[i],
5456
)
5557
if ensemble:
56-
print("warning: ensemble is not supported since RIFEv4.21")
58+
warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2)
5759
else:
5860
mask = m0
5961
flow = flow + fd
@@ -65,7 +67,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals
6567
mask = torch.sigmoid(mask)
6668
merged[4] = warped_img0 * mask + warped_img1 * (1 - mask)
6769
if not fastmode:
68-
print("contextnet is removed")
70+
warnings.warn("[CCCV] contextnet is removed", stacklevel=2)
6971
"""
7072
c0 = self.contextnet(img0, flow[:, :2])
7173
c1 = self.contextnet(img1, flow[:, 2:4])

cccv/arch/vfi/vfi_utils/softsplat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
6161
strKey += str(objValue.stride())
6262

6363
elif True:
64-
print(strVariable, type(objValue))
64+
print(f"[CCCV] {strVariable}, {type(objValue)}")
6565

6666
# end
6767
# end
@@ -106,10 +106,10 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict):
106106
strKernel = strKernel.replace("{{type}}", "long")
107107

108108
elif isinstance(objValue, torch.Tensor):
109-
print(strVariable, objValue.dtype)
109+
print(f"[CCCV] {strVariable}, {objValue.dtype}")
110110

111111
elif True:
112-
print(strVariable, type(objValue))
112+
print(f"[CCCV] {strVariable}, {type(objValue)}")
113113

114114
# end
115115
# end

cccv/auto/config.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,87 @@
1+
import importlib.util
2+
import json
3+
import warnings
4+
from pathlib import Path
15
from typing import Any, Optional, Union
26

3-
from cccv.config import CONFIG_REGISTRY, BaseConfig
7+
from cccv.config import CONFIG_REGISTRY, AutoBaseConfig
48
from cccv.type import ConfigType
9+
from cccv.util.remote import git_clone
510

611

712
class AutoConfig:
813
@staticmethod
914
def from_pretrained(
10-
pretrained_model_name: Union[ConfigType, str],
15+
pretrained_model_name_or_path: Union[ConfigType, str, Path],
16+
*,
17+
model_dir: Optional[Union[Path, str]] = None,
1118
**kwargs: Any,
1219
) -> Any:
1320
"""
14-
Get a config instance of a pretrained model configuration.
21+
Get a config instance of a pretrained model configuration, can be a registered config name or a local path or a git url.
1522
16-
:param pretrained_model_name: The name of the pretrained model configuration
23+
:param pretrained_model_name_or_path:
24+
:param model_dir: The path to cache the downloaded model configuration. Should be a full path. If None, use default cache path.
1725
:return:
1826
"""
19-
return CONFIG_REGISTRY.get(pretrained_model_name)
27+
if "pretrained_model_name" in kwargs:
28+
warnings.warn(
29+
"[CCCV] 'pretrained_model_name' is deprecated, please use 'pretrained_model_name_or_path' instead.",
30+
DeprecationWarning,
31+
stacklevel=2,
32+
)
33+
pretrained_model_name_or_path = kwargs.pop("pretrained_model_name")
2034

21-
@staticmethod
22-
def register(config: Union[BaseConfig, Any], name: Optional[str] = None) -> None:
23-
"""
24-
Register the given config class instance under the name BaseConfig.name or the given name.
25-
Can be used as a function call. See docstring of this class for usage.
35+
# 1. check if it's a registered config name, early return if found
36+
if isinstance(pretrained_model_name_or_path, ConfigType):
37+
pretrained_model_name_or_path = pretrained_model_name_or_path.value
38+
if str(pretrained_model_name_or_path) in CONFIG_REGISTRY:
39+
return CONFIG_REGISTRY.get(str(pretrained_model_name_or_path))
2640

27-
:param config: The config class instance to register
28-
:param name: The name to register the config class instance under. If None, use BaseConfig.name
29-
:return:
30-
"""
31-
# used as a function call
32-
CONFIG_REGISTRY.register(obj=config, name=name)
41+
# 2. check is a url or not, if it's a url, git clone it to model_dir then replace pretrained_model_name_or_path with the local path (Path)
42+
if str(pretrained_model_name_or_path).startswith("http"):
43+
pretrained_model_name_or_path = git_clone(
44+
git_url=str(pretrained_model_name_or_path),
45+
model_dir=model_dir,
46+
**kwargs,
47+
)
48+
49+
# 3. check if it's a real path
50+
dir_path = Path(str(pretrained_model_name_or_path))
51+
52+
if not dir_path.exists() or not dir_path.is_dir():
53+
raise ValueError(f"[CCCV] model configuration '{dir_path}' is not a valid config name or path")
54+
55+
# load config,json from the directory
56+
config_path = dir_path / "config.json"
57+
# check if config.json exists
58+
if not config_path.exists():
59+
raise FileNotFoundError(f"[CCCV] no valid config.json not found in {dir_path}")
60+
61+
with open(config_path, "r", encoding="utf-8") as f:
62+
config_dict = json.load(f)
63+
64+
for k in ["arch", "model", "name"]:
65+
if k not in config_dict:
66+
raise KeyError(
67+
f"[CCCV] no key '{k}' in config.json in {dir_path}, you should provide a valid config.json contain a key '{k}'"
68+
)
69+
70+
# auto import all .py files in the directory to register the arch, model and config
71+
try:
72+
for py_file in dir_path.glob("*.py"):
73+
spec = importlib.util.spec_from_file_location(py_file.stem, py_file)
74+
if spec is None or spec.loader is None:
75+
continue
76+
module = importlib.util.module_from_spec(spec)
77+
spec.loader.exec_module(module)
78+
except Exception as e:
79+
raise ImportError(f"[CCCV] failed register model from {dir_path}, error: {e}, please check your .py files")
80+
81+
if "path" not in config_dict or config_dict["path"] is None or config_dict["path"] == "":
82+
# add the path to the config_dict
83+
config_dict["path"] = str(dir_path / config_dict["name"])
84+
85+
# convert config_dict to pydantic model
86+
cfg = AutoBaseConfig.model_validate(config_dict)
87+
return cfg

0 commit comments

Comments
 (0)