Skip to content

Commit 7155d1b

Browse files
[PR-556]: Adding usage of subclasses of ModelClass in Model Upload (#566)
* [PR-556]: Adding usage of subclasses of ModelClass in Model Upload * [PR-556]: Adding usage of subclasses of ModelClass in Model Upload
1 parent 7c3a269 commit 7155d1b

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

clarifai/runners/models/model_builder.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ def _clear_line(n: int = 1) -> None:
4343
print(LINE_UP, end=LINE_CLEAR, flush=True)
4444

4545

46+
def is_related(object_class, main_class):
47+
# Check if the object_class is a subclass of main_class
48+
if issubclass(object_class, main_class):
49+
return True
50+
51+
# Check if the object_class is a subclass of any of the parent classes of main_class
52+
parent_classes = object_class.__bases__
53+
for parent in parent_classes:
54+
if main_class in parent.__bases__:
55+
return True
56+
return False
57+
58+
4659
class ModelBuilder:
4760
DEFAULT_CHECKPOINT_SIZE = 50 * 1024**3 # 50 GiB
4861

@@ -125,7 +138,7 @@ def custom_import(name, globals=None, locals=None, fromlist=(), level=0):
125138
# Find all classes in the model.py file that are subclasses of ModelClass
126139
classes = [
127140
cls for _, cls in inspect.getmembers(module, inspect.isclass)
128-
if issubclass(cls, ModelClass) and cls.__module__ == module.__name__
141+
if is_related(cls, ModelClass) and cls.__module__ == module.__name__
129142
]
130143
# Ensure there is exactly one subclass of BaseRunner in the model.py file
131144
if len(classes) != 1:

0 commit comments

Comments
 (0)