Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/xturing/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,28 @@ class BaseModel(BaseParent):
registry = {}

@classmethod
def load(cls, weights_dir_or_model_name):
def load(cls, weights_dir_or_model_name, model_name=None, **kwargs):
path_weights_dir_or_model_name = Path(weights_dir_or_model_name)

if path_weights_dir_or_model_name.is_dir() and exists_xturing_config_file(
path_weights_dir_or_model_name
):
return cls.load_from_local(weights_dir_or_model_name)
else:
print("Loading model from xTuring hub")
return cls.load_from_hub(weights_dir_or_model_name)
if path_weights_dir_or_model_name.is_dir():
if exists_xturing_config_file(path_weights_dir_or_model_name):
return cls.load_from_local(path_weights_dir_or_model_name)

if model_name is None:
raise ValueError(
"No xturing.json found in local directory '{}'. "
"Pass model_name=... to BaseModel.load(...) for local non-xTuring "
"checkpoints, or use a GenericModel class directly.".format(
str(path_weights_dir_or_model_name)
)
)

return cls._load_local_path_for_model_name(
path_weights_dir_or_model_name, model_name, **kwargs
)

print("Loading model from xTuring hub")
return cls.load_from_hub(weights_dir_or_model_name)

@classmethod
def load_from_hub(cls, model_name):
Expand Down Expand Up @@ -49,16 +61,25 @@ def load_from_local(cls, weights_dir_path):
model_name is not None
), "The xturing.json file is not correct. model_name is not available in the configuration"

return cls._load_local_path_for_model_name(weights_dir_path, model_name)

@classmethod
def _load_local_path_for_model_name(cls, weights_dir_path, model_name, **kwargs):
weights_dir_path = Path(weights_dir_path)

assert (
cls.registry.get(model_name) is not None
), "The model_name {} is not valid".format(model_name)

if "generic" in model_name:
model = cls.create(
model_name, model_name=model_name, weights_path=weights_dir_path
model_name,
model_name=model_name,
weights_path=weights_dir_path,
**kwargs,
)
else:
model = cls.create(model_name, weights_path=weights_dir_path)
model = cls.create(model_name, weights_path=weights_dir_path, **kwargs)

return model

Expand Down
34 changes: 34 additions & 0 deletions tests/xturing/models/test_base_model_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from xturing.models import BaseModel


class _DummyModel:
def __init__(self, weights_path=None, model_name=None, **kwargs):
self.weights_path = weights_path
self.model_name = model_name
self.kwargs = kwargs


def test_load_local_dir_without_xturing_config_with_model_name(tmp_path, monkeypatch):
local_weights = tmp_path / "hf-local-model"
local_weights.mkdir()

monkeypatch.setitem(BaseModel.registry, "dummy_model", _DummyModel)

loaded = BaseModel.load(
str(local_weights), model_name="dummy_model", revision="main"
)

assert isinstance(loaded, _DummyModel)
assert loaded.weights_path == local_weights
assert loaded.model_name is None
assert loaded.kwargs["revision"] == "main"


def test_load_local_dir_without_xturing_config_requires_model_name(tmp_path):
local_weights = tmp_path / "hf-local-model"
local_weights.mkdir()

with pytest.raises(ValueError, match="No xturing.json found"):
BaseModel.load(str(local_weights))