Skip to content

Commit 8fb3892

Browse files
committed
refactor: move dataloader creation out of analysis function (ensure consistency due to rand perm in prepare_dataset)
1 parent 6f5861e commit 8fb3892

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

retinal_rl/analysis/channel_analysis.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,13 @@ class HistogramAnalysis:
4141

4242
def spectral_analysis(
4343
device: torch.device,
44-
imageset: Imageset,
44+
dataloader: DataLoader[tuple[Tensor, Tensor, int]],
4545
brain: Brain,
46-
max_sample_size: int = 0,
4746
) -> dict[str, SpectralAnalysis]:
4847
brain.eval()
4948
brain.to(device)
5049
_, cnn_layers = get_cnn_circuit(brain)
5150

52-
# Prepare dataset
53-
dataloader = _prepare_dataset(imageset, max_sample_size)
54-
5551
# Initialize results
5652
results: dict[str, SpectralAnalysis] = {}
5753

@@ -73,17 +69,13 @@ def spectral_analysis(
7369

7470
def histogram_analysis(
7571
device: torch.device,
76-
imageset: Imageset,
72+
dataloader: DataLoader[tuple[Tensor, Tensor, int]],
7773
brain: Brain,
78-
max_sample_size: int = 0,
7974
) -> dict[str, HistogramAnalysis]:
8075
brain.eval()
8176
brain.to(device)
8277
_, cnn_layers = get_cnn_circuit(brain)
8378

84-
# Prepare dataset
85-
dataloader = _prepare_dataset(imageset, max_sample_size) # TODO: Move outside?
86-
8779
# Initialize results
8880
results: dict[str, HistogramAnalysis] = {}
8981

@@ -101,7 +93,7 @@ def histogram_analysis(
10193
return results
10294

10395

104-
def _prepare_dataset(
96+
def prepare_dataset(
10597
imageset: Imageset, max_sample_size: int = 0
10698
) -> DataLoader[tuple[Tensor, Tensor, int]]:
10799
"""Prepare dataset and dataloader for analysis."""
@@ -394,9 +386,8 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):
394386

395387

396388
def analyze_input(
397-
device: torch.device, imageset: Imageset, max_sample_size: int
389+
device: torch.device, dataloader: DataLoader[tuple[Tensor, Tensor, int]]
398390
) -> tuple[SpectralAnalysis, HistogramAnalysis]:
399-
dataloader = _prepare_dataset(imageset, max_sample_size)
400391
spectral_result = _layer_spectral_analysis(device, dataloader, nn.Identity())
401392
histogram_result = _layer_pixel_histograms(device, dataloader, nn.Identity())
402393
return spectral_result, histogram_result

runner/frameworks/classification/analyze.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,10 @@ def analyze(
6868
log.save_dict(cfg.analyses_dir / f"receptive_fields_epoch_{epoch}.json", rf_result)
6969

7070
if cfg.channel_analysis:
71-
spectral_result = channel_ana.spectral_analysis(
72-
device, test_set, brain, cfg.plot_sample_size
73-
)
74-
histogram_result = channel_ana.histogram_analysis(
75-
device, test_set, brain, cfg.plot_sample_size
76-
)
71+
# Prepare dataset
72+
dataloader = channel_ana.prepare_dataset(test_set, cfg.plot_sample_size)
73+
spectral_result = channel_ana.spectral_analysis(device, dataloader, brain)
74+
histogram_result = channel_ana.histogram_analysis(device, dataloader, brain)
7775
channel_ana.plot(
7876
log,
7977
rf_result,
@@ -142,8 +140,9 @@ def _extended_initialization_plots(
142140
if channel_analysis:
143141
# Input 'rfs' is just the colors
144142
rf_result = np.eye(input_shape[0])[:, :, np.newaxis, np.newaxis]
143+
dataloader = channel_ana.prepare_dataset(train_set, max_sample_size)
145144
spectral_result, histogram_result = channel_ana.analyze_input(
146-
device, train_set, max_sample_size
145+
device, dataloader
147146
)
148147
channel_ana.input_plot(
149148
log,

0 commit comments

Comments
 (0)