Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,18 @@ def run(args) -> None:

# Optimizer
param_options = get_params_options(args, model)

optimizer: torch.optim.Optimizer
optimizer = get_optimizer(args, param_options)
logging.info("=== Layer's learning rates ===")
for name, p in model.named_parameters():
st = optimizer.state.get(p, {})
if st:
logging.info(f"Param: {name}: {list(st.keys())}")

for i, param_group in enumerate(optimizer.param_groups):
logging.info(f"Param group {i}: lr = {param_group['lr']}")

if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
Expand Down Expand Up @@ -805,9 +815,6 @@ def run(args) -> None:
ema: Optional[ExponentialMovingAverage] = None
if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
else:
for group in optimizer.param_groups:
group["lr"] = args.lr

if args.lbfgs:
logging.info("Switching optimizer to LBFGS")
Expand Down
12 changes: 12 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7
)
parser.add_argument(
"--lr_params_factors",
help="Learning rate factors to multiply on the original lr",
type=str,
default='{"embedding_lr_factor": 1.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}',
)
parser.add_argument(
"--freeze",
help="Freeze layers from 1 to N. Can be positive or negative, e.g. -1 means the last layer is frozen. 0 or None means all layers are active and is a default setting",
type=int,
default=None,
)
parser.add_argument(
"--amsgrad",
help="use amsgrad variant of optimizer",
Expand Down
30 changes: 30 additions & 0 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,11 @@ def get_swa(
return swa, swas


def freeze_module(module: torch.nn.Module, freeze: bool = True):
for p in module.parameters():
p.requires_grad = not freeze


def get_params_options(
args: argparse.Namespace, model: torch.nn.Module
) -> Dict[str, Any]:
Expand All @@ -789,32 +794,57 @@ def get_params_options(
else:
no_decay_interactions[name] = param

lr_params_factors = json.loads(args.lr_params_factors)

if args.freeze:
if args.freeze >= 7:
logging.info("Freezing readout weights")
lr_params_factors["readouts_lr_factor"] = 0.0
freeze_module(model.readouts, True)
if args.freeze >= 6:
logging.info("Freezing product weights")
lr_params_factors["products_lr_factor"] = 0.0
freeze_module(model.products, True)
if args.freeze >= 5:
logging.info("Freezing interaction linear weights")
lr_params_factors["interactions_lr_factor"] = 0.0
freeze_module(model.interactions, True)
if args.freeze >= 1:
logging.info("Freezing embedding weights")
lr_params_factors["embedding_lr_factor"] = 0.0
freeze_module(model.node_embedding, True)

param_options = dict(
params=[
{
"name": "embedding",
"params": model.node_embedding.parameters(),
"weight_decay": 0.0,
"lr": lr_params_factors.get("embedding_lr_factor", 1.0) * args.lr,
},
{
"name": "interactions_decay",
"params": list(decay_interactions.values()),
"weight_decay": args.weight_decay,
"lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr,
},
{
"name": "interactions_no_decay",
"params": list(no_decay_interactions.values()),
"weight_decay": 0.0,
"lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr,
},
{
"name": "products",
"params": model.products.parameters(),
"weight_decay": args.weight_decay,
"lr": lr_params_factors.get("products_lr_factor", 1.0) * args.lr,
},
{
"name": "readouts",
"params": model.readouts.parameters(),
"weight_decay": 0.0,
"lr": lr_params_factors.get("readouts_lr_factor", 1.0) * args.lr,
},
],
lr=args.lr,
Expand Down
45 changes: 28 additions & 17 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import time
from collections import defaultdict
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -538,6 +538,20 @@ def closure():

return loss, loss_dict

# Keep parameters frozen/active after evaluation
@contextmanager
def preserve_grad_state(model):
# save the original requires_grad state for all parameters
requires_grad_backup = {param: param.requires_grad for param in model.parameters()}
try:
# temporarily disable gradients for all parameters
for param in model.parameters():
param.requires_grad = False
yield # perform evaluation here
finally:
# restore the original requires_grad states
for param, requires_grad in requires_grad_backup.items():
param.requires_grad = requires_grad

def evaluate(
model: torch.nn.Module,
Expand All @@ -546,31 +560,28 @@ def evaluate(
output_args: Dict[str, bool],
device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
for param in model.parameters():
param.requires_grad = False


metrics = MACELoss(loss_fn=loss_fn).to(device)

start_time = time.time()
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=False,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
avg_loss, aux = metrics(batch, output)

with preserve_grad_state(model):
for batch in data_loader:
batch = batch.to(device)
batch_dict = batch.to_dict()
output = model(
batch_dict,
training=False,
compute_force=output_args["forces"],
compute_virials=output_args["virials"],
compute_stress=output_args["stress"],
)
avg_loss, aux = metrics(batch, output)
avg_loss, aux = metrics.compute()
aux["time"] = time.time() - start_time
metrics.reset()

for param in model.parameters():
param.requires_grad = True

return avg_loss, aux


Expand Down
Loading
Loading