@@ -206,7 +206,7 @@ def __init__(
206206 save_cfg = self ._save_checkpoint_cfg ,
207207 restore_cfg = self ._restore_checkpoint_cfg ,
208208 train_state_shape = (
209- self ._train_state_initializer .global_train_state_shape
209+ self ._train_state_initializer .global_train_state_shape # pytype: disable=attribute-error # jax-api-types
210210 ),
211211 partitioner = self ._partitioner ,
212212 ds_iter = None ,
@@ -266,13 +266,13 @@ def get_state(rng):
266266 self ._train_state = self ._train_state_initializer .from_scratch (
267267 self ._init_rng
268268 )
269- self ._train_state_axes = self ._train_state_initializer .train_state_axes
269+ self ._train_state_axes = self ._train_state_initializer .train_state_axes # pytype: disable=attribute-error # jax-api-types
270270
271271 # Log the variable shapes information and write to a file.
272272 log_file = os .path .join (self ._output_dir , "model-info.txt" )
273273 utils .log_model_info (
274274 log_file ,
275- self ._train_state_initializer .global_train_state_shape ,
275+ self ._train_state_initializer .global_train_state_shape , # pytype: disable=attribute-error # jax-api-types
276276 self ._partitioner ,
277277 )
278278
@@ -489,7 +489,7 @@ def infer_with_preprocessors(
489489 self ._cached_infer_fns [infer_fn_key ] = utils .get_infer_fn (
490490 infer_step = functools .partial (infer_step , ** inference_kwargs ),
491491 batch_size = self ._batch_size ,
492- train_state_axes = self ._train_state_initializer .train_state_axes ,
492+ train_state_axes = self ._train_state_initializer .train_state_axes , # pytype: disable=attribute-error # jax-api-types
493493 partitioner = self ._partitioner ,
494494 )
495495 infer_fn = functools .partial (
0 commit comments