@@ -369,7 +369,6 @@ def _save_pretrained(
369
369
@classmethod
370
370
def _from_pretrained (
371
371
cls ,
372
- * ,
373
372
model_id : str ,
374
373
revision : Optional [str ] = None ,
375
374
cache_dir : Optional [str ] = None ,
@@ -398,11 +397,13 @@ def _from_pretrained(
398
397
"""
399
398
400
399
# Check if model_id is a local directory
401
- if os .path .isdir (model_id ):
402
- model_path = Path (model_id )
403
- else :
404
- # Download config file from the Hub
405
- try :
400
+ model_path = Path (model_id )
401
+ try :
402
+ if model_path .is_dir () and (model_path / "config.json" ).exists ():
403
+ # Local directory with required files
404
+ pass
405
+ else :
406
+ # Download files from HuggingFace Hub
406
407
config_file = hf_hub_download (
407
408
repo_id = model_id ,
408
409
filename = "config.json" ,
@@ -417,7 +418,7 @@ def _from_pretrained(
417
418
model_path = Path (os .path .dirname (config_file ))
418
419
419
420
# Download examples file
420
- examples_file = hf_hub_download (
421
+ hf_hub_download (
421
422
repo_id = model_id ,
422
423
filename = "examples.json" ,
423
424
revision = revision ,
@@ -430,7 +431,7 @@ def _from_pretrained(
430
431
)
431
432
432
433
# Download model file
433
- model_file = hf_hub_download (
434
+ hf_hub_download (
434
435
repo_id = model_id ,
435
436
filename = "model.safetensors" ,
436
437
revision = revision ,
@@ -441,8 +442,8 @@ def _from_pretrained(
441
442
token = token ,
442
443
local_files_only = local_files_only ,
443
444
)
444
- except Exception as e :
445
- raise ValueError (f"Error downloading model from { model_id } : { e } " )
445
+ except Exception as e :
446
+ raise ValueError (f"Error loading model from { model_id } : { e } " )
446
447
447
448
# Load configuration
448
449
with open (model_path / "config.json" , "r" , encoding = "utf-8" ) as f :
@@ -606,12 +607,15 @@ def _format_class_distribution(self, stats: Dict[str, Any]) -> str:
606
607
# Keep existing save/load methods for backwards compatibility
607
608
def save (self , save_dir : str ):
608
609
"""Legacy save method for backwards compatibility."""
609
- self ._save_pretrained (save_dir )
610
+ return self ._save_pretrained (save_dir )
610
611
611
612
@classmethod
612
613
def load (cls , save_dir : str , device : Optional [str ] = None ) -> 'AdaptiveClassifier' :
613
614
"""Legacy load method for backwards compatibility."""
614
- return cls ._from_pretrained (save_dir , device = device )
615
+ kwargs = {}
616
+ if device is not None :
617
+ kwargs ['device' ] = device
618
+ return cls ._from_pretrained (save_dir , ** kwargs )
615
619
616
620
def to (self , device : str ) -> 'AdaptiveClassifier' :
617
621
"""Move the model to specified device.
0 commit comments