3232import math
3333import os
3434import random
35- from typing import Optional , TypeVar , Union
35+ from typing import Optional , TypeVar , Union , TypedDict
3636
3737from absl import logging
3838from gematria .basic_block .python import basic_block
@@ -120,6 +120,12 @@ def after_run(self, run_context: ..., run_values: ...):
120120 )
121121
122122
123+ class OutputDict (TypedDict ):
124+ output : tf .Tensor | None
125+ output_deltas : tf .Tensor | None
126+ output_mask_deltas : tf .Tensor | None
127+
128+
123129class ModelBase (tf .Module , metaclass = abc .ABCMeta ):
124130 """Base class for Gematria basic block processing models.
125131
@@ -441,42 +447,43 @@ def output_tensor_names(self) -> Sequence[str]:
441447 return (ModelBase .OUTPUT_TENSOR_NAME ,)
442448
443449 @abc .abstractmethod
444- def _forward (self , feed_dict : FeedDict ) -> dict [ str , tf . Tensor ] :
450+ def _forward (self , feed_dict : FeedDict ) -> OutputDict :
445451 """Implements the forward pass of the model.
446452
447453 This function should be implemented in downstream models and calculate the
448454 outputs of the model given the inputs specified in feed_dict.
455+
456+ Returns:
457+ A dictionary containing tensors. This should contain a key, 'output' in
458+ seq2num mode and a key named 'output_deltas' in seq2seq mode.
449459 """
450460 pass
451461
452- def __call__ (self , feed_dict , train = False ):
462+ def __call__ (self , feed_dict : FeedDict , train = False ) -> OutputDict :
453463 """Implements the non-model specific part of the forward pass.
454464
455465 This function wraps the _forward method and does relevant calculations
456466 when working with models that use deltas.
457467 """
458- if self ._use_deltas :
459- output_dict = {}
460-
461- if train :
462- output_dict ['output_mask_deltas' ] = tf .nn .embedding_lookup (
463- feed_dict ['output_mask' ],
464- feed_dict ['delta_block_index' ],
465- name = 'ModelBase.output_mask_deltas' ,
466- )
468+ if not self ._use_deltas :
469+ return self ._forward (feed_dict )
467470
468- output = self ._forward (feed_dict )
471+ output = self ._forward (feed_dict )
469472
470- output_dict ['output' ] = tf .math .segment_sum (
471- output ['output_deltas' ],
473+ if train :
474+ output ['output_mask_deltas' ] = tf .nn .embedding_lookup (
475+ feed_dict ['output_mask' ],
472476 feed_dict ['delta_block_index' ],
473- name = ModelBase .OUTPUT_TENSOR_NAME ,
477+ name = ' ModelBase.output_mask_deltas' ,
474478 )
475- output_dict ['output_deltas' ] = output ['output_deltas' ]
476479
477- return output_dict
478- else :
479- return self ._forward (feed_dict )
480+ output ['output' ] = tf .math .segment_sum (
481+ output ['output_deltas' ],
482+ feed_dict ['delta_block_index' ],
483+ name = ModelBase .OUTPUT_TENSOR_NAME ,
484+ )
485+
486+ return output
480487
481488 @abc .abstractmethod
482489 def _make_model_name (self ) -> str :
@@ -1151,9 +1158,9 @@ def predict(
11511158 batch_output_blocks = []
11521159 with timer .scoped ('ModelBase.predict - one batch' ):
11531160 schedule = self .schedule_batch (batch )
1161+ output_dict = self (schedule )
1162+ output = output_dict ['output' ]
11541163 if self ._use_deltas :
1155- output_dict = self (schedule )
1156- output = output_dict ['output' ]
11571164 output_deltas = output_dict ['output_deltas' ]
11581165 output_index = 0
11591166 for block_index , block in enumerate (batch ):
@@ -1273,9 +1280,9 @@ def run_one_epoch():
12731280 for _ in range (0 , num_epochs ):
12741281 stats = run_one_epoch ()
12751282 logging .info ('Training: %s' , stats )
1276- return stats
1283+ return stats
12771284
1278- def compute_loss (self , schedule : FeedDict ):
1285+ def _compute_loss (self , schedule : FeedDict ) -> loss_utils . LossComputation :
12791286 output = self (schedule , train = True )
12801287 loss = loss_utils .LossComputation (
12811288 output ['output' ],
@@ -1301,7 +1308,7 @@ def compute_loss(self, schedule: FeedDict):
13011308
13021309 def compute_loss_tensor (self , schedule : FeedDict ):
13031310 return tf .reduce_mean (
1304- self .compute_loss (schedule ).loss_tensor (
1311+ self ._compute_loss (schedule ).loss_tensor (
13051312 self ._loss_normalization , self ._loss_type
13061313 )
13071314 )
@@ -1327,27 +1334,25 @@ def train_batch(
13271334 # TrainingEpochStats.__init__() as keyword arguments.
13281335 with tf .GradientTape () as tape :
13291336 stats = {}
1330- loss = self .compute_loss (schedule )
1337+ loss = self ._compute_loss (schedule )
13311338 loss_tensor_per_task = loss .loss_tensor (
13321339 self ._loss_normalization , self ._loss_type
13331340 )
13341341 loss_tensor = tf .reduce_mean (loss_tensor_per_task )
13351342
13361343 # The list of variables to optimize. By default, the list is empty which
13371344 # means optimize all trainable variables.
1338- variables = set ()
1345+ requested_variables = set ()
13391346 for variable_group in self ._trained_variable_groups :
1340- variables .update (
1341- [
1342- variable .ref ()
1343- for variable in self ._variable_groups .get (variable_group )
1344- ]
1347+ requested_variables .update (
1348+ variable .ref ()
1349+ for variable in self ._variable_groups .get (variable_group )
13451350 )
13461351
13471352 trainable_variables = self ._get_trainable_variables ()
13481353 variables = (
1349- [variable .deref () for variable in variables ]
1350- if variables
1354+ [variable .deref () for variable in requested_variables ]
1355+ if requested_variables
13511356 else trainable_variables
13521357 )
13531358
0 commit comments