Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions pyroved/models/ivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
**kwargs: Union[str, float]
) -> None:
args = (data_dim, invariances)
super(iVAE, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

# Reset the pyro ParamStoreDict object's dictionaries
pyro.clear_param_store()
Expand Down Expand Up @@ -243,37 +243,59 @@ def decode(self,
loc = self._decode(z, **kwargs)
return loc

def manifold2d(self, d: int,
y: torch.Tensor = None,
plot: bool = True,
**kwargs: Union[str, int, float]) -> torch.Tensor:
"""
Plots a learned latent manifold in the image space
def manifold2d(
self, d: int,
y: torch.Tensor = None,
plot: bool = True,
figsize: Tuple[float] = (8.0, 8.0),
latents: Tuple[int] = [0, 1],
**kwargs: Union[str, int, float]
) -> torch.Tensor:
"""Plots a learned latent manifold in the image space

Args:
d: Grid size
plot: Plots the generated manifold (Default: True)
y: Conditional "property" vector (e.g. one-hot encoded class vector)
kwargs: Keyword arguments include custom min/max values
for grid boundaries passed as 'z_coord'
(e.g. z_coord = [-3, 3, -3, 3]), 'angle' and
'shift' to condition a generative model on, and plot parameters
('padding', 'padding_value', 'cmap', 'origin', 'ylim')
d:
Grid size
plot:
Plots the generated manifold (Default: True)
y:
Conditional "property" vector (e.g. one-hot encoded class
vector)
kwargs:
Keyword arguments include custom min/max values for grid
boundaries passed as 'z_coord' (e.g. z_coord = [-3, 3, -3, 3]),
'angle' and 'shift' to condition a generative model on, and
plot parameters ('padding', 'padding_value', 'cmap', 'origin',
'ylim').
"""
z, (grid_x, grid_y) = generate_latent_grid(d, **kwargs)
z = [z]

# We silence all other latent variables except those listed in the
# latent list by setting them to zero. Note this choice is arbitrary
# and is just a consequence of only being able to easily visualize
# two axes at once.
z_tmp = torch.zeros(size=(z.shape[0], self.z_dim - self.coord))
z_tmp[:, latents[0]] = z[:, 0]
z_tmp[:, latents[1]] = z[:, 1]
z = [z_tmp]

if self.c_dim > 0:
if y is None:
raise ValueError("To generate a manifold pass a conditional vector y")
raise ValueError(
"To generate a manifold pass a conditional vector y"
)
y = y.unsqueeze(1) if 0 < y.ndim < 2 else y
z = z + [y.expand(z[0].shape[0], *y.shape[1:])]

loc = self.decode(*z, **kwargs)

if plot:
if self.ndim == 2:
plot_img_grid(
loc, d,
extent=[grid_x.min(), grid_x.max(), grid_y.min(), grid_y.max()],
**kwargs)
loc, d, extent=[
grid_x.min(), grid_x.max(), grid_y.min(),
grid_y.max()
], figsize=figsize, **kwargs)
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
plot_spect_grid(loc, d, figsize=figsize, **kwargs)
return loc
25 changes: 16 additions & 9 deletions pyroved/trainers/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,22 @@ def step(self,
self.loss_history["test_loss"].append(test_loss)
self.current_epoch += 1

def print_statistics(self) -> None:
"""
Prints training and test (if any) losses for current epoch
def print_statistics(self, print_every: int = 1) -> None:
"""Prints training and test (if any) losses for current epoch.

Args:
print_every:
Only prints every "print_every" epochs.
"""

e = self.current_epoch

if e % print_every != 0:
return

print(f"Epoch {e:04}")
training_loss = self.loss_history["training_loss"][-1]
print(f"\tTraining loss: {training_loss:.04f}")
if len(self.loss_history["test_loss"]) > 0:
template = 'Epoch: {} Training loss: {:.4f}, Test loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1],
self.loss_history["test_loss"][-1]))
else:
template = 'Epoch: {} Training loss: {:.4f}'
print(template.format(e, self.loss_history["training_loss"][-1]))
test_loss = self.loss_history["test_loss"][-1]
print(f"\tTesting loss: {test_loss:.04f}")
39 changes: 24 additions & 15 deletions pyroved/utils/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
import matplotlib.pyplot as plt


def plot_img_grid(imgdata: torch.Tensor, d: int,
**kwargs: Union[str, int, List[float]]) -> None:
"""
Plots a *d*-by-*d* square grid of 2D images
"""
def plot_img_grid(
imgdata: torch.Tensor, d: int, figsize: Tuple[float] = (8.0, 8.0),
**kwargs: Union[str, int, List[float]]
) -> None:
"""Plots a *d*-by-*d* square grid of 2D images."""

if imgdata.ndim < 3:
raise AssertionError("Images must be passed as a 3D or 4D tensor")
imgdata = imgdata[:, None] if imgdata.ndim == 3 else imgdata
grid = make_grid(imgdata, nrow=d,
padding=kwargs.get("padding", 2),
pad_value=kwargs.get("pad_value", 0))
grid = make_grid(
imgdata, nrow=d, padding=kwargs.get("padding", 2),
pad_value=kwargs.get("pad_value", 0)
)

plt.figure(figsize=(8, 8))
plt.figure(figsize=figsize)
plt.imshow(grid[0].squeeze(), cmap=kwargs.get("cmap", "gnuplot"),
origin=kwargs.get("origin", "upper"),
extent=kwargs.get("extent"))
Expand All @@ -27,13 +29,20 @@ def plot_img_grid(imgdata: torch.Tensor, d: int,
plt.show()


def plot_spect_grid(spectra: torch.Tensor, d: int, **kwargs: List[float]): # TODO: Add 'axes' and 'extent'
"""
Plots a *d*-by-*d* square grid with 1D spectral plots
def plot_spect_grid(
spectra: torch.Tensor, d: int, figsize: Tuple[float] = (8.0, 8.0),
**kwargs: List[float]
) -> None:
"""Plots a *d*-by-*d* square grid with 1D spectral plots.

TODO: Add 'axes' and 'extent'
"""
_, axes = plt.subplots(d, d, figsize=(8, 8),
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1))

_, axes = plt.subplots(
d, d, figsize=figsize,
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1)
)
ylim = kwargs.get("ylim")
for ax, y in zip(axes.flat, spectra):
ax.plot(y.squeeze())
Expand Down