Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b1a2135
Enhance CellMapDataSplit and sampling utilities; initialize dataset l…
rhoadesScholar Oct 30, 2025
0ffcaf3
Refactor sampling utilities to enhance memory efficiency and redundan…
rhoadesScholar Oct 31, 2025
d90d5ad
Increase MAX_SIZE to 512 million for improved handling of larger data…
rhoadesScholar Oct 31, 2025
f0f46eb
Enhance CellMapDataset to improve ThreadPoolExecutor management; add …
rhoadesScholar Oct 31, 2025
b5a80cc
Fix index generation in CellMapDataset to handle non-positive chunk s…
rhoadesScholar Oct 31, 2025
f82b928
Initial plan
Copilot Oct 31, 2025
79dc891
Update src/cellmap_data/datasplit.py
rhoadesScholar Oct 31, 2025
06b3377
Initial plan
Copilot Oct 31, 2025
e91f776
Merge branch 'hot_fix' into copilot/sub-pr-43
rhoadesScholar Oct 31, 2025
89422bb
Add timeout parameter to ThreadPoolExecutor.shutdown() in __del__ to …
Copilot Oct 31, 2025
e45e6b9
Move dataset list initialization to avoid redundancy
Copilot Oct 31, 2025
7d499b0
Update src/cellmap_data/datasplit.py
rhoadesScholar Oct 31, 2025
c1cfbcc
Merge pull request #44 from janelia-cellmap/copilot/sub-pr-43
rhoadesScholar Oct 31, 2025
c894189
Merge branch 'hot_fix' into copilot/sub-pr-43-again
rhoadesScholar Oct 31, 2025
1a85b4e
Merge pull request #45 from janelia-cellmap/copilot/sub-pr-43-again
rhoadesScholar Oct 31, 2025
af6dc5d
Initial plan
Copilot Oct 31, 2025
9da0c51
Replace custom dataloader with PyTorch's optimized DataLoader
Copilot Oct 31, 2025
e3bae69
Update tests to work with PyTorch DataLoader backend
Copilot Oct 31, 2025
6e55ba3
Add comprehensive documentation for dataloader optimizations
Copilot Oct 31, 2025
3c99eeb
Add performance verification guide and optimization summary
Copilot Oct 31, 2025
1afa66d
Add parameter validation for pin_memory and prefetch_factor
Copilot Oct 31, 2025
2ff3fac
Improve error message for prefetch_factor validation
Copilot Oct 31, 2025
7a395ec
Update src/cellmap_data/dataloader.py
rhoadesScholar Nov 3, 2025
ce843fb
Update src/cellmap_data/dataloader.py
rhoadesScholar Nov 3, 2025
b60ebc3
Delete docs/DATALOADER_OPTIMIZATION.md
rhoadesScholar Nov 3, 2025
8f5846e
Delete docs/performance_verification.md
rhoadesScholar Nov 3, 2025
018360f
Delete OPTIMIZATION_SUMMARY.md
rhoadesScholar Nov 3, 2025
e38938c
Remove redundant device transfers and fix unused variable warnings
Copilot Nov 3, 2025
c1199b5
Merge pull request #46 from janelia-cellmap/copilot/review-dataloader…
rhoadesScholar Nov 3, 2025
964ef62
Initial plan
Copilot Nov 4, 2025
9c99702
Update test_refactored_integration.py to use _pytorch_loader instead …
Copilot Nov 4, 2025
ae38773
Merge pull request #47 from janelia-cellmap/copilot/sub-pr-43
rhoadesScholar Nov 4, 2025
cbe5ba7
Initial plan
Copilot Nov 4, 2025
9790a49
Fix black and ruff formatting issues
Copilot Nov 4, 2025
c50a1b3
Update Python version to 3.11 in CI workflow
rhoadesScholar Nov 5, 2025
596fb76
Remove DEFAULT_TIMEOUT and related timeout logic
rhoadesScholar Nov 5, 2025
8b37eb7
Fix failing tests - update expectations for CPU pin_memory, prefetch_…
Copilot Nov 5, 2025
45d6054
Refactor dataset writer and related modules for improved clarity and …
rhoadesScholar Nov 7, 2025
81877dd
Refactor MockDatasetWithArrays - move class definition outside of tes…
rhoadesScholar Nov 7, 2025
ca3c47a
Merge branch 'hot_fix' into copilot/sub-pr-43
rhoadesScholar Nov 7, 2025
10da029
Refactor code for improved readability and consistency across multipl…
rhoadesScholar Nov 7, 2025
8ed1345
Add method to retrieve random subset indices from the dataset
rhoadesScholar Nov 7, 2025
b671919
Merge pull request #48 from janelia-cellmap/copilot/sub-pr-43
rhoadesScholar Nov 7, 2025
a9b82d5
Remove obsolete test files for GPU transfer, image classes, performan…
rhoadesScholar Nov 7, 2025
8f7eccb
Initial plan
Copilot Nov 7, 2025
4e4301d
Add comprehensive test files for core components
Copilot Nov 7, 2025
3b615f3
Add comprehensive tests for loaders, multi-dataset, writer, and integ…
Copilot Nov 7, 2025
8f6267c
Add comprehensive test README documentation
Copilot Nov 7, 2025
446d2ce
Fix core API mismatches in tests
Copilot Nov 13, 2025
e68c5ed
Add force_has_data and target_bounds to fix more tests
Copilot Nov 13, 2025
7120fd5
Refactor tests for MutableSubsetRandomSampler and augmentation transf…
rhoadesScholar Nov 18, 2025
1fdf6b2
Refactor dataset and image classes for improved structure and functio…
rhoadesScholar Nov 25, 2025
e0b77f6
Merge branch 'main' into feature/base-classes
rhoadesScholar Nov 25, 2025
53e453d
Add numpy for random sampling in MutableSubsetRandomSampler tests
rhoadesScholar Nov 25, 2025
dd6a985
Update src/cellmap_data/dataset.py
rhoadesScholar Nov 25, 2025
57703de
Remove unused imports from test_helpers.py
rhoadesScholar Nov 25, 2025
cec38f7
Merge branch 'feature/base-classes' of github.com:janelia-cellmap/cel…
rhoadesScholar Nov 25, 2025
9fd50bc
Rename target_class to label_class in ImageWriter for clarity
rhoadesScholar Nov 25, 2025
d5affc5
Fix path separator in ImageWriter tests for cross-platform compatibility
rhoadesScholar Nov 26, 2025
af12ea5
Refactor ImageWriter tests to use temporary UPath fixtures for improv…
rhoadesScholar Nov 26, 2025
dcea99d
Fix path handling in ImageWriter tests for improved compatibility wit…
rhoadesScholar Nov 26, 2025
c79601c
Normalize path handling in ImageWriter tests for improved consistency…
rhoadesScholar Nov 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: '3.11'
- name: Install docs dependencies
run: |
pip install -U pip
Expand All @@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: '3.11'
- name: Install black
run: pip install black
- name: Check formatting
Expand Down Expand Up @@ -116,7 +116,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.x"
python-version: "3.11"
- name: Install build tools
run: pip install -U pip hatch twine
- name: Build sdist and wheel
Expand Down
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,10 @@ scratch/
# PyPi builds
dist/
build/
clean/
clean/

# VS Code settings, etc.
.vscode/
.pytest_cache/
__pycache__/
mypy_cache/
26 changes: 13 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ repos:
# - id: conventional-pre-commit
# stages: [commit-msg]

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix]
# - repo: https://github.com/charliermarsh/ruff-pre-commit
# rev: v0.3.0
# hooks:
# - id: ruff
# args: [--fix]

- repo: https://github.com/psf/black
rev: 24.2.0
Expand All @@ -28,11 +28,11 @@ repos:
hooks:
- id: validate-pyproject

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
files: "^src/"
# # you have to add the things you want to type check against here
# additional_dependencies:
# - numpy
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.8.0
# hooks:
# - id: mypy
# files: "^src/"
# # # you have to add the things you want to type check against here
# # additional_dependencies:
# # - numpy
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,34 @@ sampler = multi_dataset.get_weighted_sampler(batch_size=4)

### CellMapDataLoader

High-performance data loader with optimization features:
High-performance data loader built on PyTorch's optimized DataLoader:

```python
loader = CellMapDataLoader(
dataset,
batch_size=16,
batch_size=32,
num_workers=12,
weighted_sampler=True,
device="cuda",
prefetch_factor=4, # Preload batches for better GPU utilization
persistent_workers=True, # Keep workers alive between epochs
pin_memory=True, # Fast CPU-to-GPU transfer
iterations_per_epoch=1000 # For large datasets
)

# Optimized GPU memory transfer
loader.to("cuda", non_blocking=True)
```

**Optimizations**:
**Optimizations** (powered by PyTorch DataLoader):

- **Prefetch Factor**: Background data loading to maximize GPU utilization
- **Pin Memory**: Fast CPU-to-GPU transfers via pinned memory (auto-enabled on CUDA)
- **Persistent Workers**: Reduced overhead by keeping workers alive between epochs
- **PyTorch's Optimized Multiprocessing**: Battle-tested parallel data loading
- **Smart Defaults**: Automatic optimization based on hardware configuration

- CUDA streams for parallel GPU transfer
- Persistent workers for reduced overhead
- Automatic memory estimation and optimization
- Thread-safe multiprocessing
See [DataLoader Optimization Guide](docs/DATALOADER_OPTIMIZATION.md) for performance tuning tips.

### CellMapDataSplit

Expand Down
42 changes: 27 additions & 15 deletions src/cellmap_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
"""
CellMap Data Loading Module.
"""Utility for loading CellMap data for machine learning training."""

Utility for loading CellMap data for machine learning training,
utilizing PyTorch, TensorStore, XArray, and PyDantic.
"""

from importlib.metadata import PackageNotFoundError, version
try:
from importlib.metadata import PackageNotFoundError, version
except ImportError:
from importlib_metadata import PackageNotFoundError, version

try:
__version__ = version("cellmap_data")
__version__ = version("cellmap-data")
except PackageNotFoundError:
__version__ = "0.1.1"

__version__ = "uninstalled"
__author__ = "Jeff Rhoades"
__email__ = "[email protected]"

from .multidataset import CellMapMultiDataset
from .base_dataset import CellMapBaseDataset
from .base_image import CellMapImageBase
from .dataloader import CellMapDataLoader
from .datasplit import CellMapDataSplit
from .dataset import CellMapDataset
from .dataset_writer import CellMapDatasetWriter
from .image import CellMapImage
from .datasplit import CellMapDataSplit
from .empty_image import EmptyImage
from .image import CellMapImage
from .image_writer import ImageWriter
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset
from .mutable_sampler import MutableSubsetRandomSampler
from . import transforms
from . import utils

__all__ = [
"CellMapBaseDataset",
"CellMapImageBase",
"CellMapDataLoader",
"CellMapDataset",
"CellMapDatasetWriter",
"CellMapDataSplit",
"CellMapImage",
"ImageWriter",
"CellMapMultiDataset",
"CellMapSubset",
"EmptyImage",
"MutableSubsetRandomSampler",
]
108 changes: 108 additions & 0 deletions src/cellmap_data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Abstract base class for CellMap dataset objects."""

from abc import ABC, abstractmethod
from typing import Any, Callable, Mapping, Sequence

import torch


class CellMapBaseDataset(ABC):
"""
Abstract base class for CellMap dataset objects.

This class defines the common interface that all CellMap dataset objects
must implement, ensuring consistency across different dataset types.

Note: `classes`, `input_arrays`, and `target_arrays` are not abstract
properties because implementing classes define them as instance attributes
in __init__, not as properties.
"""

# These are instance attributes set in __init__, not properties
classes: Sequence[str] | None
input_arrays: Mapping[str, Mapping[str, Any]]
target_arrays: Mapping[str, Mapping[str, Any]] | None

@property
@abstractmethod
def class_counts(self) -> dict[str, float]:
"""
Return the number of samples in each class, normalized by resolution.

Returns
-------
dict[str, float]
Dictionary mapping class names to their counts.
"""
pass

@property
@abstractmethod
def class_weights(self) -> dict[str, float]:
"""
Return the class weights based on the number of samples in each class.

Returns
-------
dict[str, float]
Dictionary mapping class names to their weights.
"""
pass

@property
@abstractmethod
def validation_indices(self) -> Sequence[int]:
"""
Return the indices for the validation set.

Returns
-------
Sequence[int]
List of validation indices.
"""
pass

@abstractmethod
def to(
self, device: str | torch.device, non_blocking: bool = True
) -> "CellMapBaseDataset":
"""
Move the dataset to the specified device.

Parameters
----------
device : str | torch.device
The target device.
non_blocking : bool, optional
Whether to use non-blocking transfer, by default True.

Returns
-------
CellMapBaseDataset
Self for method chaining.
"""
pass

@abstractmethod
def set_raw_value_transforms(self, transforms: Callable) -> None:
"""
Set the value transforms for raw input data.

Parameters
----------
transforms : Callable
Transform function to apply to raw data.
"""
pass

@abstractmethod
def set_target_value_transforms(self, transforms: Callable) -> None:
"""
Set the value transforms for target data.

Parameters
----------
transforms : Callable
Transform function to apply to target data.
"""
pass
100 changes: 100 additions & 0 deletions src/cellmap_data/base_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Abstract base class for CellMap image objects."""

from abc import ABC, abstractmethod
from typing import Any, Mapping

import torch


class CellMapImageBase(ABC):
"""
Abstract base class for CellMap image objects.

This class defines the common interface that all CellMap image objects
must implement, ensuring consistency across different image types.
"""

@abstractmethod
def __getitem__(self, center: Mapping[str, float]) -> torch.Tensor:
"""
Return image data centered around the given point.

Parameters
----------
center : Mapping[str, float]
The center coordinates in world units.

Returns
-------
torch.Tensor
The image data as a PyTorch tensor.
"""
pass

@property
@abstractmethod
def bounding_box(self) -> Mapping[str, tuple[float, float]] | None:
"""
Return the bounding box of the image in world units.

Returns
-------
Mapping[str, tuple[float, float]] | None
Dictionary mapping axis names to (min, max) tuples, or None.
"""
pass

@property
@abstractmethod
def sampling_box(self) -> Mapping[str, tuple[float, float]] | None:
"""
Return the sampling box of the image in world units.

The sampling box is the region where centers can be drawn from and
still have full samples drawn from within the bounding box.

Returns
-------
Mapping[str, tuple[float, float]] | None
Dictionary mapping axis names to (min, max) tuples, or None.
"""
pass

@property
@abstractmethod
def class_counts(self) -> float | dict[str, float]:
"""
Return the number of voxels for each class in the image.

Returns
-------
float | dict[str, float]
Class counts, either as a single float or dictionary.
"""
pass

@abstractmethod
def to(self, device: str | torch.device, non_blocking: bool = True) -> None:
"""
Move the image data to the specified device.

Parameters
----------
device : str | torch.device
The target device.
non_blocking : bool, optional
Whether to use non-blocking transfer, by default True.
"""
pass

@abstractmethod
def set_spatial_transforms(self, transforms: Mapping[str, Any] | None) -> None:
"""
Set spatial transformations for the image data.

Parameters
----------
transforms : Mapping[str, Any] | None
Dictionary of spatial transformations to apply.
"""
pass
Loading