diff --git a/gematria/granite/python/BUILD.bazel b/gematria/granite/python/BUILD.bazel index 3cedf7ab..31069115 100644 --- a/gematria/granite/python/BUILD.bazel +++ b/gematria/granite/python/BUILD.bazel @@ -122,9 +122,6 @@ gematria_py_test( timeout = "moderate", srcs = ["rnn_token_model_test.py"], shard_count = 18, - tags = [ - "manual", - ], deps = [ ":rnn_token_model", "//gematria/basic_block/python:tokens", diff --git a/gematria/granite/python/rnn_token_model.py b/gematria/granite/python/rnn_token_model.py index deacc171..67263bdb 100644 --- a/gematria/granite/python/rnn_token_model.py +++ b/gematria/granite/python/rnn_token_model.py @@ -19,7 +19,7 @@ from gematria.granite.python import token_graph_builder_model from gematria.model.python import options -import tensorflow.compat.v1 as tf +import tensorflow as tf import tf_keras _RNN_TYPE_TO_TF = {'LSTM': tf_keras.layers.LSTM, 'GRU': tf_keras.layers.GRU} @@ -115,29 +115,35 @@ def _make_model_name(self) -> str: f',bidirectional={self._rnn_bidirectional}' ) - def _create_readout_network(self) -> tf.Tensor: - instruction_features = self._instruction_features + def initialize(self) -> None: + super().initialize() + if self._readout_input_layer_normalization: + self._rnn_layer_normalization = tf_keras.layers.LayerNormalization() + # TODO(ayazdan): Figure out how to pass `training` flag to the pipeline. + self._rnn_layer = _RNN_TYPE_TO_TF[self._rnn_type.name]( + self._rnn_output_size, + dropout=self._rnn_dropout, + return_sequences=self._use_deltas, + ) + if self._rnn_bidirectional: + self._rnn_layer = tf_keras.layers.Bidirectional(self._rnn_layer) + + def _execute_readout_network(self, graph_tuple, feed_dict) -> tf.Tensor: + instruction_features = tf.boolean_mask( + graph_tuple.nodes, feed_dict['instruction_node_mask'] + ) # Normalize the instruction features if needed. if self._readout_input_layer_normalization: - layer_normalization = tf_keras.layers.LayerNormalization() - instruction_features = layer_normalization(instruction_features) + instruction_features = self._rnn_layer_normalization(instruction_features) # A ragged tensor that contains the basic blocks in the batch. Each element # of the ragged tensor corresponds to one basic blocks in the batch, and it # contains a sequence of feature vectors of the instructions in the basic # block. blocks_ragged = tf.RaggedTensor.from_value_rowids( - self._instruction_features, self._delta_block_index_tensor + instruction_features, feed_dict['delta_block_index'] ) - # TODO(ayazdan): Figure out how to pass `training` flag to the pipeline. - rnn_layer = _RNN_TYPE_TO_TF[self._rnn_type.name]( - self._rnn_output_size, - dropout=self._rnn_dropout, - return_sequences=self._use_deltas, - ) - if self._rnn_bidirectional: - rnn_layer = tf_keras.layers.Bidirectional(rnn_layer) # Depending on the value of self._use_deltas: # * In the seq2num mode (self._use_deltas == False), rnn_outputs contains @@ -146,7 +152,7 @@ def _create_readout_network(self) -> tf.Tensor: # * In the seq2seq mode (self._use_deltas == True), it is a ragged vector # in the same format as blocks_ragged, and for each instruction we have # the output of the RNN cell at the corresponding position. - rnn_outputs = rnn_layer(blocks_ragged) + rnn_outputs = self._rnn_layer(blocks_ragged) if self._use_deltas: # In seq2seq mode, convert the ragged tensor back to a normal tensor that @@ -158,4 +164,4 @@ def _create_readout_network(self) -> tf.Tensor: # different semantic in seq2seq vs seq2num modes, the network has exactly # the same structure. The outputs of the RNN network are already in the # (-1, 1) range, so we skip any additional normalization steps. - return self._create_dense_readout_network(rnn_outputs) + return self._execute_dense_readout_network(rnn_outputs) diff --git a/gematria/granite/python/rnn_token_model_test.py b/gematria/granite/python/rnn_token_model_test.py index 424645ed..a91e5a61 100644 --- a/gematria/granite/python/rnn_token_model_test.py +++ b/gematria/granite/python/rnn_token_model_test.py @@ -18,7 +18,7 @@ from gematria.model.python import oov_token_behavior from gematria.model.python import options from gematria.testing.python import model_test -import tensorflow.compat.v1 as tf +import tensorflow as tf _OutOfVocabularyTokenBehavior = oov_token_behavior.OutOfVocabularyTokenBehavior @@ -205,5 +205,4 @@ def test_train_seq2num_multi_task(self): if __name__ == '__main__': - tf.disable_v2_behavior() tf.test.main()