Skip to content

Commit a80ad54

Browse files
No public description
PiperOrigin-RevId: 700419920
1 parent be8acfe commit a80ad54

File tree

4 files changed

+124
-73
lines changed

4 files changed

+124
-73
lines changed

Diff for: official/projects/pix2seq/configs/pix2seq.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,20 @@ class Backbone(backbones.Backbone):
9999
resnet: backbones.ResNet = dataclasses.field(default_factory=backbones.ResNet)
100100
uvit: uvit_backbones.VisionTransformer = dataclasses.field(
101101
default_factory=uvit_backbones.VisionTransformer)
102+
# Whether to freeze this backbone during training.
103+
freeze: bool = False
104+
# The endpoint name of the features to extract from the backbone.
105+
endpoint_name: str = '5'
106+
norm_activation: common.NormActivation = dataclasses.field(
107+
default_factory=common.NormActivation
108+
)
109+
# Optional checkpoint to load for this backbone.
110+
init_checkpoint: Optional[str] = None
102111

103112

104113
@dataclasses.dataclass
105114
class Pix2Seq(hyperparams.Config):
106-
"""Pix2Seq model definations."""
115+
"""Pix2Seq model definitions."""
107116

108117
max_num_instances: int = 100
109118
hidden_size: int = 256
@@ -115,16 +124,16 @@ class Pix2Seq(hyperparams.Config):
115124
shared_decoder_embedding: bool = True
116125
decoder_output_bias: bool = True
117126
input_size: List[int] = dataclasses.field(default_factory=list)
118-
backbone: Backbone = dataclasses.field(
119-
default_factory=lambda: Backbone( # pylint: disable=g-long-lambda
120-
type='resnet',
121-
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
122-
)
123-
)
124-
norm_activation: common.NormActivation = dataclasses.field(
125-
default_factory=common.NormActivation
127+
# Backbones for each image modality. If just using RGB, you should only set
128+
# one backbone.
129+
backbones: List[Backbone] = dataclasses.field(
130+
default_factory=lambda: [
131+
Backbone( # pylint: disable=g-long-lambda
132+
type='resnet',
133+
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
134+
)
135+
]
126136
)
127-
backbone_endpoint_name: str = '5'
128137
drop_path: float = 0.1
129138
drop_units: float = 0.1
130139
drop_att: float = 0.0
@@ -172,13 +181,16 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig:
172181
),
173182
model=Pix2Seq(
174183
input_size=[640, 640, 3],
175-
norm_activation=common.NormActivation(
176-
norm_momentum=0.9,
177-
norm_epsilon=1e-5,
178-
use_sync_bn=True),
179-
backbone=Backbone(
180-
type='resnet', resnet=backbones.ResNet(model_id=50)
181-
),
184+
backbones=[
185+
Backbone(
186+
type='resnet',
187+
resnet=backbones.ResNet(model_id=50),
188+
norm_activation=common.NormActivation(
189+
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True
190+
),
191+
init_checkpoint='',
192+
)
193+
],
182194
),
183195
losses=Losses(l2_weight_decay=0.0),
184196
train_data=DataConfig(
@@ -188,7 +200,7 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig:
188200
shuffle_buffer_size=train_batch_size * 10,
189201
aug_scale_min=0.3,
190202
aug_scale_max=2.0,
191-
aug_color_jitter_strength=0.0
203+
aug_color_jitter_strength=0.0,
192204
),
193205
validation_data=DataConfig(
194206
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),

Diff for: official/projects/pix2seq/modeling/pix2seq_model.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ class Pix2Seq(tf_keras.Model):
215215
def __init__(
216216
self,
217217
backbones: Sequence[tf_keras.Model],
218-
backbone_endpoint_name,
218+
backbone_endpoint_names: Sequence[str],
219219
max_seq_len,
220220
vocab_size,
221221
hidden_size,
@@ -233,7 +233,7 @@ def __init__(
233233
):
234234
super().__init__(**kwargs)
235235
self._backbones = backbones
236-
self._backbone_endpoint_name = backbone_endpoint_name
236+
self._backbone_endpoint_names = backbone_endpoint_names
237237
self._max_seq_len = max_seq_len
238238
self._vocab_size = vocab_size
239239
self._hidden_size = hidden_size
@@ -285,9 +285,7 @@ def transformer(self) -> tf_keras.Model:
285285
return self._transformer
286286

287287
def get_config(self):
288-
return {
289-
"backbone": self._backbone,
290-
"backbone_endpoint_name": self._backbone_endpoint_name,
288+
config = {
291289
"max_seq_len": self._max_seq_len,
292290
"vocab_size": self._vocab_size,
293291
"hidden_size": self._hidden_size,
@@ -302,6 +300,12 @@ def get_config(self):
302300
"early_stopping_token": self._early_stopping_token,
303301
"num_heads": self._num_heads,
304302
}
303+
config["backbone"] = self._backbones[0]
304+
config["backbone_endpoint_name"] = self._backbone_endpoint_names[0]
305+
for i in range(1, len(self._backbones)):
306+
config[f"backbone_{i+1}"] = self._backbones[i]
307+
config[f"backbone_endpoint_name_{i+1}"] = self._backbone_endpoint_names[i]
308+
return config
305309

306310
@classmethod
307311
def from_config(cls, config):
@@ -354,7 +358,9 @@ def call(
354358
if use_input_as_backbone_features:
355359
features = inputs_i
356360
else:
357-
features = self._backbones[i](inputs_i)[self._backbone_endpoint_name]
361+
features = self._backbones[i](inputs_i)[
362+
self._backbone_endpoint_names[i]
363+
]
358364
mask = tf.ones_like(features)
359365
batch_size, h, w, num_channels = get_shape(features)
360366
features = tf.reshape(features, [batch_size, h * w, num_channels])

Diff for: official/projects/pix2seq/modeling/pix2seq_model_test.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def test_forward(self, num_backbones: int):
3737
backbones = [
3838
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
3939
]
40-
backbone_endpoint_name = '5'
40+
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
4141
model = pix2seq_model.Pix2Seq(
4242
backbones,
43-
backbone_endpoint_name,
43+
backbone_endpoint_names,
4444
max_seq_len,
4545
vocab_size,
4646
hidden_size,
@@ -68,10 +68,10 @@ def test_forward_infer_teacher_forcing(self, num_backbones: int):
6868
backbones = [
6969
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
7070
]
71-
backbone_endpoint_name = '5'
71+
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
7272
model = pix2seq_model.Pix2Seq(
7373
backbones,
74-
backbone_endpoint_name,
74+
backbone_endpoint_names,
7575
max_seq_len,
7676
vocab_size,
7777
hidden_size,
@@ -100,10 +100,10 @@ def test_forward_infer(self, num_backbones: int):
100100
backbones = [
101101
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
102102
]
103-
backbone_endpoint_name = '5'
103+
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
104104
model = pix2seq_model.Pix2Seq(
105105
backbones,
106-
backbone_endpoint_name,
106+
backbone_endpoint_names,
107107
max_seq_len,
108108
vocab_size,
109109
hidden_size,
@@ -125,10 +125,10 @@ def test_forward_infer_with_early_stopping(self):
125125
image_size = 640
126126
batch_size = 2
127127
backbone = resnet.ResNet(50, bn_trainable=False)
128-
backbone_endpoint_name = '5'
128+
backbone_endpoint_names = ['5']
129129
model = pix2seq_model.Pix2Seq(
130130
[backbone],
131-
backbone_endpoint_name,
131+
backbone_endpoint_names,
132132
max_seq_len,
133133
vocab_size,
134134
hidden_size,
@@ -151,10 +151,10 @@ def test_forward_infer_with_long_prompt(self):
151151
image_size = 640
152152
batch_size = 2
153153
backbone = resnet.ResNet(50, bn_trainable=False)
154-
backbone_endpoint_name = '5'
154+
backbone_endpoint_names = ['5']
155155
model = pix2seq_model.Pix2Seq(
156156
[backbone],
157-
backbone_endpoint_name,
157+
backbone_endpoint_names,
158158
max_seq_len,
159159
vocab_size,
160160
hidden_size,

Diff for: official/projects/pix2seq/tasks/pix2seq_task.py

+72-39
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from official.vision.dataloaders import tfds_factory
3333
from official.vision.dataloaders import tf_example_label_map_decoder
3434
from official.vision.evaluation import coco_evaluator
35-
from official.vision.modeling import backbones
35+
from official.vision.modeling import backbones as backbones_lib
3636

3737

3838
@task_factory.register_task_cls(pix2seq_cfg.Pix2SeqTask)
@@ -44,24 +44,34 @@ class Pix2SeqTask(base_task.Task):
4444
post-processing, and customized metrics with reduction.
4545
"""
4646

47-
def build_model(self):
48-
"""Build Pix2Seq model."""
47+
def _build_backbones_and_endpoint_names(
48+
self,
49+
) -> tuple[list[tf_keras.Model], list[str]]:
50+
"""Build backbones and returns their corresponding endpoint names."""
4951
config: pix2seq_cfg.Pix2Seq = self._task_config.model
50-
5152
input_specs = tf_keras.layers.InputSpec(
5253
shape=[None] + config.input_size
5354
)
55+
backbones = []
56+
endpoint_names = []
57+
for backbone_config in config.backbones:
58+
backbone = backbones_lib.factory.build_backbone(
59+
input_specs=input_specs,
60+
backbone_config=backbone_config,
61+
norm_activation_config=backbone_config.norm_activation,
62+
)
63+
backbone.trainable = not backbone_config.freeze
64+
backbones.append(backbone)
65+
endpoint_names.append(backbone_config.endpoint_name)
66+
return backbones, endpoint_names
5467

55-
backbone = backbones.factory.build_backbone(
56-
input_specs=input_specs,
57-
backbone_config=config.backbone,
58-
norm_activation_config=config.norm_activation,
59-
)
60-
68+
def build_model(self):
69+
"""Build Pix2Seq model."""
70+
config: pix2seq_cfg.Pix2Seq = self._task_config.model
71+
backbones, endpoint_names = self._build_backbones_and_endpoint_names()
6172
model = pix2seq_model.Pix2Seq(
62-
# TODO: b/378885339 - Support multiple backbones from the config.
63-
backbones=[backbone],
64-
backbone_endpoint_name=config.backbone_endpoint_name,
73+
backbones=backbones,
74+
backbone_endpoint_name=endpoint_names,
6575
max_seq_len=config.max_num_instances * 5,
6676
vocab_size=config.vocab_size,
6777
hidden_size=config.hidden_size,
@@ -78,41 +88,64 @@ def build_model(self):
7888
)
7989
return model
8090

91+
def _get_ckpt(self, ckpt_dir_or_file: str) -> str:
92+
if tf.io.gfile.isdir(ckpt_dir_or_file):
93+
return tf.train.latest_checkpoint(ckpt_dir_or_file)
94+
return ckpt_dir_or_file
95+
8196
def initialize(self, model: tf_keras.Model):
8297
"""Loading pretrained checkpoint."""
83-
if not self._task_config.init_checkpoint:
84-
return
98+
if self._task_config.init_checkpoint_modules == 'backbone':
99+
raise ValueError(
100+
'init_checkpoint_modules=backbone is deprecated. Specify backbone '
101+
'checkpoints in each backbone config.'
102+
)
85103

86-
ckpt_dir_or_file = self._task_config.init_checkpoint
104+
if self._task_config.init_checkpoint_modules not in ['all', 'partial', '']:
105+
raise ValueError(
106+
'Unsupported init_checkpoint_modules: '
107+
f'{self._task_config.init_checkpoint_modules}'
108+
)
87109

88-
# Restoring checkpoint.
89-
if tf.io.gfile.isdir(ckpt_dir_or_file):
90-
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
110+
if self._task_config.init_checkpoint and any(
111+
[b.init_checkpoint for b in self._task_config.model.backbones]
112+
):
113+
raise ValueError(
114+
'A global init_checkpoint and a backbone init_checkpoint cannot be'
115+
' specified at the same time.'
116+
)
91117

92-
if self._task_config.init_checkpoint_modules == 'all':
118+
if self._task_config.init_checkpoint:
119+
global_ckpt_file = self._get_ckpt(self._task_config.init_checkpoint)
93120
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
94-
status = ckpt.restore(ckpt_dir_or_file)
95-
status.expect_partial().assert_existing_objects_matched()
121+
status = ckpt.restore(global_ckpt_file).expect_partial()
122+
if self._task_config.init_checkpoint_modules != 'partial':
123+
status.assert_existing_objects_matched()
96124
logging.info(
97-
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
98-
)
99-
elif self._task_config.init_checkpoint_modules == 'backbone':
100-
if self.task_config.model.backbone.type == 'uvit':
101-
model.backbone.load_checkpoint(ckpt_filepath=ckpt_dir_or_file)
102-
else:
103-
# TODO: b/378885339 - Support multiple backbones from the config.
104-
ckpt = tf.train.Checkpoint(backbone=model.backbones[0])
105-
status = ckpt.restore(ckpt_dir_or_file)
106-
status.expect_partial().assert_existing_objects_matched()
107-
logging.info(
108-
'Finished loading pretrained backbone from %s', ckpt_dir_or_file
125+
'Finished loading pretrained checkpoint from %s', global_ckpt_file
109126
)
110127
else:
111-
raise ValueError(
112-
f'Failed to load {ckpt_dir_or_file}. Unsupported '
113-
'init_checkpoint_modules: '
114-
f'{self._task_config.init_checkpoint_modules}'
115-
)
128+
# This case means that no global checkpoint was provided. Possibly,
129+
# backbone-specific checkpoints were.
130+
for backbone_config, backbone in zip(
131+
self._task_config.model.backbones, model.backbones
132+
):
133+
if not backbone_config.init_checkpoint:
134+
continue
135+
136+
backbone_init_ckpt = self._get_ckpt(backbone_config.init_checkpoint)
137+
if backbone_config.type == 'uvit':
138+
# The UVit object has a special function called load_checkpoint.
139+
# The other backbones do not.
140+
backbone.load_checkpoint(ckpt_filepath=backbone_init_ckpt)
141+
else:
142+
ckpt = tf.train.Checkpoint(backbone=backbone)
143+
status = ckpt.restore(backbone_init_ckpt)
144+
status.expect_partial().assert_existing_objects_matched()
145+
146+
logging.info(
147+
'Finished loading pretrained backbone from %s', backbone_init_ckpt
148+
)
116149

117150
def build_inputs(
118151
self, params, input_context: Optional[tf.distribute.InputContext] = None

0 commit comments

Comments
 (0)