We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ad9e0c9 commit 0678aa3Copy full SHA for 0678aa3
t5x/train.py
@@ -354,6 +354,11 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):
354
checkpoint_cfg.save and checkpoint_cfg.save.save_dataset
355
),
356
state_transformation_fns=state_transforms_for_restore,
357
+ strict=(checkpoint_cfg.restore.strict
358
+ if checkpoint_cfg.restore is not None else True
359
+ ),
360
+ fallback_to_scratch=(checkpoint_cfg.restore.fallback_to_scratch
361
+ if checkpoint_cfg.restore is not None else False)
362
)
363
]
364
# 2. From a checkpoint specified by `checkpoint_cfg.restore.path`, if set.
0 commit comments