Skip to content

Commit 3e92f99

Browse files
committed
Refactor plugin code for general view import
Repurposes the code from plugins.py (now deprecated) to register the different view factories. To maintain compatability, added a "view_name" attribute which is the name that gets added to the registry. Additionally, the import code tries to import from the path if a search path is provided, otherwise will search the sys.path for existing modules with the "niftyone_" prefix.
1 parent 95c0715 commit 3e92f99

10 files changed

Lines changed: 102 additions & 80 deletions

File tree

src/niftyone/__init__.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,11 @@
11
"""Large-scale neuroimaging visualization using FiftyOne."""
22

3-
# Import all plugins. In particular this should register user's custom generators
4-
from . import plugins
5-
from ._version import __version__, __version_tuple__
3+
from pathlib import Path
64

7-
# Register existing views
8-
from .figures.dwi import (
9-
DwiPerShell,
10-
QSpaceShells,
11-
SignalPerVolume,
12-
)
13-
from .figures.func import CarpetPlot, MeanStd
14-
from .figures.multi_view import (
15-
SliceVideo,
16-
ThreeView,
17-
ThreeViewVideo,
18-
)
5+
from ._version import __version__, __version_tuple__
6+
from .figures.factory import __file__ as factory_path
7+
from .figures.factory import register_views
198
from .runner import Runner
9+
10+
# Register all default views
11+
register_views(search_path=str(Path(factory_path).parent))

src/niftyone/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def main() -> None:
2828
sub=args.participant_label,
2929
index_path=args.index,
3030
qc_dir=args.qc_dir,
31+
plugin_dir=args.plugin_dir,
3132
config=args.config,
3233
workers=args.workers,
3334
overwrite=args.overwrite,

src/niftyone/analysis_levels/participant.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def participant(
3535
sub: str | None = None,
3636
index_path: Path | None = None,
3737
qc_dir: Path | None = None,
38+
plugin_dir: Path | None = None,
3839
config: Path | None = None,
3940
workers: int = 1,
4041
overwrite: bool = False,
@@ -60,11 +61,17 @@ def participant(
6061
f"\n\tsubject: {sub}"
6162
f"\n\tindex: {index_path}"
6263
f"\n\tqc: {qc_dir}"
64+
f"\n\tplugin: {plugin_dir}"
6365
f"\n\tconfig: {config}"
6466
f"\n\tworkers: {workers}"
6567
f"\n\toverwrite: {overwrite}"
6668
)
6769

70+
# Register any plugin figure views
71+
factory.register_views(
72+
search_path=str(plugin_dir) if plugin_dir else None, plugin_prefix="niftyone_"
73+
)
74+
6875
logging.info("Loading dataset index")
6976
index = bids2table(bids_dir, index_path=index_path, workers=workers)
7077

src/niftyone/cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def _add_participant_args(self) -> None:
8686
default=None,
8787
help="pre-computed QC metrics if available",
8888
)
89+
self.participant_level.add_argument(
90+
"--plugin-dir",
91+
metavar="PATH",
92+
type=Path,
93+
default=None,
94+
help="directory to search for plugins in;plugins should be "
95+
"prepended with 'niftyone_'",
96+
)
8997
self.participant_level.add_argument(
9098
"--config",
9199
metavar="PATH",

src/niftyone/figures/dwi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
"""Factories associated with diffusion data."""
22

33
from niclips.figures import dwi
4-
from niftyone.figures.factory import View, register
4+
from niftyone.figures.factory import View
55

66

7-
@register("qspace_shells")
87
class QSpaceShells(View):
98
entities = {"ext": ".mp4", "figure": "qspace"}
109
view_fn = staticmethod(dwi.visualize_qspace)
10+
view_name = "qspace_shells"
1111

1212

13-
@register("three_view_shell_video")
1413
class DwiPerShell(View):
1514
entities = {"ext": ".mp4", "figure": "bval"}
1615
view_fn = staticmethod(dwi.three_view_per_shell)
16+
view_name = "three_view_shell_video"
1717

1818

19-
@register("signal_per_volume")
2019
class SignalPerVolume(View):
2120
entities = {"ext": ".mp4", "figure": "signalPerVolume"}
2221
view_fn = staticmethod(dwi.signal_per_volume)
22+
view_name = "signal_per_volume"

src/niftyone/figures/factory.py

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Factory module for creating different figures."""
22

3+
import importlib.util
4+
import inspect
35
import logging
6+
import pkgutil
47
from abc import ABC
58
from functools import reduce
69
from pathlib import Path
7-
from types import MappingProxyType
10+
from types import MappingProxyType, ModuleType
811
from typing import Any, Callable, Generic, TypeVar
912

1013
import matplotlib.pyplot as plt
@@ -19,50 +22,12 @@
1922
view_registry: dict[str, type["View"]] = {}
2023

2124

22-
def register(name: str) -> Callable:
23-
"""Function to register view to registry."""
24-
25-
def decorator(cls: type[T]) -> type[T]:
26-
view_registry[name] = cls
27-
return cls
28-
29-
return decorator
30-
31-
32-
def create_view(
33-
view: str,
34-
view_kwargs: dict[str, Any] | None,
35-
join_entities: list[str],
36-
queries: list[str],
37-
) -> "View":
38-
"""Function to create view."""
39-
view_kwargs = view_kwargs or {}
40-
try:
41-
view_cls = view_registry[view]
42-
return view_cls(queries, join_entities, view_kwargs)
43-
except KeyError:
44-
raise KeyError(f"Factory for '{view}' for not found in registry.")
45-
46-
47-
def create_views(config: dict[str, Any]) -> list["View"]:
48-
"""Create selected views dynamically from config with default settings."""
49-
return [
50-
create_view(
51-
view=view,
52-
view_kwargs=view_kwargs,
53-
join_entities=group.get("join_entities", ["sub", "ses"]),
54-
queries=group.get("queries", []),
55-
)
56-
for group in config.get("figures", {}).values()
57-
for view, view_kwargs in group.get("views", {}).items()
58-
]
59-
60-
6125
class View(ABC, Generic[T]):
6226
"""Base view class."""
6327

6428
entities: dict[str, Any] | None = None
6529
view_fn: Callable | None = None
30+
view_name: str | None = None # Name for registry (defaults to class)
6631

6732
def __init__(
6833
self,
@@ -153,3 +118,65 @@ def create(
153118
self.view_fn(img, out_path, overlays=overlays, **self.view_kwargs)
154119

155120
plt.close("all")
121+
122+
123+
def create_view(
124+
view: str,
125+
view_kwargs: dict[str, Any] | None,
126+
join_entities: list[str],
127+
queries: list[str],
128+
) -> "View":
129+
"""Function to create view."""
130+
view_kwargs = view_kwargs or {}
131+
try:
132+
view_cls = view_registry[view]
133+
return view_cls(queries, join_entities, view_kwargs)
134+
except KeyError:
135+
raise KeyError(f"Factory for '{view}' for not found in registry.")
136+
137+
138+
def create_views(config: dict[str, Any]) -> list["View"]:
139+
"""Create selected views dynamically from config with default settings."""
140+
return [
141+
create_view(
142+
view=view,
143+
view_kwargs=view_kwargs,
144+
join_entities=group.get("join_entities", ["sub", "ses"]),
145+
queries=group.get("queries", []),
146+
)
147+
for group in config.get("figures", {}).values()
148+
for view, view_kwargs in group.get("views", {}).items()
149+
]
150+
151+
152+
def register(cls: type[T]) -> type[T]:
153+
"""Function to add view to registry."""
154+
view_registry[cls.view_name or cls.__name__] = cls
155+
return cls
156+
157+
158+
def register_views(search_path: str | None, plugin_prefix: str | None = None) -> None:
159+
"""Register all views."""
160+
161+
def _import_module(module_name: str) -> ModuleType:
162+
"""Import module."""
163+
if search_path is not None:
164+
module_path = Path(search_path) / f"{module_name}.py"
165+
spec = importlib.util.spec_from_file_location(module_name, module_path)
166+
if spec is None or spec.loader is None:
167+
raise ImportError(f"Unable to load {module_name}")
168+
module = importlib.util.module_from_spec(spec)
169+
spec.loader.exec_module(module)
170+
else:
171+
module = importlib.import_module(module_name)
172+
return module
173+
174+
for _, module_name, _ in pkgutil.iter_modules(
175+
path=[search_path] if search_path else None
176+
):
177+
if plugin_prefix is None or module_name.startswith(plugin_prefix):
178+
module = _import_module(module_name=module_name)
179+
180+
for _, obj in inspect.getmembers(module, inspect.isclass):
181+
if issubclass(obj, View) and obj is not View:
182+
register(obj)

src/niftyone/figures/func.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""Factories associated with functional data."""
22

33
from niclips.figures import bold
4-
from niftyone.figures.factory import View, register
4+
from niftyone.figures.factory import View
55

66

7-
@register("carpet_plot")
87
class CarpetPlot(View):
98
entities = {"ext": ".png", "figure": "carpet"}
109
view_fn = staticmethod(bold.carpet_plot)
1110

1211

13-
@register("mean_std")
1412
class MeanStd(View):
1513
entities = {"ext": ".png", "figure": "meanStd"}
1614
view_fn = staticmethod(bold.bold_mean_std)

src/niftyone/figures/multi_view.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
"""Factories associated with multi-view."""
22

33
from niclips.figures import multi_view
4-
from niftyone.figures.factory import View, register
4+
from niftyone.figures.factory import View
55

66

7-
@register("three_view")
87
class ThreeView(View):
98
entities = {"ext": ".png", "figure": "threeView"}
109
view_fn = staticmethod(multi_view.three_view_frame)
10+
view_name = "three_view"
1111

1212

13-
@register("slice_video")
1413
class SliceVideo(View):
1514
entities = {"ext": ".mp4", "figure": "sliceVideo"}
1615
view_fn = staticmethod(multi_view.slice_video)
16+
view_name = "slice_video"
1717

1818

19-
@register("three_view_video")
2019
class ThreeViewVideo(View):
2120
entities = {"ext": ".mp4", "figure": "threeViewVideo"}
2221
view_fn = staticmethod(multi_view.three_view_video)
22+
view_name = "three_view_video"

src/niftyone/plugins.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/unit/niftyone/figures/test_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def setup_registry():
5858
view_registry.clear()
5959

6060
class TestView(View):
61+
view_name = "test_view"
62+
6163
def create(
6264
self,
6365
records: pd.Series,
@@ -66,7 +68,7 @@ def create(
6668
) -> None:
6769
pass
6870

69-
register("test_view")(TestView)
71+
register(TestView)
7072
yield
7173
view_registry.clear()
7274

0 commit comments

Comments
 (0)