Skip to content

Commit 6f5861e

Browse files
committed
refactor: continue analysis/plot restructuring
1 parent 78f4d4d commit 6f5861e

File tree

9 files changed

+189
-193
lines changed

9 files changed

+189
-193
lines changed

retinal_rl/analysis/channel_analysis.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
2-
3-
41
import logging
52
from dataclasses import dataclass
63
from typing import cast
@@ -23,6 +20,7 @@
2320

2421
logger = logging.getLogger(__name__)
2522

23+
2624
@dataclass
2725
class SpectralAnalysis:
2826
"""Results of spectral analysis for a layer."""
@@ -40,6 +38,7 @@ class HistogramAnalysis:
4038
channel_histograms: FloatArray
4139
bin_edges: FloatArray
4240

41+
4342
def spectral_analysis(
4443
device: torch.device,
4544
imageset: Imageset,
@@ -54,7 +53,7 @@ def spectral_analysis(
5453
dataloader = _prepare_dataset(imageset, max_sample_size)
5554

5655
# Initialize results
57-
results = {"input": _layer_spectral_analysis(device, dataloader, nn.Identity())}
56+
results: dict[str, SpectralAnalysis] = {}
5857

5958
# Analyze each layer
6059
head_layers: list[nn.Module] = []
@@ -64,7 +63,6 @@ def spectral_analysis(
6463

6564
if is_nonlinearity(layer):
6665
continue
67-
# TODO: Possible for non Conv2D layers?
6866

6967
results[layer_name] = _layer_spectral_analysis(
7068
device, dataloader, nn.Sequential(*head_layers)
@@ -84,10 +82,10 @@ def histogram_analysis(
8482
_, cnn_layers = get_cnn_circuit(brain)
8583

8684
# Prepare dataset
87-
dataloader = _prepare_dataset(imageset, max_sample_size)
85+
dataloader = _prepare_dataset(imageset, max_sample_size) # TODO: Move outside?
8886

8987
# Initialize results
90-
results = {"input": _layer_pixel_histograms(device, dataloader, nn.Identity())}
88+
results: dict[str, HistogramAnalysis] = {}
9189

9290
# Analyze each layer
9391
head_layers: list[nn.Module] = []
@@ -96,7 +94,6 @@ def histogram_analysis(
9694
head_layers.append(layer)
9795
if is_nonlinearity(layer):
9896
continue
99-
# TODO: Possible for non Conv2D layers?
10097
results[layer_name] = _layer_pixel_histograms(
10198
device, dataloader, nn.Sequential(*head_layers)
10299
)
@@ -243,8 +240,6 @@ def plot(
243240
copy_checkpoint: bool,
244241
):
245242
for layer_name, layer_rfs in rf_result.items():
246-
if layer_name != "input":
247-
continue
248243
layer_spectral = spectral_result[layer_name]
249244
layer_histogram = histogram_result[layer_name]
250245
for channel in range(layer_rfs.shape[0]):
@@ -395,4 +390,37 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):
395390
horizontalalignment="center",
396391
verticalalignment="center",
397392
transform=ax.transAxes,
398-
)
393+
)
394+
395+
396+
def analyze_input(
397+
device: torch.device, imageset: Imageset, max_sample_size: int
398+
) -> tuple[SpectralAnalysis, HistogramAnalysis]:
399+
dataloader = _prepare_dataset(imageset, max_sample_size)
400+
spectral_result = _layer_spectral_analysis(device, dataloader, nn.Identity())
401+
histogram_result = _layer_pixel_histograms(device, dataloader, nn.Identity())
402+
return spectral_result, histogram_result
403+
404+
405+
def input_plot(
406+
log: FigureLogger,
407+
rf_result: FloatArray,
408+
spectral_result: SpectralAnalysis,
409+
histogram_result: HistogramAnalysis,
410+
init_dir: str,
411+
):
412+
for channel in range(histogram_result.channel_histograms.shape[0]):
413+
channel_fig = layer_channel_plots(
414+
rf_result,
415+
spectral_result,
416+
histogram_result,
417+
layer_name="input",
418+
channel=channel,
419+
)
420+
log.log_figure(
421+
channel_fig,
422+
init_dir,
423+
f"input_channel_{channel}",
424+
0,
425+
False,
426+
)

retinal_rl/analysis/default.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from retinal_rl.analysis.plot import (
2+
FigureLogger,
3+
plot_brain_and_optimizers,
4+
plot_receptive_field_sizes,
5+
)
6+
from retinal_rl.models.brain import Brain
7+
from retinal_rl.models.objective import ContextT, Objective
8+
from retinal_rl.util import FloatArray
9+
10+
INIT_DIR = "initialization_analysis"
11+
12+
def initialization_plots(
13+
log: FigureLogger,
14+
brain: Brain,
15+
objective: Objective[ContextT],
16+
input_shape: tuple[int, ...],
17+
rf_result: dict[str, FloatArray],
18+
):
19+
log.save_summary(brain)
20+
21+
# TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats
22+
rf_sizes_fig = plot_receptive_field_sizes(input_shape, rf_result)
23+
log.log_figure(
24+
rf_sizes_fig,
25+
INIT_DIR,
26+
"receptive_field_sizes",
27+
0,
28+
False,
29+
)
30+
31+
graph_fig = plot_brain_and_optimizers(brain, objective)
32+
log.log_figure(
33+
graph_fig,
34+
INIT_DIR,
35+
"brain_graph",
36+
0,
37+
False,
38+
)

retinal_rl/analysis/plot.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Utility functions for plotting the results of statistical analyses."""
22

3+
import json
34
import shutil
45
from pathlib import Path
6+
from typing import Any
57

68
import matplotlib.pyplot as plt
79
import networkx as nx
@@ -17,7 +19,7 @@
1719

1820
from retinal_rl.models.brain import Brain
1921
from retinal_rl.models.objective import ContextT, Objective
20-
from retinal_rl.util import FloatArray
22+
from retinal_rl.util import FloatArray, NumpyEncoder
2123

2224

2325
def make_image_grid(arrays: list[FloatArray], nrow: int) -> FloatArray:
@@ -178,8 +180,6 @@ def plot_receptive_field_sizes(
178180
rf_sizes: list[tuple[int, int]] = []
179181
layer_names: list[str] = []
180182
for name, rf in rf_layers.items():
181-
if name == "input": # TODO: Should not be possible?!
182-
continue
183183
rf_height, rf_width = rf.shape[2:]
184184
rf_sizes.append((rf_height, rf_width))
185185
layer_names.append(name)
@@ -291,13 +291,10 @@ def set_integer_ticks(ax: Axes):
291291
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
292292
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
293293

294+
294295
class FigureLogger:
295296
def __init__(
296-
self,
297-
use_wandb: bool,
298-
plot_dir: Path,
299-
checkpoint_plot_dir: Path,
300-
run_dir: Path
297+
self, use_wandb: bool, plot_dir: Path, checkpoint_plot_dir: Path, run_dir: Path
301298
):
302299
self.use_wandb = use_wandb
303300
self.plot_dir = plot_dir
@@ -340,7 +337,6 @@ def capitalize_part(part: str) -> str:
340337
return "/".join(capitalized_parts)
341338

342339
def _checkpoint_copy(self, sub_dir: str, file_name: str, epoch: int) -> None:
343-
# TODO: Does this need to be in here?
344340
src_path = self.plot_dir / sub_dir / f"{file_name}.png"
345341

346342
dest_dir = self.checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir
@@ -370,3 +366,7 @@ def save_summary(self, brain: Brain):
370366

371367
if self.use_wandb:
372368
wandb.save(str(filepath), base_path=self.run_dir, policy="now")
369+
370+
def save_dict(self, path: Path, dict: dict[str, Any]):
371+
with open(path, "w") as f:
372+
json.dump(dict, f, cls=NumpyEncoder)

retinal_rl/analysis/receptive_fields.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@ def analyze(brain: Brain, device: torch.device):
4040

4141
return input_shape, results
4242

43+
4344
def plot(
4445
log: FigureLogger,
4546
rf_result: dict[str, FloatArray],
4647
epoch: int,
4748
copy_checkpoint: bool,
4849
):
4950
for layer_name, layer_rfs in rf_result.items():
50-
if layer_name == "input": # TODO: Remove, see where this goes
51-
continue
5251
layer_rf_plots = layer_receptive_field_plots(layer_rfs)
5352
log.log_figure(
5453
layer_rf_plots,
@@ -58,9 +57,8 @@ def plot(
5857
copy_checkpoint,
5958
)
6059

61-
def layer_receptive_field_plots(
62-
lyr_rfs: FloatArray, max_cols: int = 8
63-
) -> Figure:
60+
61+
def layer_receptive_field_plots(lyr_rfs: FloatArray, max_cols: int = 8) -> Figure:
6462
"""Plot the receptive fields of a convolutional layer."""
6563
ochns, _, _, _ = lyr_rfs.shape
6664

@@ -79,9 +77,7 @@ def layer_receptive_field_plots(
7977

8078
for i in range(ochns):
8179
ax = axs[i]
82-
data = np.moveaxis(
83-
lyr_rfs[i], 0, -1
84-
) # Move channel axis to the last dimension
80+
data = np.moveaxis(lyr_rfs[i], 0, -1) # Move channel axis to the last dimension
8581
data_min = data.min()
8682
data_max = data.max()
8783
data = (data - data_min) / (data_max - data_min)
@@ -96,6 +92,7 @@ def layer_receptive_field_plots(
9692
fig.tight_layout() # Adjust layout to fit color bars
9793
return fig
9894

95+
9996
def _compute_receptive_fields(
10097
device: torch.device,
10198
head_layers: list[nn.Module],

retinal_rl/analysis/reconstructions.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
import json
31
from dataclasses import asdict, dataclass
42
from pathlib import Path
53

@@ -13,7 +11,7 @@
1311
from retinal_rl.models.brain import Brain
1412
from retinal_rl.models.loss import ContextT, ReconstructionLoss
1513
from retinal_rl.models.objective import Objective
16-
from retinal_rl.util import FloatArray, NumpyEncoder
14+
from retinal_rl.util import FloatArray
1715

1816

1917
@dataclass
@@ -24,42 +22,47 @@ class Reconstructions:
2422
inputs: list[tuple[FloatArray, int]]
2523
estimates: list[tuple[FloatArray, int]]
2624

25+
2726
@dataclass
2827
class ReconstructionStatistics:
2928
"""Results of image reconstruction for both training and test sets."""
3029

3130
train: Reconstructions
3231
test: Reconstructions
3332

34-
# TODO: Make structure match the analyze / plot structure as receptive_fields
3533

36-
def perform_reconstruction_analysis(
37-
log: FigureLogger,
38-
analyses_dir: Path,
34+
def analyze(
3935
device: torch.device,
4036
brain: Brain,
4137
objective: Objective[ContextT],
4238
train_set: Imageset,
4339
test_set: Imageset,
44-
epoch: int,
45-
copy_checkpoint: bool,
46-
):
40+
) -> tuple[dict[str, ReconstructionStatistics], list[float], list[float]]:
4741
reconstruction_decoders = [
4842
loss.target_decoder
4943
for loss in objective.losses
5044
if isinstance(loss, ReconstructionLoss)
5145
]
5246

47+
results: dict[str, ReconstructionStatistics] = {}
5348
for decoder in reconstruction_decoders:
54-
norm_means, norm_stds = train_set.normalization_stats
55-
rec_dict = asdict(
56-
reconstruct_images(device, brain, decoder, train_set, test_set, 5)
49+
results[decoder] = reconstruct_images(
50+
device, brain, decoder, train_set, test_set, 5
5751
)
58-
# Save the reconstructions
59-
rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json"
60-
with open(rec_path, "w") as f:
61-
json.dump(rec_dict, f, cls=NumpyEncoder)
52+
return results, *train_set.normalization_stats
6253

54+
55+
def plot(
56+
log: FigureLogger,
57+
analyses_dir: Path,
58+
result: dict[str, ReconstructionStatistics],
59+
norm_means: list[float],
60+
norm_stds: list[float],
61+
epoch: int,
62+
copy_checkpoint: bool,
63+
):
64+
for decoder, reconstructions in result.items():
65+
rec_dict = asdict(reconstructions)
6366
recon_fig = plot_reconstructions(
6467
norm_means,
6568
norm_stds,
@@ -74,6 +77,10 @@ def perform_reconstruction_analysis(
7477
epoch,
7578
copy_checkpoint,
7679
)
80+
# Save the reconstructions #TODO: most plot functions don't do this, should stay?
81+
log.save_dict(
82+
analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json", rec_dict
83+
)
7784

7885

7986
def reconstruct_images(

retinal_rl/analysis/transforms_analysis.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@ class TransformStatistics:
1818
source_transforms: dict[str, dict[float, list[FloatArray]]]
1919
noise_transforms: dict[str, dict[float, list[FloatArray]]]
2020

21-
# TODO: Make structure match the analyze / plot structure as receptive_fields
2221

23-
def transform_base_images(
24-
imageset: Imageset, num_steps: int, num_images: int
25-
) -> TransformStatistics:
22+
def analyze(imageset: Imageset, num_steps: int, num_images: int) -> TransformStatistics:
2623
"""Apply transformations to a set of images from an Imageset."""
2724
images: list[Image.Image] = []
2825

@@ -59,7 +56,7 @@ def transform_base_images(
5956
return resultss
6057

6158

62-
def plot_transforms(
59+
def plot(
6360
source_transforms: dict[str, dict[float, list[FloatArray]]],
6461
noise_transforms: dict[str, dict[float, list[FloatArray]]],
6562
) -> Figure:
@@ -142,4 +139,4 @@ def plot_transforms(
142139
transform_index += 1
143140

144141
plt.tight_layout()
145-
return fig
142+
return fig

retinal_rl/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
### IO Handling stuff
2121

22+
2223
class NumpyEncoder(json.JSONEncoder):
2324
"""JSON encoder that handles numpy arrays."""
2425

@@ -27,6 +28,7 @@ def default(self, obj: Any) -> Any:
2728
return obj.tolist()
2829
return super().default(obj)
2930

31+
3032
### Functions
3133

3234

0 commit comments

Comments
 (0)