1-
2-
3-
41import logging
52from dataclasses import dataclass
63from typing import cast
2320
2421logger = logging .getLogger (__name__ )
2522
23+
2624@dataclass
2725class SpectralAnalysis :
2826 """Results of spectral analysis for a layer."""
@@ -40,6 +38,7 @@ class HistogramAnalysis:
4038 channel_histograms : FloatArray
4139 bin_edges : FloatArray
4240
41+
4342def spectral_analysis (
4443 device : torch .device ,
4544 imageset : Imageset ,
@@ -54,7 +53,7 @@ def spectral_analysis(
5453 dataloader = _prepare_dataset (imageset , max_sample_size )
5554
5655 # Initialize results
57- results = { "input" : _layer_spectral_analysis ( device , dataloader , nn . Identity ()) }
56+ results : dict [ str , SpectralAnalysis ] = { }
5857
5958 # Analyze each layer
6059 head_layers : list [nn .Module ] = []
@@ -64,7 +63,6 @@ def spectral_analysis(
6463
6564 if is_nonlinearity (layer ):
6665 continue
67- # TODO: Possible for non Conv2D layers?
6866
6967 results [layer_name ] = _layer_spectral_analysis (
7068 device , dataloader , nn .Sequential (* head_layers )
@@ -84,10 +82,10 @@ def histogram_analysis(
8482 _ , cnn_layers = get_cnn_circuit (brain )
8583
8684 # Prepare dataset
87- dataloader = _prepare_dataset (imageset , max_sample_size )
85+ dataloader = _prepare_dataset (imageset , max_sample_size ) # TODO: Move outside?
8886
8987 # Initialize results
90- results = { "input" : _layer_pixel_histograms ( device , dataloader , nn . Identity ()) }
88+ results : dict [ str , HistogramAnalysis ] = { }
9189
9290 # Analyze each layer
9391 head_layers : list [nn .Module ] = []
@@ -96,7 +94,6 @@ def histogram_analysis(
9694 head_layers .append (layer )
9795 if is_nonlinearity (layer ):
9896 continue
99- # TODO: Possible for non Conv2D layers?
10097 results [layer_name ] = _layer_pixel_histograms (
10198 device , dataloader , nn .Sequential (* head_layers )
10299 )
@@ -243,8 +240,6 @@ def plot(
243240 copy_checkpoint : bool ,
244241):
245242 for layer_name , layer_rfs in rf_result .items ():
246- if layer_name != "input" :
247- continue
248243 layer_spectral = spectral_result [layer_name ]
249244 layer_histogram = histogram_result [layer_name ]
250245 for channel in range (layer_rfs .shape [0 ]):
@@ -395,4 +390,37 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):
395390 horizontalalignment = "center" ,
396391 verticalalignment = "center" ,
397392 transform = ax .transAxes ,
398- )
393+ )
394+
395+
396+ def analyze_input (
397+ device : torch .device , imageset : Imageset , max_sample_size : int
398+ ) -> tuple [SpectralAnalysis , HistogramAnalysis ]:
399+ dataloader = _prepare_dataset (imageset , max_sample_size )
400+ spectral_result = _layer_spectral_analysis (device , dataloader , nn .Identity ())
401+ histogram_result = _layer_pixel_histograms (device , dataloader , nn .Identity ())
402+ return spectral_result , histogram_result
403+
404+
405+ def input_plot (
406+ log : FigureLogger ,
407+ rf_result : FloatArray ,
408+ spectral_result : SpectralAnalysis ,
409+ histogram_result : HistogramAnalysis ,
410+ init_dir : str ,
411+ ):
412+ for channel in range (histogram_result .channel_histograms .shape [0 ]):
413+ channel_fig = layer_channel_plots (
414+ rf_result ,
415+ spectral_result ,
416+ histogram_result ,
417+ layer_name = "input" ,
418+ channel = channel ,
419+ )
420+ log .log_figure (
421+ channel_fig ,
422+ init_dir ,
423+ f"input_channel_{ channel } " ,
424+ 0 ,
425+ False ,
426+ )
0 commit comments