|
| 1 | +import importlib.util |
| 2 | +import json |
| 3 | +import warnings |
| 4 | +from pathlib import Path |
1 | 5 | from typing import Any, Optional, Union |
2 | 6 |
|
3 | | -from cccv.config import CONFIG_REGISTRY, BaseConfig |
| 7 | +from cccv.config import CONFIG_REGISTRY, AutoBaseConfig |
4 | 8 | from cccv.type import ConfigType |
| 9 | +from cccv.util.remote import git_clone |
5 | 10 |
|
6 | 11 |
|
7 | 12 | class AutoConfig: |
8 | 13 | @staticmethod |
9 | 14 | def from_pretrained( |
10 | | - pretrained_model_name: Union[ConfigType, str], |
| 15 | + pretrained_model_name_or_path: Union[ConfigType, str, Path], |
| 16 | + *, |
| 17 | + model_dir: Optional[Union[Path, str]] = None, |
11 | 18 | **kwargs: Any, |
12 | 19 | ) -> Any: |
13 | 20 | """ |
14 | | - Get a config instance of a pretrained model configuration. |
| 21 | + Get a config instance of a pretrained model configuration, can be a registered config name or a local path or a git url. |
15 | 22 |
|
16 | | - :param pretrained_model_name: The name of the pretrained model configuration |
| 23 | + :param pretrained_model_name_or_path: |
| 24 | + :param model_dir: The path to cache the downloaded model configuration. Should be a full path. If None, use default cache path. |
17 | 25 | :return: |
18 | 26 | """ |
19 | | - return CONFIG_REGISTRY.get(pretrained_model_name) |
| 27 | + if "pretrained_model_name" in kwargs: |
| 28 | + warnings.warn( |
| 29 | + "[CCCV] 'pretrained_model_name' is deprecated, please use 'pretrained_model_name_or_path' instead.", |
| 30 | + DeprecationWarning, |
| 31 | + stacklevel=2, |
| 32 | + ) |
| 33 | + pretrained_model_name_or_path = kwargs.pop("pretrained_model_name") |
20 | 34 |
|
21 | | - @staticmethod |
22 | | - def register(config: Union[BaseConfig, Any], name: Optional[str] = None) -> None: |
23 | | - """ |
24 | | - Register the given config class instance under the name BaseConfig.name or the given name. |
25 | | - Can be used as a function call. See docstring of this class for usage. |
| 35 | + # 1. check if it's a registered config name, early return if found |
| 36 | + if isinstance(pretrained_model_name_or_path, ConfigType): |
| 37 | + pretrained_model_name_or_path = pretrained_model_name_or_path.value |
| 38 | + if str(pretrained_model_name_or_path) in CONFIG_REGISTRY: |
| 39 | + return CONFIG_REGISTRY.get(str(pretrained_model_name_or_path)) |
26 | 40 |
|
27 | | - :param config: The config class instance to register |
28 | | - :param name: The name to register the config class instance under. If None, use BaseConfig.name |
29 | | - :return: |
30 | | - """ |
31 | | - # used as a function call |
32 | | - CONFIG_REGISTRY.register(obj=config, name=name) |
| 41 | + # 2. check is a url or not, if it's a url, git clone it to model_dir then replace pretrained_model_name_or_path with the local path (Path) |
| 42 | + if str(pretrained_model_name_or_path).startswith("http"): |
| 43 | + pretrained_model_name_or_path = git_clone( |
| 44 | + git_url=str(pretrained_model_name_or_path), |
| 45 | + model_dir=model_dir, |
| 46 | + **kwargs, |
| 47 | + ) |
| 48 | + |
| 49 | + # 3. check if it's a real path |
| 50 | + dir_path = Path(str(pretrained_model_name_or_path)) |
| 51 | + |
| 52 | + if not dir_path.exists() or not dir_path.is_dir(): |
| 53 | + raise ValueError(f"[CCCV] model configuration '{dir_path}' is not a valid config name or path") |
| 54 | + |
| 55 | + # load config,json from the directory |
| 56 | + config_path = dir_path / "config.json" |
| 57 | + # check if config.json exists |
| 58 | + if not config_path.exists(): |
| 59 | + raise FileNotFoundError(f"[CCCV] no valid config.json not found in {dir_path}") |
| 60 | + |
| 61 | + with open(config_path, "r", encoding="utf-8") as f: |
| 62 | + config_dict = json.load(f) |
| 63 | + |
| 64 | + for k in ["arch", "model", "name"]: |
| 65 | + if k not in config_dict: |
| 66 | + raise KeyError( |
| 67 | + f"[CCCV] no key '{k}' in config.json in {dir_path}, you should provide a valid config.json contain a key '{k}'" |
| 68 | + ) |
| 69 | + |
| 70 | + # auto import all .py files in the directory to register the arch, model and config |
| 71 | + try: |
| 72 | + for py_file in dir_path.glob("*.py"): |
| 73 | + spec = importlib.util.spec_from_file_location(py_file.stem, py_file) |
| 74 | + if spec is None or spec.loader is None: |
| 75 | + continue |
| 76 | + module = importlib.util.module_from_spec(spec) |
| 77 | + spec.loader.exec_module(module) |
| 78 | + except Exception as e: |
| 79 | + raise ImportError(f"[CCCV] failed register model from {dir_path}, error: {e}, please check your .py files") |
| 80 | + |
| 81 | + if "path" not in config_dict or config_dict["path"] is None or config_dict["path"] == "": |
| 82 | + # add the path to the config_dict |
| 83 | + config_dict["path"] = str(dir_path / config_dict["name"]) |
| 84 | + |
| 85 | + # convert config_dict to pydantic model |
| 86 | + cfg = AutoBaseConfig.model_validate(config_dict) |
| 87 | + return cfg |
0 commit comments