Skip to content

Commit 31fb19a

Browse files
merrymercycctry
andauthored
[Auto Sync] Update registry.py (20250915) (sgl-project#10484)
Co-authored-by: cctry <shiyang@x.ai>
1 parent 3f41b48 commit 31fb19a

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

python/sglang/srt/models/registry.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ class _ModelRegistry:
1717
# Keyed by model_arch
1818
models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
1919

20+
def register(self, package_name: str, overwrite: bool = False):
21+
new_models = import_model_classes(package_name)
22+
if overwrite:
23+
self.models.update(new_models)
24+
else:
25+
for arch, cls in new_models.items():
26+
if arch in self.models:
27+
raise ValueError(
28+
f"Model architecture {arch} already registered. Set overwrite=True to replace."
29+
)
30+
self.models[arch] = cls
31+
2032
def get_supported_archs(self) -> AbstractSet[str]:
2133
return self.models.keys()
2234

@@ -74,9 +86,8 @@ def resolve_model_cls(
7486

7587

7688
@lru_cache()
77-
def import_model_classes():
89+
def import_model_classes(package_name: str):
7890
model_arch_name_to_cls = {}
79-
package_name = "sglang.srt.models"
8091
package = importlib.import_module(package_name)
8192
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
8293
if not ispkg:
@@ -104,4 +115,5 @@ def import_model_classes():
104115
return model_arch_name_to_cls
105116

106117

107-
ModelRegistry = _ModelRegistry(import_model_classes())
118+
ModelRegistry = _ModelRegistry()
119+
ModelRegistry.register("sglang.srt.models")

0 commit comments

Comments
 (0)