diff --git a/avae/train.py b/avae/train.py index afed1c46..860b2432 100644 --- a/avae/train.py +++ b/avae/train.py @@ -602,7 +602,12 @@ def train( x_train, p_train, vae, data_dim, device ) - if pose and config.VIS_POSE_CLASS: + if ( + pose + and config.VIS_POS + and config.VIS_POSE_CLASS + and (epoch + 1) % config.FREQ_POS == 0 + ): vis.pose_class_disentanglement_plot( x_train, y_train, diff --git a/avae/utils.py b/avae/utils.py index 9be16e17..185f9bbc 100644 --- a/avae/utils.py +++ b/avae/utils.py @@ -191,8 +191,9 @@ def save_imshow_png( fig, _ = plt.subplots(figsize=(10, 10)) plt.imshow(array, cmap=cmap, vmin=min, vmax=max) # channels last + plt.axis("off") - plt.savefig("plots/" + fname) + plt.savefig("plots/" + fname, bbox_inches="tight", pad_inches=0) if writer: writer.add_figure(figname, fig, epoch) diff --git a/avae/vis.py b/avae/vis.py index 861cdadb..60983fef 100644 --- a/avae/vis.py +++ b/avae/vis.py @@ -1077,7 +1077,12 @@ def latent_disentamglement_plot( "################################################################" ) logging.info("Visualising latent content disentanglement ...\n") - number_of_samples = 7 + + # every SAMPLING_STEP interval from -SIGMA to SIGMA + sigma = 2 + sampling_rate = 15 + sampling_step = (sigma * 2) / (sampling_rate - 1) + padding = 0 lats = np.asarray(lats) if poses is not None: @@ -1085,22 +1090,22 @@ def latent_disentamglement_plot( lat_means = np.mean(lats, axis=0) lat_stds = np.std(lats, axis=0) + lat_dims = lats.shape[-1] - lat_grid = np.zeros((lat_dims * number_of_samples, lat_dims)) + lat_grid = np.zeros((lat_dims * sampling_rate, lat_dims)) if poses is not None: pos_means = np.mean(poses, axis=0) pos_dims = poses.shape[-1] - pos_grid = ( - np.zeros((lat_dims * number_of_samples, pos_dims)) + pos_means - ) + pos_grid = np.zeros((lat_dims * sampling_rate, pos_dims)) + pos_means # Generate vectors representing single transversals along each lat_dim for l_dim in range(lat_dims): - for grid_spot in range(7): + for grid_spot in range(sampling_rate): means = copy.deepcopy(lat_means) - # every 0.4 interval from -1.2 to 1.2 sigma - means[l_dim] += lat_stds[l_dim] * (-1.2 + 0.4 * grid_spot) - lat_grid[l_dim * number_of_samples + grid_spot, :] = means + means[l_dim] += lat_stds[l_dim] * ( + -sigma + sampling_step * grid_spot + ) + lat_grid[l_dim * sampling_rate + grid_spot, :] = means # Decode interpolated vectors with torch.no_grad(): @@ -1121,13 +1126,13 @@ def latent_disentamglement_plot( return recon = np.reshape( - np.array(recon.cpu()), (lat_dims, number_of_samples, *dsize) + np.array(recon.cpu()), (lat_dims, sampling_rate, *dsize) ) grid_for_napari = create_grid_for_plotting( - lat_dims, number_of_samples, dsize, padding + lat_dims, sampling_rate, dsize, padding ) grid_for_napari = fill_grid_for_plottting( - lat_dims, number_of_samples, grid_for_napari, dsize, recon, padding + lat_dims, sampling_rate, grid_for_napari, dsize, recon, padding ) if data_dim == 3: @@ -1182,27 +1187,32 @@ def pose_disentanglement_plot( "Visualising pose disentanglement for class {}...\n".format(label) ) - number_of_samples = 7 - padding = 0 + # every SAMPLING_STEP interval from -SIGMA to SIGMA + sigma = 2 + sampling_rate = 15 + sampling_step = (sigma * 2) / (sampling_rate - 1) + padding = 0 lats = np.asarray(lats) poses = np.asarray(poses) pos_means = np.mean(poses, axis=0) pos_stds = np.std(poses, axis=0) pos_dims = poses.shape[-1] - pos_grid = np.zeros((pos_dims * number_of_samples, pos_dims)) + pos_grid = np.zeros((pos_dims * sampling_rate, pos_dims)) lat_means = np.mean(lats, axis=0) lat_dims = lats.shape[-1] - lat_grid = np.zeros((pos_dims * number_of_samples, lat_dims)) + lat_means + lat_grid = np.zeros((pos_dims * sampling_rate, lat_dims)) + lat_means # Generate vectors representing single transversals along each lat_dim for p_dim in range(pos_dims): - for grid_spot in range(number_of_samples): + for grid_spot in range(sampling_rate): means = copy.deepcopy(pos_means) - means[p_dim] += pos_stds[p_dim] * (-1.2 + 0.4 * grid_spot) - pos_grid[p_dim * number_of_samples + grid_spot, :] = means + means[p_dim] += pos_stds[p_dim] * ( + -sigma + sampling_step * grid_spot + ) + pos_grid[p_dim * sampling_rate + grid_spot, :] = means # Decode interpolated vectors with torch.no_grad(): @@ -1222,15 +1232,15 @@ def pose_disentanglement_plot( recon = np.reshape( np.array(recon.cpu()), - (pos_dims, number_of_samples, *dsize), + (pos_dims, sampling_rate, *dsize), ) grid_for_napari = create_grid_for_plotting( - pos_dims, number_of_samples, dsize, padding + pos_dims, sampling_rate, dsize, padding ) # Create and save the mrc file with single transversals grid_for_napari = fill_grid_for_plottting( - pos_dims, number_of_samples, grid_for_napari, dsize, recon, padding + pos_dims, sampling_rate, grid_for_napari, dsize, recon, padding ) if data_dim == 3: