Skip to content

Commit 8e20deb

Browse files
committed
refactor: rename pure_forward to reconstruct
1 parent 25f81d7 commit 8e20deb

3 files changed

Lines changed: 4 additions & 5 deletions

File tree

src/spherinator/callbacks/log_reconstruction_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def on_train_epoch_end(self, trainer, model):
5252
images = images.to(model.device)
5353

5454
# Generate reconstructions of the samples using the model
55-
recon = model.pure_forward(images)
55+
recon = model.reconstruct(images)
5656
loss = torch.nn.MSELoss(reduction="none")(images, recon).flatten(1).mean(1)
5757

5858
# Plot the original samples and their reconstructions side by side

src/spherinator/models/autoencoder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ def forward(self, x):
3838
x = self.encode(x)
3939
return self.decode(x)
4040

41-
def pure_forward(self, x):
42-
x = self.encode(x)
43-
return self.decode(x)
41+
def reconstruct(self, x):
42+
return self.forward(x)
4443

4544
def _compute_loss(self, batch):
4645
recon = self.forward(batch)

src/spherinator/models/variational_autoencoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def forward(self, x):
118118
recon = self.decode(z)
119119
return (z_location, z_scale), (q_z, p_z), z, recon
120120

121-
def pure_forward(self, x):
121+
def reconstruct(self, x):
122122
z_location, _ = self.encode(x)
123123
return self.decode(z_location)
124124

0 commit comments

Comments
 (0)