Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions src/niftyone/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
"""Large-scale neuroimaging visualization using FiftyOne."""

# Import all plugins. In particular this should register user's custom generators
from . import plugins
from ._version import __version__, __version_tuple__
from pathlib import Path

# Register existing views
from .figures.dwi import (
DwiPerShell,
QSpaceShells,
SignalPerVolume,
)
from .figures.func import CarpetPlot, MeanStd
from .figures.multi_view import (
SliceVideo,
ThreeView,
ThreeViewVideo,
)
from ._version import __version__, __version_tuple__
from .figures.factory import __file__ as factory_path
from .figures.factory import register_views
from .runner import Runner

# Register all default views
register_views(search_path=str(Path(factory_path).parent))
1 change: 1 addition & 0 deletions src/niftyone/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def main() -> None:
sub=args.participant_label,
index_path=args.index,
qc_dir=args.qc_dir,
plugin_dir=args.plugin_dir,
config=args.config,
workers=args.workers,
overwrite=args.overwrite,
Expand Down
7 changes: 7 additions & 0 deletions src/niftyone/analysis_levels/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def participant(
sub: str | None = None,
index_path: Path | None = None,
qc_dir: Path | None = None,
plugin_dir: Path | None = None,
config: Path | None = None,
workers: int = 1,
overwrite: bool = False,
Expand All @@ -60,11 +61,17 @@ def participant(
f"\n\tsubject: {sub}"
f"\n\tindex: {index_path}"
f"\n\tqc: {qc_dir}"
f"\n\tplugin: {plugin_dir}"
f"\n\tconfig: {config}"
f"\n\tworkers: {workers}"
f"\n\toverwrite: {overwrite}"
)

# Register any plugin figure views
factory.register_views(
search_path=str(plugin_dir) if plugin_dir else None, plugin_prefix="niftyone_"
)

logging.info("Loading dataset index")
index = bids2table(bids_dir, index_path=index_path, workers=workers)

Expand Down
8 changes: 8 additions & 0 deletions src/niftyone/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def _add_participant_args(self) -> None:
default=None,
help="pre-computed QC metrics if available",
)
self.participant_level.add_argument(
"--plugin-dir",
metavar="PATH",
type=Path,
default=None,
help="directory to search for plugins in;plugins should be "
"prepended with 'niftyone_'",
)
self.participant_level.add_argument(
"--config",
metavar="PATH",
Expand Down
8 changes: 4 additions & 4 deletions src/niftyone/figures/dwi.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
"""Factories associated with diffusion data."""

from niclips.figures import dwi
from niftyone.figures.factory import View, register
from niftyone.figures.factory import View


@register("qspace_shells")
class QSpaceShells(View):
entities = {"ext": ".mp4", "figure": "qspace"}
view_fn = staticmethod(dwi.visualize_qspace)
view_name = "qspace_shells"


@register("three_view_shell_video")
class DwiPerShell(View):
entities = {"ext": ".mp4", "figure": "bval"}
view_fn = staticmethod(dwi.three_view_per_shell)
view_name = "three_view_shell_video"


@register("signal_per_volume")
class SignalPerVolume(View):
entities = {"ext": ".mp4", "figure": "signalPerVolume"}
view_fn = staticmethod(dwi.signal_per_volume)
view_name = "signal_per_volume"
107 changes: 67 additions & 40 deletions src/niftyone/figures/factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Factory module for creating different figures."""

import importlib.util
import inspect
import logging
import pkgutil
from abc import ABC
from functools import reduce
from pathlib import Path
from types import MappingProxyType
from types import MappingProxyType, ModuleType
from typing import Any, Callable, Generic, TypeVar

import matplotlib.pyplot as plt
Expand All @@ -19,50 +22,12 @@
view_registry: dict[str, type["View"]] = {}


def register(name: str) -> Callable:
"""Function to register view to registry."""

def decorator(cls: type[T]) -> type[T]:
view_registry[name] = cls
return cls

return decorator


def create_view(
view: str,
view_kwargs: dict[str, Any] | None,
join_entities: list[str],
queries: list[str],
) -> "View":
"""Function to create view."""
view_kwargs = view_kwargs or {}
try:
view_cls = view_registry[view]
return view_cls(queries, join_entities, view_kwargs)
except KeyError:
raise KeyError(f"Factory for '{view}' for not found in registry.")


def create_views(config: dict[str, Any]) -> list["View"]:
"""Create selected views dynamically from config with default settings."""
return [
create_view(
view=view,
view_kwargs=view_kwargs,
join_entities=group.get("join_entities", ["sub", "ses"]),
queries=group.get("queries", []),
)
for group in config.get("figures", {}).values()
for view, view_kwargs in group.get("views", {}).items()
]


class View(ABC, Generic[T]):
"""Base view class."""

entities: dict[str, Any] | None = None
view_fn: Callable | None = None
view_name: str | None = None # Name for registry (defaults to class)

def __init__(
self,
Expand Down Expand Up @@ -153,3 +118,65 @@ def create(
self.view_fn(img, out_path, overlays=overlays, **self.view_kwargs)

plt.close("all")


def create_view(
view: str,
view_kwargs: dict[str, Any] | None,
join_entities: list[str],
queries: list[str],
) -> "View":
"""Function to create view."""
view_kwargs = view_kwargs or {}
try:
view_cls = view_registry[view]
return view_cls(queries, join_entities, view_kwargs)
except KeyError:
raise KeyError(f"Factory for '{view}' for not found in registry.")


def create_views(config: dict[str, Any]) -> list["View"]:
"""Create selected views dynamically from config with default settings."""
return [
create_view(
view=view,
view_kwargs=view_kwargs,
join_entities=group.get("join_entities", ["sub", "ses"]),
queries=group.get("queries", []),
)
for group in config.get("figures", {}).values()
for view, view_kwargs in group.get("views", {}).items()
]


def register(cls: type[T]) -> type[T]:
"""Function to add view to registry."""
view_registry[cls.view_name or cls.__name__] = cls
return cls


def register_views(search_path: str | None, plugin_prefix: str | None = None) -> None:
"""Register all views."""

def _import_module(module_name: str) -> ModuleType:
"""Import module."""
if search_path is not None:
module_path = Path(search_path) / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load {module_name}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
module = importlib.import_module(module_name)
return module

for _, module_name, _ in pkgutil.iter_modules(
path=[search_path] if search_path else None
):
if plugin_prefix is None or module_name.startswith(plugin_prefix):
module = _import_module(module_name=module_name)

for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, View) and obj is not View:
register(obj)
4 changes: 1 addition & 3 deletions src/niftyone/figures/func.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""Factories associated with functional data."""

from niclips.figures import bold
from niftyone.figures.factory import View, register
from niftyone.figures.factory import View


@register("carpet_plot")
class CarpetPlot(View):
entities = {"ext": ".png", "figure": "carpet"}
view_fn = staticmethod(bold.carpet_plot)


@register("mean_std")
class MeanStd(View):
entities = {"ext": ".png", "figure": "meanStd"}
view_fn = staticmethod(bold.bold_mean_std)
8 changes: 4 additions & 4 deletions src/niftyone/figures/multi_view.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
"""Factories associated with multi-view."""

from niclips.figures import multi_view
from niftyone.figures.factory import View, register
from niftyone.figures.factory import View


@register("three_view")
class ThreeView(View):
entities = {"ext": ".png", "figure": "threeView"}
view_fn = staticmethod(multi_view.three_view_frame)
view_name = "three_view"


@register("slice_video")
class SliceVideo(View):
entities = {"ext": ".mp4", "figure": "sliceVideo"}
view_fn = staticmethod(multi_view.slice_video)
view_name = "slice_video"


@register("three_view_video")
class ThreeViewVideo(View):
entities = {"ext": ".mp4", "figure": "threeViewVideo"}
view_fn = staticmethod(multi_view.three_view_video)
view_name = "three_view_video"
13 changes: 0 additions & 13 deletions src/niftyone/plugins.py

This file was deleted.

4 changes: 3 additions & 1 deletion tests/unit/niftyone/figures/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def setup_registry():
view_registry.clear()

class TestView(View):
view_name = "test_view"

def create(
self,
records: pd.Series,
Expand All @@ -66,7 +68,7 @@ def create(
) -> None:
pass

register("test_view")(TestView)
register(TestView)
yield
view_registry.clear()

Expand Down