Skip to content

Commit 103d7d5

Browse files
No public description
PiperOrigin-RevId: 712914895
1 parent 501e0f6 commit 103d7d5

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

Diff for: official/projects/pix2seq/configs/pix2seq.py

+5
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class BackboneConfig(hyperparams.Config):
115115
)
116116
# Optional checkpoint to load for this backbone.
117117
init_checkpoint: Optional[str] = None
118+
# If loading an init_checkpoint, whether to assert that all objects in the
119+
# Python program are matched by the checkpoint.
120+
# If False, understand that only the weak assertion of a non-trivial match
121+
# will be made.
122+
assert_existing_objects_matched: bool = True
118123

119124

120125
@dataclasses.dataclass

Diff for: official/projects/pix2seq/tasks/pix2seq_task.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,13 @@ def initialize(self, model: tf_keras.Model):
141141
backbone.load_checkpoint(ckpt_filepath=backbone_init_ckpt)
142142
else:
143143
ckpt = tf.train.Checkpoint(backbone=backbone)
144-
status = ckpt.restore(backbone_init_ckpt)
145-
status.expect_partial().assert_existing_objects_matched()
144+
status = (
145+
ckpt.restore(backbone_init_ckpt)
146+
.expect_partial()
147+
.assert_nontrivial_match()
148+
)
149+
if backbone_config.assert_existing_objects_matched:
150+
status.assert_existing_objects_matched()
146151

147152
logging.info(
148153
'Finished loading pretrained backbone from %s', backbone_init_ckpt

0 commit comments

Comments
 (0)