Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions t5x/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def evaluate(
input_shapes=input_shapes,
partitioner=partitioner,
)
train_state_axes = train_state_initializer.train_state_axes
train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types
# Log the variable shapes information and write to a file.
log_file = os.path.join(output_dir, 'model-info.txt')
utils.log_model_info(
log_file, train_state_initializer.global_train_state_shape, partitioner
log_file, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types
)

if training_evaluator_cls:
Expand Down
4 changes: 2 additions & 2 deletions t5x/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_train_state_initializer(
partitioner=partitioner,
)
utils.log_model_info(
None, train_state_initializer.global_train_state_shape, partitioner
None, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types
)
return train_state_initializer

Expand Down Expand Up @@ -317,7 +317,7 @@ def create_inference_function(
# TODO(b/121310741): Re-enable pytype.
# pytype:disable=wrong-arg-types
in_axis_resources=(
train_state_initializer.train_state_axes.params,
train_state_initializer.train_state_axes.params, # pytype: disable=attribute-error # jax-api-types
partitioning.PartitionSpec(
'data',
),
Expand Down
4 changes: 2 additions & 2 deletions t5x/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _get_dataset(dataset_provider):
if shard_id == 0:
utils.log_model_info(
model_info_log_file,
train_state_initializer.global_train_state_shape,
train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types
partitioner,
)

Expand Down Expand Up @@ -541,7 +541,7 @@ def _get_dataset(dataset_provider):
utils.get_infer_fn(
infer_step=infer_step,
batch_size=batch_size,
train_state_axes=train_state_initializer.train_state_axes,
train_state_axes=train_state_initializer.train_state_axes, # pytype: disable=attribute-error # jax-api-types
partitioner=partitioner,
keep_aux_as_numpy=keep_aux_as_numpy,
),
Expand Down
4 changes: 2 additions & 2 deletions t5x/precompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def precompile(
input_types=input_types,
partitioner=partitioner,
)
train_state_shape = train_state_initializer.global_train_state_shape
train_state_axes = train_state_initializer.train_state_axes
train_state_shape = train_state_initializer.global_train_state_shape # pytype: disable=attribute-error # jax-api-types
train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types

def train_step(train_state, batch):
return trainer_lib.train_with_lr( # pytype: disable=wrong-arg-types # jax-ndarray
Expand Down
4 changes: 2 additions & 2 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):

# 3. If no checkpoint to restore, init from scratch.
train_state = train_state or train_state_initializer.from_scratch(init_rng)
train_state_axes = train_state_initializer.train_state_axes
train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types
init_or_restore_secs = time.time() - init_or_restore_tick
logging.info(
'Initialize/restore complete (%.2f seconds).', init_or_restore_secs
Expand All @@ -395,7 +395,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):
# Log the variable shapes information and write to a file.
log_file = os.path.join(model_dir, 'model-info.txt')
utils.log_model_info(
log_file, train_state_initializer.global_train_state_shape, partitioner
log_file, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types
)

# Restore step from last checkpoint or set to 0 if training from scratch.
Expand Down
10 changes: 5 additions & 5 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,9 +1091,9 @@ def from_scratch(self, init_rng: Array) -> train_state_lib.TrainState:
# setup as the training step/loop to initialize everything "in-place" and
# avoid communication or OOM.
p_initialize_train_state_fn = self._partitioner.partition(
self._initialize_train_state,
self._initialize_train_state, # pytype: disable=attribute-error # jax-api-types
in_axis_resources=None,
out_axis_resources=self.train_state_axes,
out_axis_resources=self.train_state_axes, # pytype: disable=attribute-error # jax-api-types
)
return p_initialize_train_state_fn(init_rng)

Expand Down Expand Up @@ -1129,7 +1129,7 @@ def _restore_path(path, cfg):
if cfg is None:
raise ValueError('Expected valid `RestoreCheckpointConfig`.')
restore_checkpointer = cfg.checkpointer_cls(
train_state=self.global_train_state_shape,
train_state=self.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types
partitioner=self._partitioner,
checkpoints_dir='', # unused for restore
dataset_iterator=ds_iter if cfg.restore_dataset else None,
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def _init(rng):
checkpoint_manager = create_orbax_checkpoint_manager(
save_cfg=save_checkpoint_cfg,
restore_cfg=restore_checkpoint_cfg,
train_state=train_state_initializer.global_train_state_shape,
train_state=train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types
partitioner=partitioner,
ds_iter=ds_iter,
model_dir=model_dir,
Expand All @@ -1244,7 +1244,7 @@ def _init(rng):
checkpoint_manager = LegacyCheckpointManager(
save_cfg=save_checkpoint_cfg,
restore_cfg=restore_checkpoint_cfg,
train_state_shape=train_state_initializer.global_train_state_shape,
train_state_shape=train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types
partitioner=partitioner,
ds_iter=ds_iter,
model_dir=model_dir,
Expand Down
Loading