Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions keras/src/callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,16 @@ def _configure_embeddings(self):
with file_utils.File(path, "w") as f:
f.write(config_pbtxt)

def _push_writer(self, writer, step):
def _push_writer(self, writer, step_var):
"""Sets the default writer for custom batch-level summaries."""
if self.update_freq == "epoch":
return

def should_record():
return step % self.update_freq == 0
return step_var % self.update_freq == 0

summary_context = (
writer.as_default(step),
writer.as_default(step_var),
self.summary.record_if(should_record),
)
self._prev_summary_state.append(summary_context)
Expand Down Expand Up @@ -404,9 +404,12 @@ def _init_profile_batch(self, profile_batch):
)

def on_train_begin(self, logs=None):
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import tensorflow as tf statement is also present in on_test_begin. To avoid duplication and improve maintainability, consider importing tensorflow once at the top of the file using the Keras-idiomatic lazy loader:

# At the top of keras/src/callbacks/tensorboard.py
from keras.src.utils.module_utils import tensorflow as tf

This would allow you to remove the local imports from both on_train_begin and on_test_begin.


self._global_train_batch = 0
self._previous_epoch_iterations = 0
self._push_writer(self._train_writer, self._global_train_batch)
self._train_step_var = tf.Variable(0, dtype=tf.int64, trainable=False)
self._push_writer(self._train_writer, self._train_step_var)

def on_train_end(self, logs=None):
self._pop_writer()
Expand All @@ -417,7 +420,10 @@ def on_train_end(self, logs=None):
self._close_writers()

def on_test_begin(self, logs=None):
self._push_writer(self._val_writer, self._global_test_batch)
import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import tensorflow as tf is a duplicate of the one in on_train_begin. As suggested in the other comment, this can be de-duplicated by moving the import to the top of the file for better code maintainability.


self._test_step_var = tf.Variable(0, dtype=tf.int64, trainable=False)
self._push_writer(self._val_writer, self._test_step_var)

def on_test_end(self, logs=None):
if self.model.optimizer and hasattr(self.model.optimizer, "iterations"):
Expand All @@ -432,6 +438,8 @@ def on_test_end(self, logs=None):

def on_train_batch_begin(self, batch, logs=None):
self._global_train_batch += 1
if hasattr(self, "_train_step_var"):
self._train_step_var.assign(self._global_train_batch)
if self.write_steps_per_second:
self._batch_start_time = time.time()
if not self._should_trace:
Expand Down Expand Up @@ -475,6 +483,8 @@ def on_train_batch_end(self, batch, logs=None):

def on_test_batch_begin(self, batch, logs=None):
self._global_test_batch += 1
if hasattr(self, "_test_step_var"):
self._test_step_var.assign(self._global_test_batch)

def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
Expand Down