From 087c7242b43a158bf560bcb44ec5e8824cd1dbc6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 4 Jun 2025 09:07:02 -0700 Subject: [PATCH] Disable pytype for errors revealed by JAX build refactor PiperOrigin-RevId: 767169324 --- t5x/eval.py | 4 ++-- t5x/export_lib.py | 4 ++-- t5x/infer.py | 4 ++-- t5x/precompile.py | 4 ++-- t5x/train.py | 4 ++-- t5x/utils.py | 10 +++++----- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/t5x/eval.py b/t5x/eval.py index ac2f0e59b..eb4e3c89e 100644 --- a/t5x/eval.py +++ b/t5x/eval.py @@ -281,11 +281,11 @@ def evaluate( input_shapes=input_shapes, partitioner=partitioner, ) - train_state_axes = train_state_initializer.train_state_axes + train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types # Log the variable shapes information and write to a file. log_file = os.path.join(output_dir, 'model-info.txt') utils.log_model_info( - log_file, train_state_initializer.global_train_state_shape, partitioner + log_file, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types ) if training_evaluator_cls: diff --git a/t5x/export_lib.py b/t5x/export_lib.py index 9e53a5631..ac999ba94 100644 --- a/t5x/export_lib.py +++ b/t5x/export_lib.py @@ -269,7 +269,7 @@ def get_train_state_initializer( partitioner=partitioner, ) utils.log_model_info( - None, train_state_initializer.global_train_state_shape, partitioner + None, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types ) return train_state_initializer @@ -317,7 +317,7 @@ def create_inference_function( # TODO(b/121310741): Re-enable pytype. # pytype:disable=wrong-arg-types in_axis_resources=( - train_state_initializer.train_state_axes.params, + train_state_initializer.train_state_axes.params, # pytype: disable=attribute-error # jax-api-types partitioning.PartitionSpec( 'data', ), diff --git a/t5x/infer.py b/t5x/infer.py index 78f03d3d5..bf6281b67 100644 --- a/t5x/infer.py +++ b/t5x/infer.py @@ -510,7 +510,7 @@ def _get_dataset(dataset_provider): if shard_id == 0: utils.log_model_info( model_info_log_file, - train_state_initializer.global_train_state_shape, + train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types partitioner, ) @@ -541,7 +541,7 @@ def _get_dataset(dataset_provider): utils.get_infer_fn( infer_step=infer_step, batch_size=batch_size, - train_state_axes=train_state_initializer.train_state_axes, + train_state_axes=train_state_initializer.train_state_axes, # pytype: disable=attribute-error # jax-api-types partitioner=partitioner, keep_aux_as_numpy=keep_aux_as_numpy, ), diff --git a/t5x/precompile.py b/t5x/precompile.py index c8b9dd634..7794a42af 100644 --- a/t5x/precompile.py +++ b/t5x/precompile.py @@ -95,8 +95,8 @@ def precompile( input_types=input_types, partitioner=partitioner, ) - train_state_shape = train_state_initializer.global_train_state_shape - train_state_axes = train_state_initializer.train_state_axes + train_state_shape = train_state_initializer.global_train_state_shape # pytype: disable=attribute-error # jax-api-types + train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types def train_step(train_state, batch): return trainer_lib.train_with_lr( # pytype: disable=wrong-arg-types # jax-ndarray diff --git a/t5x/train.py b/t5x/train.py index 7c4afe0df..d25bef7b7 100644 --- a/t5x/train.py +++ b/t5x/train.py @@ -386,7 +386,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): # 3. If no checkpoint to restore, init from scratch. train_state = train_state or train_state_initializer.from_scratch(init_rng) - train_state_axes = train_state_initializer.train_state_axes + train_state_axes = train_state_initializer.train_state_axes # pytype: disable=attribute-error # jax-api-types init_or_restore_secs = time.time() - init_or_restore_tick logging.info( 'Initialize/restore complete (%.2f seconds).', init_or_restore_secs @@ -395,7 +395,7 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig): # Log the variable shapes information and write to a file. log_file = os.path.join(model_dir, 'model-info.txt') utils.log_model_info( - log_file, train_state_initializer.global_train_state_shape, partitioner + log_file, train_state_initializer.global_train_state_shape, partitioner # pytype: disable=attribute-error # jax-api-types ) # Restore step from last checkpoint or set to 0 if training from scratch. diff --git a/t5x/utils.py b/t5x/utils.py index 7fc7c5706..098734f6a 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -1091,9 +1091,9 @@ def from_scratch(self, init_rng: Array) -> train_state_lib.TrainState: # setup as the training step/loop to initialize everything "in-place" and # avoid communication or OOM. p_initialize_train_state_fn = self._partitioner.partition( - self._initialize_train_state, + self._initialize_train_state, # pytype: disable=attribute-error # jax-api-types in_axis_resources=None, - out_axis_resources=self.train_state_axes, + out_axis_resources=self.train_state_axes, # pytype: disable=attribute-error # jax-api-types ) return p_initialize_train_state_fn(init_rng) @@ -1129,7 +1129,7 @@ def _restore_path(path, cfg): if cfg is None: raise ValueError('Expected valid `RestoreCheckpointConfig`.') restore_checkpointer = cfg.checkpointer_cls( - train_state=self.global_train_state_shape, + train_state=self.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types partitioner=self._partitioner, checkpoints_dir='', # unused for restore dataset_iterator=ds_iter if cfg.restore_dataset else None, @@ -1229,7 +1229,7 @@ def _init(rng): checkpoint_manager = create_orbax_checkpoint_manager( save_cfg=save_checkpoint_cfg, restore_cfg=restore_checkpoint_cfg, - train_state=train_state_initializer.global_train_state_shape, + train_state=train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types partitioner=partitioner, ds_iter=ds_iter, model_dir=model_dir, @@ -1244,7 +1244,7 @@ def _init(rng): checkpoint_manager = LegacyCheckpointManager( save_cfg=save_checkpoint_cfg, restore_cfg=restore_checkpoint_cfg, - train_state_shape=train_state_initializer.global_train_state_shape, + train_state_shape=train_state_initializer.global_train_state_shape, # pytype: disable=attribute-error # jax-api-types partitioner=partitioner, ds_iter=ds_iter, model_dir=model_dir,