diff --git a/src/xturing/models/base.py b/src/xturing/models/base.py index 698c298..c8bc6c5 100644 --- a/src/xturing/models/base.py +++ b/src/xturing/models/base.py @@ -12,16 +12,28 @@ class BaseModel(BaseParent): registry = {} @classmethod - def load(cls, weights_dir_or_model_name): + def load(cls, weights_dir_or_model_name, model_name=None, **kwargs): path_weights_dir_or_model_name = Path(weights_dir_or_model_name) - if path_weights_dir_or_model_name.is_dir() and exists_xturing_config_file( - path_weights_dir_or_model_name - ): - return cls.load_from_local(weights_dir_or_model_name) - else: - print("Loading model from xTuring hub") - return cls.load_from_hub(weights_dir_or_model_name) + if path_weights_dir_or_model_name.is_dir(): + if exists_xturing_config_file(path_weights_dir_or_model_name): + return cls.load_from_local(path_weights_dir_or_model_name) + + if model_name is None: + raise ValueError( + "No xturing.json found in local directory '{}'. " + "Pass model_name=... to BaseModel.load(...) for local non-xTuring " + "checkpoints, or use a GenericModel class directly.".format( + str(path_weights_dir_or_model_name) + ) + ) + + return cls._load_local_path_for_model_name( + path_weights_dir_or_model_name, model_name, **kwargs + ) + + print("Loading model from xTuring hub") + return cls.load_from_hub(weights_dir_or_model_name) @classmethod def load_from_hub(cls, model_name): @@ -49,16 +61,25 @@ def load_from_local(cls, weights_dir_path): model_name is not None ), "The xturing.json file is not correct. model_name is not available in the configuration" + return cls._load_local_path_for_model_name(weights_dir_path, model_name) + + @classmethod + def _load_local_path_for_model_name(cls, weights_dir_path, model_name, **kwargs): + weights_dir_path = Path(weights_dir_path) + assert ( cls.registry.get(model_name) is not None ), "The model_name {} is not valid".format(model_name) if "generic" in model_name: model = cls.create( - model_name, model_name=model_name, weights_path=weights_dir_path + model_name, + model_name=model_name, + weights_path=weights_dir_path, + **kwargs, ) else: - model = cls.create(model_name, weights_path=weights_dir_path) + model = cls.create(model_name, weights_path=weights_dir_path, **kwargs) return model diff --git a/tests/xturing/models/test_base_model_load.py b/tests/xturing/models/test_base_model_load.py new file mode 100644 index 0000000..8f84b2a --- /dev/null +++ b/tests/xturing/models/test_base_model_load.py @@ -0,0 +1,34 @@ +import pytest + +from xturing.models import BaseModel + + +class _DummyModel: + def __init__(self, weights_path=None, model_name=None, **kwargs): + self.weights_path = weights_path + self.model_name = model_name + self.kwargs = kwargs + + +def test_load_local_dir_without_xturing_config_with_model_name(tmp_path, monkeypatch): + local_weights = tmp_path / "hf-local-model" + local_weights.mkdir() + + monkeypatch.setitem(BaseModel.registry, "dummy_model", _DummyModel) + + loaded = BaseModel.load( + str(local_weights), model_name="dummy_model", revision="main" + ) + + assert isinstance(loaded, _DummyModel) + assert loaded.weights_path == local_weights + assert loaded.model_name is None + assert loaded.kwargs["revision"] == "main" + + +def test_load_local_dir_without_xturing_config_requires_model_name(tmp_path): + local_weights = tmp_path / "hf-local-model" + local_weights.mkdir() + + with pytest.raises(ValueError, match="No xturing.json found"): + BaseModel.load(str(local_weights))