Skip to content
49 changes: 28 additions & 21 deletions mmf/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,13 @@ def build(self):

def get_optimizer_parameters(self, config):
lr = config.optimizer.params.lr

backbone_param_set = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's name this trunk_params_set.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

param_list = []
parameters = []
head_configs = self.config.get("heads", [])

for name, module in self.named_children():
# Heads can have different learning rates. This is handled here
if name == "heads":
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
)
elif name == "encoders":

if name == "encoders":
for key in module:
for modality in self.config.modalities:
if key == modality.key:
Expand All @@ -134,29 +122,48 @@ def get_optimizer_parameters(self, config):
parameters=parameters,
param_list=param_list,
)
else:
elif name != "heads":
# For other modules in trunk, add to same param group
param_list += list(module.named_parameters())

backbone_param_set.update(list(module.parameters()))
head_configs = self.config.get("heads", [])
# Heads can have different learning rates. This is handled here
if len(head_configs) > 0:
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
backbone_param_set = backbone_param_set
)
parameters += get_bert_configured_parameters(param_list)

return parameters

def set_lr_for_parameters(
self, config, module_name, base_lr, module, parameters, param_list
self, config, module_name, base_lr, module, parameters, param_list, backbone_param_set = None
):
lr_multiplier = config.get("lr_multiplier", 1.0)
if backbone_param_set is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is None, make it an empty list, [].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

module_param = list(module.named_parameters())
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, you can remove this else condition.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your reviewing. I will change my code according to ur suggestion.

module_param = [ tup for tup in module.named_parameters() if tup[1] not in backbone_param_set ]
if lr_multiplier != 1.0:
logger.info(
f"Setting learning rate of {module_name} to be {base_lr} * {lr_multiplier}."
) # noqa
parameters += get_bert_configured_parameters(
module, base_lr * lr_multiplier
module_param, base_lr * lr_multiplier
)
else:
# Parameters for the modules with same learning rate as
# trunk, add to same param group
param_list += list(module.named_parameters())
param_list += module_param
return parameters, param_list

def build_encoders(self):
Expand Down