Skip to content

Commit be99a2f

Browse files
fix: correct total_loss calculate assign (#2301)
* fix: correct total_loss calculate assign * Update examples/generative/vae.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ea77717 commit be99a2f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/generative/vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,13 @@ def train_step(self, data):
110110
total_loss = reconstruction_loss + kl_loss
111111
grads = tape.gradient(total_loss, self.trainable_weights)
112112
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
113-
self.total_loss_tracker.update_state(total_loss)
114113
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
115114
self.kl_loss_tracker.update_state(kl_loss)
115+
self.total_loss_tracker.update_state(total_loss)
116116
return {
117-
"loss": self.total_loss_tracker.result(),
118117
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
119118
"kl_loss": self.kl_loss_tracker.result(),
119+
"total_loss": self.total_loss_tracker.result(),
120120
}
121121

122122

0 commit comments

Comments
 (0)