@@ -17,6 +17,18 @@ class _ModelRegistry:
1717 # Keyed by model_arch
1818 models : Dict [str , Union [Type [nn .Module ], str ]] = field (default_factory = dict )
1919
20+ def register (self , package_name : str , overwrite : bool = False ):
21+ new_models = import_model_classes (package_name )
22+ if overwrite :
23+ self .models .update (new_models )
24+ else :
25+ for arch , cls in new_models .items ():
26+ if arch in self .models :
27+ raise ValueError (
28+ f"Model architecture { arch } already registered. Set overwrite=True to replace."
29+ )
30+ self .models [arch ] = cls
31+
2032 def get_supported_archs (self ) -> AbstractSet [str ]:
2133 return self .models .keys ()
2234
@@ -74,9 +86,8 @@ def resolve_model_cls(
7486
7587
7688@lru_cache ()
77- def import_model_classes ():
89+ def import_model_classes (package_name : str ):
7890 model_arch_name_to_cls = {}
79- package_name = "sglang.srt.models"
8091 package = importlib .import_module (package_name )
8192 for _ , name , ispkg in pkgutil .iter_modules (package .__path__ , package_name + "." ):
8293 if not ispkg :
@@ -104,4 +115,5 @@ def import_model_classes():
104115 return model_arch_name_to_cls
105116
106117
107- ModelRegistry = _ModelRegistry (import_model_classes ())
118+ ModelRegistry = _ModelRegistry ()
119+ ModelRegistry .register ("sglang.srt.models" )
0 commit comments