@@ -41,17 +41,13 @@ class HistogramAnalysis:
4141
4242def spectral_analysis (
4343 device : torch .device ,
44- imageset : Imageset ,
44+ dataloader : DataLoader [ tuple [ Tensor , Tensor , int ]] ,
4545 brain : Brain ,
46- max_sample_size : int = 0 ,
4746) -> dict [str , SpectralAnalysis ]:
4847 brain .eval ()
4948 brain .to (device )
5049 _ , cnn_layers = get_cnn_circuit (brain )
5150
52- # Prepare dataset
53- dataloader = _prepare_dataset (imageset , max_sample_size )
54-
5551 # Initialize results
5652 results : dict [str , SpectralAnalysis ] = {}
5753
@@ -73,17 +69,13 @@ def spectral_analysis(
7369
7470def histogram_analysis (
7571 device : torch .device ,
76- imageset : Imageset ,
72+ dataloader : DataLoader [ tuple [ Tensor , Tensor , int ]] ,
7773 brain : Brain ,
78- max_sample_size : int = 0 ,
7974) -> dict [str , HistogramAnalysis ]:
8075 brain .eval ()
8176 brain .to (device )
8277 _ , cnn_layers = get_cnn_circuit (brain )
8378
84- # Prepare dataset
85- dataloader = _prepare_dataset (imageset , max_sample_size ) # TODO: Move outside?
86-
8779 # Initialize results
8880 results : dict [str , HistogramAnalysis ] = {}
8981
@@ -101,7 +93,7 @@ def histogram_analysis(
10193 return results
10294
10395
104- def _prepare_dataset (
96+ def prepare_dataset (
10597 imageset : Imageset , max_sample_size : int = 0
10698) -> DataLoader [tuple [Tensor , Tensor , int ]]:
10799 """Prepare dataset and dataloader for analysis."""
@@ -394,9 +386,8 @@ def _plot_receptive_fields(ax: Axes, rf: FloatArray):
394386
395387
396388def analyze_input (
397- device : torch .device , imageset : Imageset , max_sample_size : int
389+ device : torch .device , dataloader : DataLoader [ tuple [ Tensor , Tensor , int ]]
398390) -> tuple [SpectralAnalysis , HistogramAnalysis ]:
399- dataloader = _prepare_dataset (imageset , max_sample_size )
400391 spectral_result = _layer_spectral_analysis (device , dataloader , nn .Identity ())
401392 histogram_result = _layer_pixel_histograms (device , dataloader , nn .Identity ())
402393 return spectral_result , histogram_result
0 commit comments