Skip to content

Commit 4533da3

Browse files
committed
Add unit test for TF Profiler.
* Tests the profiling server by sending a request and ensuring the profile is written to the expected location.
1 parent 58b4ab8 commit 4533da3

1 file changed

Lines changed: 68 additions & 0 deletions

File tree

gematria/model/python/main_function_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import functools
1717
from os import path
1818
import re
19+
from threading import Thread
1920
from unittest import mock
2021

2122
from absl import flags
@@ -33,6 +34,7 @@
3334
from gematria.testing.python import model_test
3435
import numpy as np
3536
import tensorflow.compat.v1 as tf
37+
from tensorflow import profiler
3638

3739
FLAGS = flags.FLAGS
3840

@@ -831,6 +833,72 @@ def test_multi_task_flags(self):
831833
FLAGS.gematria_throughput_source_filter = ['alice', 'bob']
832834
FLAGS.validate_all_flags()
833835

836+
@flagsaver.flagsaver
837+
def test_train_under_tf_profiler(self):
838+
"""Tests the profiling of model training using the TF Profiler.
839+
840+
The tests prepares training data and runs the actual training for a small
841+
number of epochs under the TF Profiler. Then checks that the expected profiles
842+
were recorded and stored at the expected directory.
843+
"""
844+
num_epochs = 10
845+
max_blocks_in_batch = 15
846+
max_instructions_in_batch = 124
847+
learning_rate = 0.321
848+
randomize_batches = False
849+
training_throughput_selection = io_options.ThroughputSelection.RANDOM
850+
checkpoint_dir = path.join(self.work_directory.full_path, 'checkpoint')
851+
summary_dir = path.join(self.work_directory.full_path, 'summary')
852+
tf_profiler_port = 6009
853+
use_seq2seq_loss = False # The default is True.
854+
855+
model = None
856+
857+
def MockModel(*args, **kwargs):
858+
nonlocal model
859+
self.assertEqual(kwargs['learning_rate'], learning_rate)
860+
model = TestModel(*args, **kwargs)
861+
# Record calls to model.train(), but still call the original method.
862+
mock_train = mock.MagicMock(side_effect=model.train)
863+
model.train = mock_train
864+
return model
865+
866+
FLAGS.gematria_action = model_options.Action.TRAIN
867+
FLAGS.gematria_run_tf_profiler = True
868+
FLAGS.gematria_tf_profiler_port = tf_profiler_port
869+
FLAGS.gematria_input_file = (self.input_filename,)
870+
FLAGS.gematria_checkpoint_dir = checkpoint_dir
871+
FLAGS.gematria_summary_dir = summary_dir
872+
FLAGS.gematria_training_num_epochs = num_epochs
873+
FLAGS.gematria_training_randomize_batches = randomize_batches
874+
FLAGS.gematria_max_blocks_in_batch = max_blocks_in_batch
875+
FLAGS.gematria_max_instructions_in_batch = max_instructions_in_batch
876+
FLAGS.gematria_use_seq2seq_loss = use_seq2seq_loss
877+
FLAGS.gematria_learning_rate = learning_rate
878+
FLAGS.gematria_training_throughput_selection = training_throughput_selection
879+
880+
# Set up a thread for the training process running the profiling server.
881+
server_thread = Thread(
882+
target=main_function.run_gematria_model_from_command_line_flags,
883+
args=(MockModel,),
884+
kwargs={'dtype': tf.dtypes.float32},
885+
)
886+
server_thread.start()
887+
888+
# Try sending a trace request to the TF Profiler.
889+
profiler.experimental.client.trace(
890+
service_addr=f'grpc://localhost:{tf_profiler_port}',
891+
logdir=summary_dir,
892+
duration_ms=1000,
893+
num_tracing_attempts=4000, # Keep trying until the server is ready.
894+
)
895+
server_thread.join()
896+
897+
# Check that profile has been written to the expected location.
898+
self._assert_file_exists(
899+
f'summary/plugins/profile/*/localhost_{tf_profiler_port}.xplane.pb'
900+
)
901+
834902

835903
if __name__ == '__main__':
836904
tf.disable_v2_behavior()

0 commit comments

Comments
 (0)