Skip to content

Commit d94e7bd

Browse files
authored
Merge pull request #32 from codelion/fix-load-local
Fix load local
2 parents 196f61c + b66fa15 commit d94e7bd

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.11",
18+
version="0.0.12",
1919
author="codelion",
2020
author_email="[email protected]",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/classifier.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ def _save_pretrained(
369369
@classmethod
370370
def _from_pretrained(
371371
cls,
372-
*,
373372
model_id: str,
374373
revision: Optional[str] = None,
375374
cache_dir: Optional[str] = None,
@@ -398,11 +397,13 @@ def _from_pretrained(
398397
"""
399398

400399
# Check if model_id is a local directory
401-
if os.path.isdir(model_id):
402-
model_path = Path(model_id)
403-
else:
404-
# Download config file from the Hub
405-
try:
400+
model_path = Path(model_id)
401+
try:
402+
if model_path.is_dir() and (model_path / "config.json").exists():
403+
# Local directory with required files
404+
pass
405+
else:
406+
# Download files from HuggingFace Hub
406407
config_file = hf_hub_download(
407408
repo_id=model_id,
408409
filename="config.json",
@@ -417,7 +418,7 @@ def _from_pretrained(
417418
model_path = Path(os.path.dirname(config_file))
418419

419420
# Download examples file
420-
examples_file = hf_hub_download(
421+
hf_hub_download(
421422
repo_id=model_id,
422423
filename="examples.json",
423424
revision=revision,
@@ -430,7 +431,7 @@ def _from_pretrained(
430431
)
431432

432433
# Download model file
433-
model_file = hf_hub_download(
434+
hf_hub_download(
434435
repo_id=model_id,
435436
filename="model.safetensors",
436437
revision=revision,
@@ -441,8 +442,8 @@ def _from_pretrained(
441442
token=token,
442443
local_files_only=local_files_only,
443444
)
444-
except Exception as e:
445-
raise ValueError(f"Error downloading model from {model_id}: {e}")
445+
except Exception as e:
446+
raise ValueError(f"Error loading model from {model_id}: {e}")
446447

447448
# Load configuration
448449
with open(model_path / "config.json", "r", encoding="utf-8") as f:
@@ -606,12 +607,15 @@ def _format_class_distribution(self, stats: Dict[str, Any]) -> str:
606607
# Keep existing save/load methods for backwards compatibility
607608
def save(self, save_dir: str):
608609
"""Legacy save method for backwards compatibility."""
609-
self._save_pretrained(save_dir)
610+
return self._save_pretrained(save_dir)
610611

611612
@classmethod
612613
def load(cls, save_dir: str, device: Optional[str] = None) -> 'AdaptiveClassifier':
613614
"""Legacy load method for backwards compatibility."""
614-
return cls._from_pretrained(save_dir, device=device)
615+
kwargs = {}
616+
if device is not None:
617+
kwargs['device'] = device
618+
return cls._from_pretrained(save_dir, **kwargs)
615619

616620
def to(self, device: str) -> 'AdaptiveClassifier':
617621
"""Move the model to specified device.

0 commit comments

Comments
 (0)