Skip to content

Commit 45f7739

Browse files
comments
Created using spr 1.3.4
2 parents fd71a54 + 2c0dd22 commit 45f7739

2 files changed

Lines changed: 41 additions & 36 deletions

File tree

gematria/model/python/loss_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ def test_multi_task_unknown_shape(self):
320320

321321
def test_single_task_unknown_shape(self):
322322
num_tasks = 1
323-
actual_output = tf.reshape(self.actual_outputs_array, (-1, 1))
324-
expected_output = tf.reshape(self.expected_outputs_array, (-1, 1))
323+
actual_output = tf.constant(self.actual_outputs_array)
324+
expected_output = tf.constant(self.expected_outputs_array)
325325
mask = tf.ones_like(actual_output, tf.dtypes.bool)
326326
percentile_ranks = (0, 50, 75, 100)
327327

gematria/model/python/model_base.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import math
3333
import os
3434
import random
35-
from typing import Optional, TypeVar, Union
35+
from typing import Optional, TypeVar, Union, TypedDict
3636

3737
from absl import logging
3838
from gematria.basic_block.python import basic_block
@@ -120,6 +120,12 @@ def after_run(self, run_context: ..., run_values: ...):
120120
)
121121

122122

123+
class OutputDict(TypedDict):
124+
output: tf.Tensor | None
125+
output_deltas: tf.Tensor | None
126+
output_mask_deltas: tf.Tensor | None
127+
128+
123129
class ModelBase(tf.Module, metaclass=abc.ABCMeta):
124130
"""Base class for Gematria basic block processing models.
125131
@@ -441,42 +447,43 @@ def output_tensor_names(self) -> Sequence[str]:
441447
return (ModelBase.OUTPUT_TENSOR_NAME,)
442448

443449
@abc.abstractmethod
444-
def _forward(self, feed_dict: FeedDict) -> dict[str, tf.Tensor]:
450+
def _forward(self, feed_dict: FeedDict) -> OutputDict:
445451
"""Implements the forward pass of the model.
446452
447453
This function should be implemented in downstream models and calculate the
448454
outputs of the model given the inputs specified in feed_dict.
455+
456+
Returns:
457+
A dictionary containing tensors. This should contain a key, 'output' in
458+
seq2num mode and a key named 'output_deltas' in seq2seq mode.
449459
"""
450460
pass
451461

452-
def __call__(self, feed_dict, train=False):
462+
def __call__(self, feed_dict: FeedDict, train=False) -> OutputDict:
453463
"""Implements the non-model specific part of the forward pass.
454464
455465
This function wraps the _forward method and does relevant calculations
456466
when working with models that use deltas.
457467
"""
458-
if self._use_deltas:
459-
output_dict = {}
460-
461-
if train:
462-
output_dict['output_mask_deltas'] = tf.nn.embedding_lookup(
463-
feed_dict['output_mask'],
464-
feed_dict['delta_block_index'],
465-
name='ModelBase.output_mask_deltas',
466-
)
468+
if not self._use_deltas:
469+
return self._forward(feed_dict)
467470

468-
output = self._forward(feed_dict)
471+
output = self._forward(feed_dict)
469472

470-
output_dict['output'] = tf.math.segment_sum(
471-
output['output_deltas'],
473+
if train:
474+
output['output_mask_deltas'] = tf.nn.embedding_lookup(
475+
feed_dict['output_mask'],
472476
feed_dict['delta_block_index'],
473-
name=ModelBase.OUTPUT_TENSOR_NAME,
477+
name='ModelBase.output_mask_deltas',
474478
)
475-
output_dict['output_deltas'] = output['output_deltas']
476479

477-
return output_dict
478-
else:
479-
return self._forward(feed_dict)
480+
output['output'] = tf.math.segment_sum(
481+
output['output_deltas'],
482+
feed_dict['delta_block_index'],
483+
name=ModelBase.OUTPUT_TENSOR_NAME,
484+
)
485+
486+
return output
480487

481488
@abc.abstractmethod
482489
def _make_model_name(self) -> str:
@@ -1151,9 +1158,9 @@ def predict(
11511158
batch_output_blocks = []
11521159
with timer.scoped('ModelBase.predict - one batch'):
11531160
schedule = self.schedule_batch(batch)
1161+
output_dict = self(schedule)
1162+
output = output_dict['output']
11541163
if self._use_deltas:
1155-
output_dict = self(schedule)
1156-
output = output_dict['output']
11571164
output_deltas = output_dict['output_deltas']
11581165
output_index = 0
11591166
for block_index, block in enumerate(batch):
@@ -1273,9 +1280,9 @@ def run_one_epoch():
12731280
for _ in range(0, num_epochs):
12741281
stats = run_one_epoch()
12751282
logging.info('Training: %s', stats)
1276-
return stats
1283+
return stats
12771284

1278-
def compute_loss(self, schedule: FeedDict):
1285+
def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation:
12791286
output = self(schedule, train=True)
12801287
loss = loss_utils.LossComputation(
12811288
output['output'],
@@ -1301,7 +1308,7 @@ def compute_loss(self, schedule: FeedDict):
13011308

13021309
def compute_loss_tensor(self, schedule: FeedDict):
13031310
return tf.reduce_mean(
1304-
self.compute_loss(schedule).loss_tensor(
1311+
self._compute_loss(schedule).loss_tensor(
13051312
self._loss_normalization, self._loss_type
13061313
)
13071314
)
@@ -1327,27 +1334,25 @@ def train_batch(
13271334
# TrainingEpochStats.__init__() as keyword arguments.
13281335
with tf.GradientTape() as tape:
13291336
stats = {}
1330-
loss = self.compute_loss(schedule)
1337+
loss = self._compute_loss(schedule)
13311338
loss_tensor_per_task = loss.loss_tensor(
13321339
self._loss_normalization, self._loss_type
13331340
)
13341341
loss_tensor = tf.reduce_mean(loss_tensor_per_task)
13351342

13361343
# The list of variables to optimize. By default, the list is empty which
13371344
# means optimize all trainable variables.
1338-
variables = set()
1345+
requested_variables = set()
13391346
for variable_group in self._trained_variable_groups:
1340-
variables.update(
1341-
[
1342-
variable.ref()
1343-
for variable in self._variable_groups.get(variable_group)
1344-
]
1347+
requested_variables.update(
1348+
variable.ref()
1349+
for variable in self._variable_groups.get(variable_group)
13451350
)
13461351

13471352
trainable_variables = self._get_trainable_variables()
13481353
variables = (
1349-
[variable.deref() for variable in variables]
1350-
if variables
1354+
[variable.deref() for variable in requested_variables]
1355+
if requested_variables
13511356
else trainable_variables
13521357
)
13531358

0 commit comments

Comments
 (0)