Skip to content

Commit 9b5cbd8

Browse files
committed
Add dimension_order parameter support to fix TYX/CYX confusion
- Add dimension_order parameter to key processing functions (Otsu, Gaussian, Median) - Functions now respect dimension hints: TYX (time), ZYX (z-stack), CYX (channels) - Fix file saving logic to respect dimension_order when deciding whether to split channels - Add dimension_order dropdown UI (positioned before function selector) - Pass dimension_order from UI to processing functions via param_values - Fix split_channels to remove singleton dimensions with np.squeeze() - Use tifffile.imread instead of skimage.io.imread to preserve dimension order - Remove automatic channel splitting in file selector (only split when dimension_order indicates channels) - Add comprehensive test suites for split_channels and TYX display handling - All 9 split_channels tests passing, all 3 file_selector tests passing
1 parent 4e840bd commit 9b5cbd8

6 files changed

Lines changed: 706 additions & 138 deletions

File tree

src/napari_tmidas/_file_selector.py

Lines changed: 122 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,15 @@ def load_image_file(filepath: str) -> Union[np.ndarray, List, Any]:
437437
# Fallback to basic zarr loading with dask
438438
return load_zarr_basic(filepath)
439439
else:
440-
return imread(filepath)
440+
# Use tifffile for TIFF files to preserve dimension order
441+
# (skimage.io.imread may transpose dimensions)
442+
if _HAS_TIFFFILE and (
443+
filepath.lower().endswith(".tif")
444+
or filepath.lower().endswith(".tiff")
445+
):
446+
return tifffile.imread(filepath)
447+
else:
448+
return imread(filepath)
441449

442450

443451
class ProcessedFilesTableWidget(QTableWidget):
@@ -613,21 +621,35 @@ def _clear_current_images(self, image_list):
613621
image_list.clear()
614622

615623
def _should_enable_3d_view(self, data):
616-
"""Check if 3D view should be enabled based on data dimensions"""
624+
"""
625+
Check if 3D view should be enabled based on data dimensions.
626+
627+
Conservative approach: Only enable 3D view for clearly spatial 3D data (Z-stacks),
628+
not for time series which should use 2D view with time slider.
629+
"""
617630
if not hasattr(data, "shape") or len(data.shape) < 3:
618631
return False
619632

620-
# Check if we have meaningful 3D data (excluding potential channel dimensions)
621633
shape = data.shape
622634

623-
# If first dimension is small (<=10), it's likely channels, check remaining dims
624-
meaningful_dims = shape[1:] if shape[0] <= 10 else shape
635+
# If first dimension is channels (2-4), check remaining dims
636+
if shape[0] >= 2 and shape[0] <= 4:
637+
meaningful_dims = shape[1:]
638+
else:
639+
meaningful_dims = shape
625640

626-
# Enable 3D if we have at least 3 spatial dimensions with substantial size
627-
if len(meaningful_dims) >= 3:
628-
# Check that we have meaningful Z depth (not just singleton)
629-
z_dim = meaningful_dims[0] if len(meaningful_dims) >= 3 else 1
641+
# Only enable 3D view for data with 4+ dimensions (like TZYX, CZYX)
642+
# or 3D data with many slices (likely a Z-stack, not time series)
643+
if len(meaningful_dims) >= 4:
644+
# TZYX or similar - check Z dimension
645+
z_dim = meaningful_dims[1] if len(meaningful_dims) >= 4 else 1
630646
return z_dim > 1
647+
elif len(meaningful_dims) == 3:
648+
# Could be ZYX (spatial) or TYX (temporal)
649+
# Only enable 3D for many slices (likely Z-stack)
650+
# 10+ slices suggests Z-stack, fewer suggests time series
651+
first_dim = meaningful_dims[0]
652+
return first_dim > 10
631653

632654
return False
633655

@@ -822,46 +844,9 @@ def _load_original_image(self, filepath: str):
822844
if hasattr(image, "squeeze") and not hasattr(image, "chunks"):
823845
image = np.squeeze(image)
824846

825-
# Check for multi-channel data in single array
826-
if (
827-
hasattr(image, "shape")
828-
and len(image.shape) > 2
829-
and image.shape[0] <= 10
830-
and image.shape[0] > 1
831-
): # Likely channels first
832-
print(
833-
f"Using napari channel_axis=0 for {image.shape[0]} channels"
834-
)
835-
836-
# Use napari's channel_axis to automatically split channels with proper colormaps
837-
layers = self.viewer.add_image(
838-
image,
839-
channel_axis=0,
840-
name=f"Original: {os.path.basename(filepath)}",
841-
blending="additive",
842-
)
843-
844-
# Track all the layers napari created
845-
if isinstance(layers, list):
846-
self.current_original_images.extend(layers)
847-
else:
848-
self.current_original_images.append(layers)
849-
850-
# Switch to 3D view if data has meaningful 3D dimensions (excluding channel dim)
851-
if len(self.current_original_images) > 0:
852-
first_layer = self.current_original_images[0]
853-
if hasattr(first_layer, "data"):
854-
channel_data = first_layer.data
855-
if self._should_enable_3d_view(channel_data):
856-
self.viewer.dims.ndisplay = 3
857-
print(
858-
f"Switched to 3D view for multi-channel data with shape: {channel_data.shape}"
859-
)
860-
861-
self.viewer.status = f"Loaded {len(self.current_original_images)} channels from {os.path.basename(filepath)}"
862-
return
863-
864-
# Single channel image
847+
# Don't automatically split channels - let napari handle with sliders
848+
# This avoids confusion between channels (C) and time (T) dimensions
849+
# Users can manually split if needed using the "Split Color Channels" function
865850
base_filename = os.path.basename(filepath)
866851
# check if label image by checking image dtype
867852
is_label = is_label_image(image)
@@ -879,14 +864,8 @@ def _load_original_image(self, filepath: str):
879864

880865
self.current_original_images.append(layer)
881866

882-
# Switch to 3D view if data has meaningful 3D dimensions
883-
if hasattr(layer, "data") and self._should_enable_3d_view(
884-
layer.data
885-
):
886-
self.viewer.dims.ndisplay = 3
887-
print(
888-
f"Switched to 3D view for single-channel data with shape: {layer.data.shape}"
889-
)
867+
# Don't automatically switch to 3D view - let user decide
868+
# napari will show appropriate sliders for all dimensions
890869

891870
self.viewer.status = f"Loaded {base_filename}"
892871

@@ -1179,57 +1158,8 @@ def _load_processed_image(self, filepath: str):
11791158
if hasattr(image, "squeeze") and not hasattr(image, "chunks"):
11801159
image = np.squeeze(image)
11811160

1182-
filename = os.path.basename(filepath)
1183-
1184-
# Check for multi-channel data in single array
1185-
if (
1186-
hasattr(image, "shape")
1187-
and len(image.shape) > 2
1188-
and image.shape[0] <= 10
1189-
and image.shape[0] > 1
1190-
): # Likely channels first
1191-
print(
1192-
f"Using napari channel_axis=0 for {image.shape[0]} processed channels"
1193-
)
1194-
1195-
# Use napari's channel_axis to automatically split channels with proper colormaps
1196-
layers = self.viewer.add_image(
1197-
image,
1198-
channel_axis=0,
1199-
name=f"Processed: {filename}",
1200-
blending="additive",
1201-
)
1202-
1203-
# Track all the layers napari created
1204-
if isinstance(layers, list):
1205-
self.current_processed_images.extend(layers)
1206-
else:
1207-
self.current_processed_images.append(layers)
1208-
1209-
# Switch to 3D view if data has meaningful 3D dimensions (excluding channel dim)
1210-
if len(self.current_processed_images) > 0:
1211-
first_layer = self.current_processed_images[0]
1212-
if hasattr(first_layer, "data"):
1213-
channel_data = first_layer.data
1214-
if self._should_enable_3d_view(channel_data):
1215-
self.viewer.dims.ndisplay = 3
1216-
print(
1217-
f"Switched to 3D view for processed multi-channel data with shape: {channel_data.shape}"
1218-
)
1219-
1220-
# Move all processed layers to top
1221-
for layer in self.current_processed_images:
1222-
if layer in self.viewer.layers:
1223-
layer_index = self.viewer.layers.index(layer)
1224-
if layer_index < len(self.viewer.layers) - 1:
1225-
self.viewer.layers.move(
1226-
layer_index, len(self.viewer.layers) - 1
1227-
)
1228-
1229-
self.viewer.status = f"Loaded {len(self.current_processed_images)} processed channels from {filename}"
1230-
return
1231-
1232-
# Single channel processed image
1161+
# Don't automatically split channels - let napari handle with sliders
1162+
# This avoids confusion between channels (C) and time (T) dimensions
12331163
filename = os.path.basename(filepath)
12341164
# Check if image dtype indicates labels
12351165
is_label = is_label_image(image)
@@ -1252,14 +1182,8 @@ def _load_processed_image(self, filepath: str):
12521182

12531183
self.current_processed_images.append(layer)
12541184

1255-
# Switch to 3D view if data has meaningful 3D dimensions
1256-
if hasattr(layer, "data") and self._should_enable_3d_view(
1257-
layer.data
1258-
):
1259-
self.viewer.dims.ndisplay = 3
1260-
print(
1261-
f"Switched to 3D view for processed single-channel data with shape: {layer.data.shape}"
1262-
)
1185+
# Don't automatically switch to 3D view - let user decide
1186+
# napari will show appropriate sliders for all dimensions
12631187

12641188
# Move the processed layer to the top of the stack
12651189
if layer in self.viewer.layers:
@@ -1844,10 +1768,46 @@ def process_file(self, filepath):
18441768
ext = ".tif"
18451769

18461770
# Check if the first dimension should be treated as channels
1847-
is_multi_channel = (
1848-
processed_image.ndim > 2 and processed_image.shape[0] <= 10
1771+
# Respect dimension_order hint if provided, otherwise use heuristic (2-4 channels for RGB/RGBA)
1772+
dimension_order_hint = processing_params.get(
1773+
"dimension_order", "Auto"
18491774
)
18501775

1776+
# Only split if dimension_order indicates channels (CYX, TCYX, etc. with C first)
1777+
# or if Auto and shape suggests channels (2-4)
1778+
is_multi_channel = False
1779+
if dimension_order_hint in [
1780+
"CYX",
1781+
"CZYX",
1782+
"TCYX",
1783+
"ZCYX",
1784+
"TZCYX",
1785+
]:
1786+
# User explicitly said first dim is channels - split it
1787+
is_multi_channel = (
1788+
processed_image.ndim > 2 and processed_image.shape[0] > 1
1789+
)
1790+
print(
1791+
f"dimension_order='{dimension_order_hint}' indicates channels, will split {processed_image.shape[0]} channels"
1792+
)
1793+
elif dimension_order_hint in ["TYX", "ZYX", "TZYX"]:
1794+
# User explicitly said it's NOT channels (time or Z) - don't split
1795+
is_multi_channel = False
1796+
print(
1797+
f"dimension_order='{dimension_order_hint}' indicates time/Z dimension, will NOT split channels"
1798+
)
1799+
elif dimension_order_hint == "Auto":
1800+
# Auto mode: use old heuristic (2-4 suggests channels)
1801+
is_multi_channel = (
1802+
processed_image.ndim > 2
1803+
and processed_image.shape[0] <= 4
1804+
and processed_image.shape[0] > 1
1805+
)
1806+
if is_multi_channel:
1807+
print(
1808+
f"Auto mode: shape[0]={processed_image.shape[0]} <= 4, assuming channels"
1809+
)
1810+
18511811
if is_multi_channel:
18521812
# Save each channel as a separate image
18531813
processed_files = []
@@ -2032,6 +1992,44 @@ def __init__(
20321992

20331993
# Create processing function selector
20341994
processing_layout = QVBoxLayout()
1995+
1996+
# Add dimension order selector FIRST (before function selector)
1997+
dim_order_layout = QHBoxLayout()
1998+
dim_order_label = QLabel("Dimension Order (optional hint):")
1999+
dim_order_label.setToolTip(
2000+
"Help processing functions interpret multi-dimensional data.\n"
2001+
"• Auto: Let function decide (default)\n"
2002+
"• YX: 2D image\n"
2003+
"• CYX: Channels first (e.g., RGB)\n"
2004+
"• TYX: Time series\n"
2005+
"• ZYX: Z-stack\n"
2006+
"• TCYX, TZYX, etc.: Combined dimensions\n"
2007+
"\nNote: Not all functions use this hint."
2008+
)
2009+
dim_order_layout.addWidget(dim_order_label)
2010+
2011+
self.dimension_order = QComboBox()
2012+
self.dimension_order.addItems(
2013+
[
2014+
"Auto",
2015+
"YX",
2016+
"CYX",
2017+
"TYX",
2018+
"ZYX",
2019+
"TCYX",
2020+
"TZYX",
2021+
"ZCYX",
2022+
"TZCYX",
2023+
]
2024+
)
2025+
self.dimension_order.setToolTip(
2026+
"Dimension interpretation hint for processing functions"
2027+
)
2028+
dim_order_layout.addWidget(self.dimension_order)
2029+
dim_order_layout.addStretch()
2030+
processing_layout.addLayout(dim_order_layout)
2031+
2032+
# Now add processing function selector
20352033
processing_label = QLabel("Select Processing Function:")
20362034
processing_layout.addWidget(processing_label)
20372035

@@ -2220,6 +2218,12 @@ def start_batch_processing(self):
22202218
):
22212219
param_values = self.param_widget_instance.get_parameter_values()
22222220

2221+
# Add dimension order hint if not "Auto"
2222+
if hasattr(self, "dimension_order"):
2223+
dim_order = self.dimension_order.currentText()
2224+
if dim_order != "Auto":
2225+
param_values["dimension_order"] = dim_order
2226+
22232227
# Determine output folder
22242228
output_folder = self.output_folder.text().strip()
22252229
if not output_folder:

0 commit comments

Comments
 (0)