32
32
from official .vision .dataloaders import tfds_factory
33
33
from official .vision .dataloaders import tf_example_label_map_decoder
34
34
from official .vision .evaluation import coco_evaluator
35
- from official .vision .modeling import backbones
35
+ from official .vision .modeling import backbones as backbones_lib
36
36
37
37
38
38
@task_factory .register_task_cls (pix2seq_cfg .Pix2SeqTask )
@@ -44,24 +44,34 @@ class Pix2SeqTask(base_task.Task):
44
44
post-processing, and customized metrics with reduction.
45
45
"""
46
46
47
- def build_model (self ):
48
- """Build Pix2Seq model."""
47
+ def _build_backbones_and_endpoint_names (
48
+ self ,
49
+ ) -> tuple [list [tf_keras .Model ], list [str ]]:
50
+ """Build backbones and returns their corresponding endpoint names."""
49
51
config : pix2seq_cfg .Pix2Seq = self ._task_config .model
50
-
51
52
input_specs = tf_keras .layers .InputSpec (
52
53
shape = [None ] + config .input_size
53
54
)
55
+ backbones = []
56
+ endpoint_names = []
57
+ for backbone_config in config .backbones :
58
+ backbone = backbones_lib .factory .build_backbone (
59
+ input_specs = input_specs ,
60
+ backbone_config = backbone_config ,
61
+ norm_activation_config = backbone_config .norm_activation ,
62
+ )
63
+ backbone .trainable = not backbone_config .freeze
64
+ backbones .append (backbone )
65
+ endpoint_names .append (backbone_config .endpoint_name )
66
+ return backbones , endpoint_names
54
67
55
- backbone = backbones .factory .build_backbone (
56
- input_specs = input_specs ,
57
- backbone_config = config .backbone ,
58
- norm_activation_config = config .norm_activation ,
59
- )
60
-
68
+ def build_model (self ):
69
+ """Build Pix2Seq model."""
70
+ config : pix2seq_cfg .Pix2Seq = self ._task_config .model
71
+ backbones , endpoint_names = self ._build_backbones_and_endpoint_names ()
61
72
model = pix2seq_model .Pix2Seq (
62
- # TODO: b/378885339 - Support multiple backbones from the config.
63
- backbones = [backbone ],
64
- backbone_endpoint_name = config .backbone_endpoint_name ,
73
+ backbones = backbones ,
74
+ backbone_endpoint_name = endpoint_names ,
65
75
max_seq_len = config .max_num_instances * 5 ,
66
76
vocab_size = config .vocab_size ,
67
77
hidden_size = config .hidden_size ,
@@ -78,41 +88,64 @@ def build_model(self):
78
88
)
79
89
return model
80
90
91
+ def _get_ckpt (self , ckpt_dir_or_file : str ) -> str :
92
+ if tf .io .gfile .isdir (ckpt_dir_or_file ):
93
+ return tf .train .latest_checkpoint (ckpt_dir_or_file )
94
+ return ckpt_dir_or_file
95
+
81
96
def initialize (self , model : tf_keras .Model ):
82
97
"""Loading pretrained checkpoint."""
83
- if not self ._task_config .init_checkpoint :
84
- return
98
+ if self ._task_config .init_checkpoint_modules == 'backbone' :
99
+ raise ValueError (
100
+ 'init_checkpoint_modules=backbone is deprecated. Specify backbone '
101
+ 'checkpoints in each backbone config.'
102
+ )
85
103
86
- ckpt_dir_or_file = self ._task_config .init_checkpoint
104
+ if self ._task_config .init_checkpoint_modules not in ['all' , 'partial' , '' ]:
105
+ raise ValueError (
106
+ 'Unsupported init_checkpoint_modules: '
107
+ f'{ self ._task_config .init_checkpoint_modules } '
108
+ )
87
109
88
- # Restoring checkpoint.
89
- if tf .io .gfile .isdir (ckpt_dir_or_file ):
90
- ckpt_dir_or_file = tf .train .latest_checkpoint (ckpt_dir_or_file )
110
+ if self ._task_config .init_checkpoint and any (
111
+ [b .init_checkpoint for b in self ._task_config .model .backbones ]
112
+ ):
113
+ raise ValueError (
114
+ 'A global init_checkpoint and a backbone init_checkpoint cannot be'
115
+ ' specified at the same time.'
116
+ )
91
117
92
- if self ._task_config .init_checkpoint_modules == 'all' :
118
+ if self ._task_config .init_checkpoint :
119
+ global_ckpt_file = self ._get_ckpt (self ._task_config .init_checkpoint )
93
120
ckpt = tf .train .Checkpoint (** model .checkpoint_items )
94
- status = ckpt .restore (ckpt_dir_or_file )
95
- status .expect_partial ().assert_existing_objects_matched ()
121
+ status = ckpt .restore (global_ckpt_file ).expect_partial ()
122
+ if self ._task_config .init_checkpoint_modules != 'partial' :
123
+ status .assert_existing_objects_matched ()
96
124
logging .info (
97
- 'Finished loading pretrained checkpoint from %s' , ckpt_dir_or_file
98
- )
99
- elif self ._task_config .init_checkpoint_modules == 'backbone' :
100
- if self .task_config .model .backbone .type == 'uvit' :
101
- model .backbone .load_checkpoint (ckpt_filepath = ckpt_dir_or_file )
102
- else :
103
- # TODO: b/378885339 - Support multiple backbones from the config.
104
- ckpt = tf .train .Checkpoint (backbone = model .backbones [0 ])
105
- status = ckpt .restore (ckpt_dir_or_file )
106
- status .expect_partial ().assert_existing_objects_matched ()
107
- logging .info (
108
- 'Finished loading pretrained backbone from %s' , ckpt_dir_or_file
125
+ 'Finished loading pretrained checkpoint from %s' , global_ckpt_file
109
126
)
110
127
else :
111
- raise ValueError (
112
- f'Failed to load { ckpt_dir_or_file } . Unsupported '
113
- 'init_checkpoint_modules: '
114
- f'{ self ._task_config .init_checkpoint_modules } '
115
- )
128
+ # This case means that no global checkpoint was provided. Possibly,
129
+ # backbone-specific checkpoints were.
130
+ for backbone_config , backbone in zip (
131
+ self ._task_config .model .backbones , model .backbones
132
+ ):
133
+ if not backbone_config .init_checkpoint :
134
+ continue
135
+
136
+ backbone_init_ckpt = self ._get_ckpt (backbone_config .init_checkpoint )
137
+ if backbone_config .type == 'uvit' :
138
+ # The UVit object has a special function called load_checkpoint.
139
+ # The other backbones do not.
140
+ backbone .load_checkpoint (ckpt_filepath = backbone_init_ckpt )
141
+ else :
142
+ ckpt = tf .train .Checkpoint (backbone = backbone )
143
+ status = ckpt .restore (backbone_init_ckpt )
144
+ status .expect_partial ().assert_existing_objects_matched ()
145
+
146
+ logging .info (
147
+ 'Finished loading pretrained backbone from %s' , backbone_init_ckpt
148
+ )
116
149
117
150
def build_inputs (
118
151
self , params , input_context : Optional [tf .distribute .InputContext ] = None
0 commit comments