|
16 | 16 | import functools |
17 | 17 | from os import path |
18 | 18 | import re |
| 19 | +from threading import Thread |
19 | 20 | from unittest import mock |
20 | 21 |
|
21 | 22 | from absl import flags |
|
33 | 34 | from gematria.testing.python import model_test |
34 | 35 | import numpy as np |
35 | 36 | import tensorflow.compat.v1 as tf |
| 37 | +from tensorflow import profiler |
36 | 38 |
|
37 | 39 | FLAGS = flags.FLAGS |
38 | 40 |
|
@@ -831,6 +833,72 @@ def test_multi_task_flags(self): |
831 | 833 | FLAGS.gematria_throughput_source_filter = ['alice', 'bob'] |
832 | 834 | FLAGS.validate_all_flags() |
833 | 835 |
|
| 836 | + @flagsaver.flagsaver |
| 837 | + def test_train_under_tf_profiler(self): |
| 838 | + """Tests the profiling of model training using the TF Profiler. |
| 839 | +
|
| 840 | + The tests prepares training data and runs the actual training for a small |
| 841 | + number of epochs under the TF Profiler. Then checks that the expected profiles |
| 842 | + were recorded and stored at the expected directory. |
| 843 | + """ |
| 844 | + num_epochs = 10 |
| 845 | + max_blocks_in_batch = 15 |
| 846 | + max_instructions_in_batch = 124 |
| 847 | + learning_rate = 0.321 |
| 848 | + randomize_batches = False |
| 849 | + training_throughput_selection = io_options.ThroughputSelection.RANDOM |
| 850 | + checkpoint_dir = path.join(self.work_directory.full_path, 'checkpoint') |
| 851 | + summary_dir = path.join(self.work_directory.full_path, 'summary') |
| 852 | + tf_profiler_port = 6009 |
| 853 | + use_seq2seq_loss = False # The default is True. |
| 854 | + |
| 855 | + model = None |
| 856 | + |
| 857 | + def MockModel(*args, **kwargs): |
| 858 | + nonlocal model |
| 859 | + self.assertEqual(kwargs['learning_rate'], learning_rate) |
| 860 | + model = TestModel(*args, **kwargs) |
| 861 | + # Record calls to model.train(), but still call the original method. |
| 862 | + mock_train = mock.MagicMock(side_effect=model.train) |
| 863 | + model.train = mock_train |
| 864 | + return model |
| 865 | + |
| 866 | + FLAGS.gematria_action = model_options.Action.TRAIN |
| 867 | + FLAGS.gematria_run_tf_profiler = True |
| 868 | + FLAGS.gematria_tf_profiler_port = tf_profiler_port |
| 869 | + FLAGS.gematria_input_file = (self.input_filename,) |
| 870 | + FLAGS.gematria_checkpoint_dir = checkpoint_dir |
| 871 | + FLAGS.gematria_summary_dir = summary_dir |
| 872 | + FLAGS.gematria_training_num_epochs = num_epochs |
| 873 | + FLAGS.gematria_training_randomize_batches = randomize_batches |
| 874 | + FLAGS.gematria_max_blocks_in_batch = max_blocks_in_batch |
| 875 | + FLAGS.gematria_max_instructions_in_batch = max_instructions_in_batch |
| 876 | + FLAGS.gematria_use_seq2seq_loss = use_seq2seq_loss |
| 877 | + FLAGS.gematria_learning_rate = learning_rate |
| 878 | + FLAGS.gematria_training_throughput_selection = training_throughput_selection |
| 879 | + |
| 880 | + # Set up a thread for the training process running the profiling server. |
| 881 | + server_thread = Thread( |
| 882 | + target=main_function.run_gematria_model_from_command_line_flags, |
| 883 | + args=(MockModel,), |
| 884 | + kwargs={'dtype': tf.dtypes.float32}, |
| 885 | + ) |
| 886 | + server_thread.start() |
| 887 | + |
| 888 | + # Try sending a trace request to the TF Profiler. |
| 889 | + profiler.experimental.client.trace( |
| 890 | + service_addr=f'grpc://localhost:{tf_profiler_port}', |
| 891 | + logdir=summary_dir, |
| 892 | + duration_ms=1000, |
| 893 | + num_tracing_attempts=4000, # Keep trying until the server is ready. |
| 894 | + ) |
| 895 | + server_thread.join() |
| 896 | + |
| 897 | + # Check that profile has been written to the expected location. |
| 898 | + self._assert_file_exists( |
| 899 | + f'summary/plugins/profile/*/localhost_{tf_profiler_port}.xplane.pb' |
| 900 | + ) |
| 901 | + |
834 | 902 |
|
835 | 903 | if __name__ == '__main__': |
836 | 904 | tf.disable_v2_behavior() |
|
0 commit comments