|
1 | 1 | """Factory module for creating different figures.""" |
2 | 2 |
|
| 3 | +import importlib.util |
| 4 | +import inspect |
3 | 5 | import logging |
| 6 | +import pkgutil |
4 | 7 | from abc import ABC |
5 | 8 | from functools import reduce |
6 | 9 | from pathlib import Path |
| 10 | +from types import ModuleType |
7 | 11 | from typing import Any, Callable, Generic, TypeVar |
8 | 12 |
|
9 | 13 | import matplotlib.pyplot as plt |
|
20 | 24 | view_registry: dict[str, type["View"]] = {} |
21 | 25 |
|
22 | 26 |
|
23 | | -def register(name: str) -> Callable[[type[T]], type[T]]: |
24 | | - """Function to register view to registry.""" |
25 | | - |
26 | | - def decorator(cls: type[T]) -> type[T]: |
27 | | - view_registry[name] = cls |
28 | | - return cls |
29 | | - |
30 | | - return decorator |
31 | | - |
32 | | - |
33 | | -def create_view( |
34 | | - view: str, |
35 | | - view_kwargs: dict[str, Any] | None, |
36 | | - join_entities: list[str], |
37 | | - queries: list[str], |
38 | | -) -> "View": |
39 | | - """Function to create view.""" |
40 | | - if view_kwargs is None: |
41 | | - view_kwargs = {} |
42 | | - |
43 | | - try: |
44 | | - view_cls = view_registry[view] |
45 | | - view_instance = view_cls(queries, join_entities, view_kwargs) |
46 | | - return view_instance |
47 | | - except KeyError: |
48 | | - msg = f"Factory for '{view}' for not found in registry." |
49 | | - raise KeyError(msg) |
50 | | - |
51 | | - |
52 | | -def create_views(config: dict[str, Any]) -> list["View"]: |
53 | | - """Create selected views dynamically from config with default settings.""" |
54 | | - views: list["View"] = [] |
55 | | - |
56 | | - for group in config.get("figures", {}).values(): |
57 | | - queries = group.get("queries", []) |
58 | | - join_entities = group.get("join_entities", ["sub", "ses"]) |
59 | | - views_map = group.get("views", {}) |
60 | | - |
61 | | - for view, view_kwargs in views_map.items(): |
62 | | - views.append( |
63 | | - create_view( |
64 | | - view=view, |
65 | | - view_kwargs=view_kwargs, |
66 | | - join_entities=join_entities, |
67 | | - queries=queries, |
68 | | - ) |
69 | | - ) |
70 | | - |
71 | | - return views |
72 | | - |
73 | | - |
74 | 27 | class View(ABC, Generic[T]): |
75 | 28 | """Base view class.""" |
76 | 29 |
|
77 | 30 | entities: dict[str, Any] | None = None |
78 | 31 | view_fn: Callable[[nib.Nifti1Image, Path], Image | Figure | None] | None = None |
| 32 | + view_name: str | None = None # Name for registry (defaults to class) |
79 | 33 |
|
80 | 34 | def __init__( |
81 | 35 | self, |
@@ -162,3 +116,78 @@ def create( |
162 | 116 | self.view_fn(img, out_path, **self.view_kwargs) |
163 | 117 |
|
164 | 118 | plt.close("all") |
| 119 | + |
| 120 | + |
| 121 | +def create_view( |
| 122 | + view: str, |
| 123 | + view_kwargs: dict[str, Any] | None, |
| 124 | + join_entities: list[str], |
| 125 | + queries: list[str], |
| 126 | +) -> "View": |
| 127 | + """Function to create view.""" |
| 128 | + if view_kwargs is None: |
| 129 | + view_kwargs = {} |
| 130 | + |
| 131 | + try: |
| 132 | + view_cls = view_registry[view] |
| 133 | + view_instance = view_cls(queries, join_entities, view_kwargs) |
| 134 | + return view_instance |
| 135 | + except KeyError: |
| 136 | + msg = f"Factory for '{view}' for not found in registry." |
| 137 | + raise KeyError(msg) |
| 138 | + |
| 139 | + |
| 140 | +def create_views(config: dict[str, Any]) -> list["View"]: |
| 141 | + """Create selected views dynamically from config with default settings.""" |
| 142 | + views: list["View"] = [] |
| 143 | + |
| 144 | + for group in config.get("figures", {}).values(): |
| 145 | + queries = group.get("queries", []) |
| 146 | + join_entities = group.get("join_entities", ["sub", "ses"]) |
| 147 | + views_map = group.get("views", {}) |
| 148 | + |
| 149 | + for view, view_kwargs in views_map.items(): |
| 150 | + views.append( |
| 151 | + create_view( |
| 152 | + view=view, |
| 153 | + view_kwargs=view_kwargs, |
| 154 | + join_entities=join_entities, |
| 155 | + queries=queries, |
| 156 | + ) |
| 157 | + ) |
| 158 | + |
| 159 | + return views |
| 160 | + |
| 161 | + |
| 162 | +def register(cls: type[T]) -> type[T]: |
| 163 | + """Function to add view to registry.""" |
| 164 | + view_registry[cls.view_name or cls.__name__] = cls |
| 165 | + return cls |
| 166 | + |
| 167 | + |
| 168 | +def register_views(search_path: str | None, plugin_prefix: str | None = None) -> None: |
| 169 | + """Register all views.""" |
| 170 | + |
| 171 | + def _import_module(module_name: str) -> ModuleType: |
| 172 | + """Import module.""" |
| 173 | + if search_path is not None: |
| 174 | + module_path = Path(search_path) / f"{module_name}.py" |
| 175 | + spec = importlib.util.spec_from_file_location(module_name, module_path) |
| 176 | + if spec is None or spec.loader is None: |
| 177 | + raise ImportError(f"Unable to load {module_name}") |
| 178 | + module = importlib.util.module_from_spec(spec) |
| 179 | + spec.loader.exec_module(module) |
| 180 | + else: |
| 181 | + module = importlib.import_module(module_name) |
| 182 | + return module |
| 183 | + |
| 184 | + for _, module_name, _ in pkgutil.iter_modules( |
| 185 | + path=[search_path] if search_path else None |
| 186 | + ): |
| 187 | + if plugin_prefix is None or module_name.startswith(plugin_prefix): |
| 188 | + print(plugin_prefix, module_name) |
| 189 | + module = _import_module(module_name=module_name) |
| 190 | + |
| 191 | + for _, obj in inspect.getmembers(module, inspect.isclass): |
| 192 | + if issubclass(obj, View) and obj is not View: |
| 193 | + register(obj) |
0 commit comments