|
1 | 1 | """Factory module for creating different figures.""" |
2 | 2 |
|
| 3 | +import importlib.util |
| 4 | +import inspect |
3 | 5 | import logging |
| 6 | +import pkgutil |
4 | 7 | from abc import ABC |
5 | 8 | from functools import reduce |
6 | 9 | from pathlib import Path |
7 | | -from types import MappingProxyType |
| 10 | +from types import MappingProxyType, ModuleType |
8 | 11 | from typing import Any, Callable, Generic, TypeVar |
9 | 12 |
|
10 | 13 | import matplotlib.pyplot as plt |
|
19 | 22 | view_registry: dict[str, type["View"]] = {} |
20 | 23 |
|
21 | 24 |
|
22 | | -def register(name: str) -> Callable: |
23 | | - """Function to register view to registry.""" |
24 | | - |
25 | | - def decorator(cls: type[T]) -> type[T]: |
26 | | - view_registry[name] = cls |
27 | | - return cls |
28 | | - |
29 | | - return decorator |
30 | | - |
31 | | - |
32 | | -def create_view( |
33 | | - view: str, |
34 | | - view_kwargs: dict[str, Any] | None, |
35 | | - join_entities: list[str], |
36 | | - queries: list[str], |
37 | | -) -> "View": |
38 | | - """Function to create view.""" |
39 | | - view_kwargs = view_kwargs or {} |
40 | | - try: |
41 | | - view_cls = view_registry[view] |
42 | | - return view_cls(queries, join_entities, view_kwargs) |
43 | | - except KeyError: |
44 | | - raise KeyError(f"Factory for '{view}' for not found in registry.") |
45 | | - |
46 | | - |
47 | | -def create_views(config: dict[str, Any]) -> list["View"]: |
48 | | - """Create selected views dynamically from config with default settings.""" |
49 | | - return [ |
50 | | - create_view( |
51 | | - view=view, |
52 | | - view_kwargs=view_kwargs, |
53 | | - join_entities=group.get("join_entities", ["sub", "ses"]), |
54 | | - queries=group.get("queries", []), |
55 | | - ) |
56 | | - for group in config.get("figures", {}).values() |
57 | | - for view, view_kwargs in group.get("views", {}).items() |
58 | | - ] |
59 | | - |
60 | | - |
61 | 25 | class View(ABC, Generic[T]): |
62 | 26 | """Base view class.""" |
63 | 27 |
|
64 | 28 | entities: dict[str, Any] | None = None |
65 | 29 | view_fn: Callable | None = None |
| 30 | + view_name: str | None = None # Name for registry (defaults to class) |
66 | 31 |
|
67 | 32 | def __init__( |
68 | 33 | self, |
@@ -153,3 +118,65 @@ def create( |
153 | 118 | self.view_fn(img, out_path, overlays=overlays, **self.view_kwargs) |
154 | 119 |
|
155 | 120 | plt.close("all") |
| 121 | + |
| 122 | + |
| 123 | +def create_view( |
| 124 | + view: str, |
| 125 | + view_kwargs: dict[str, Any] | None, |
| 126 | + join_entities: list[str], |
| 127 | + queries: list[str], |
| 128 | +) -> "View": |
| 129 | + """Function to create view.""" |
| 130 | + view_kwargs = view_kwargs or {} |
| 131 | + try: |
| 132 | + view_cls = view_registry[view] |
| 133 | + return view_cls(queries, join_entities, view_kwargs) |
| 134 | + except KeyError: |
| 135 | + raise KeyError(f"Factory for '{view}' for not found in registry.") |
| 136 | + |
| 137 | + |
| 138 | +def create_views(config: dict[str, Any]) -> list["View"]: |
| 139 | + """Create selected views dynamically from config with default settings.""" |
| 140 | + return [ |
| 141 | + create_view( |
| 142 | + view=view, |
| 143 | + view_kwargs=view_kwargs, |
| 144 | + join_entities=group.get("join_entities", ["sub", "ses"]), |
| 145 | + queries=group.get("queries", []), |
| 146 | + ) |
| 147 | + for group in config.get("figures", {}).values() |
| 148 | + for view, view_kwargs in group.get("views", {}).items() |
| 149 | + ] |
| 150 | + |
| 151 | + |
| 152 | +def register(cls: type[T]) -> type[T]: |
| 153 | + """Function to add view to registry.""" |
| 154 | + view_registry[cls.view_name or cls.__name__] = cls |
| 155 | + return cls |
| 156 | + |
| 157 | + |
| 158 | +def register_views(search_path: str | None, plugin_prefix: str | None = None) -> None: |
| 159 | + """Register all views.""" |
| 160 | + |
| 161 | + def _import_module(module_name: str) -> ModuleType: |
| 162 | + """Import module.""" |
| 163 | + if search_path is not None: |
| 164 | + module_path = Path(search_path) / f"{module_name}.py" |
| 165 | + spec = importlib.util.spec_from_file_location(module_name, module_path) |
| 166 | + if spec is None or spec.loader is None: |
| 167 | + raise ImportError(f"Unable to load {module_name}") |
| 168 | + module = importlib.util.module_from_spec(spec) |
| 169 | + spec.loader.exec_module(module) |
| 170 | + else: |
| 171 | + module = importlib.import_module(module_name) |
| 172 | + return module |
| 173 | + |
| 174 | + for _, module_name, _ in pkgutil.iter_modules( |
| 175 | + path=[search_path] if search_path else None |
| 176 | + ): |
| 177 | + if plugin_prefix is None or module_name.startswith(plugin_prefix): |
| 178 | + module = _import_module(module_name=module_name) |
| 179 | + |
| 180 | + for _, obj in inspect.getmembers(module, inspect.isclass): |
| 181 | + if issubclass(obj, View) and obj is not View: |
| 182 | + register(obj) |
0 commit comments