diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 47940c071..e730a021f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -757,8 +757,18 @@ def run(args) -> None: # Optimizer param_options = get_params_options(args, model) + optimizer: torch.optim.Optimizer optimizer = get_optimizer(args, param_options) + logging.info("=== Layer's learning rates ===") + for name, p in model.named_parameters(): + st = optimizer.state.get(p, {}) + if st: + logging.info(f"Param: {name}: {list(st.keys())}") + + for i, param_group in enumerate(optimizer.param_groups): + logging.info(f"Param group {i}: lr = {param_group['lr']}") + if args.device == "xpu": logging.info("Optimzing model and optimzier for XPU") model, optimizer = ipex.optimize(model, optimizer=optimizer) @@ -805,9 +815,6 @@ def run(args) -> None: ema: Optional[ExponentialMovingAverage] = None if args.ema: ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) - else: - for group in optimizer.param_groups: - group["lr"] = args.lr if args.lbfgs: logging.info("Switching optimizer to LBFGS") diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index a074c6cb8..0b93cbf9b 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -787,6 +787,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 ) + parser.add_argument( + "--lr_params_factors", + help="Learning rate factors to multiply on the original lr", + type=str, + default='{"embedding_lr_factor": 1.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}', + ) + parser.add_argument( + "--freeze", + help="Freeze layers from 1 to N. Can be positive or negative, e.g. -1 means the last layer is frozen. 0 or None means all layers are active and is a default setting", + type=int, + default=None, + ) parser.add_argument( "--amsgrad", help="use amsgrad variant of optimizer", diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 218679666..aea4741d8 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -778,6 +778,11 @@ def get_swa( return swa, swas +def freeze_module(module: torch.nn.Module, freeze: bool = True): + for p in module.parameters(): + p.requires_grad = not freeze + + def get_params_options( args: argparse.Namespace, model: torch.nn.Module ) -> Dict[str, Any]: @@ -789,32 +794,57 @@ def get_params_options( else: no_decay_interactions[name] = param + lr_params_factors = json.loads(args.lr_params_factors) + + if args.freeze: + if args.freeze >= 7: + logging.info("Freezing readout weights") + lr_params_factors["readouts_lr_factor"] = 0.0 + freeze_module(model.readouts, True) + if args.freeze >= 6: + logging.info("Freezing product weights") + lr_params_factors["products_lr_factor"] = 0.0 + freeze_module(model.products, True) + if args.freeze >= 5: + logging.info("Freezing interaction linear weights") + lr_params_factors["interactions_lr_factor"] = 0.0 + freeze_module(model.interactions, True) + if args.freeze >= 1: + logging.info("Freezing embedding weights") + lr_params_factors["embedding_lr_factor"] = 0.0 + freeze_module(model.node_embedding, True) + param_options = dict( params=[ { "name": "embedding", "params": model.node_embedding.parameters(), "weight_decay": 0.0, + "lr": lr_params_factors.get("embedding_lr_factor", 1.0) * args.lr, }, { "name": "interactions_decay", "params": list(decay_interactions.values()), "weight_decay": args.weight_decay, + "lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr, }, { "name": "interactions_no_decay", "params": list(no_decay_interactions.values()), "weight_decay": 0.0, + "lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr, }, { "name": "products", "params": model.products.parameters(), "weight_decay": args.weight_decay, + "lr": lr_params_factors.get("products_lr_factor", 1.0) * args.lr, }, { "name": "readouts", "params": model.readouts.parameters(), "weight_decay": 0.0, + "lr": lr_params_factors.get("readouts_lr_factor", 1.0) * args.lr, }, ], lr=args.lr, diff --git a/mace/tools/train.py b/mace/tools/train.py index 2110ac81e..d8b97c64a 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -8,7 +8,7 @@ import logging import time from collections import defaultdict -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -538,6 +538,20 @@ def closure(): return loss, loss_dict +# Keep parameters frozen/active after evaluation +@contextmanager +def preserve_grad_state(model): + # save the original requires_grad state for all parameters + requires_grad_backup = {param: param.requires_grad for param in model.parameters()} + try: + # temporarily disable gradients for all parameters + for param in model.parameters(): + param.requires_grad = False + yield # perform evaluation here + finally: + # restore the original requires_grad states + for param, requires_grad in requires_grad_backup.items(): + param.requires_grad = requires_grad def evaluate( model: torch.nn.Module, @@ -546,31 +560,28 @@ def evaluate( output_args: Dict[str, bool], device: torch.device, ) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False + metrics = MACELoss(loss_fn=loss_fn).to(device) start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - avg_loss, aux = metrics(batch, output) + with preserve_grad_state(model): + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) avg_loss, aux = metrics.compute() aux["time"] = time.time() - start_time metrics.reset() - for param in model.parameters(): - param.requires_grad = True - return avg_loss, aux diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 000000000..5d2c277a4 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,257 @@ +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase.atoms import Atoms + +from mace.calculators import MACECalculator + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +device = "cuda" if torch.cuda.is_available() else "cpu" +# device = "cpu" + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(name="fitting_configs") +def fixture_fitting_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + print(c.info["REF_energy"]) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(name="pretraining_configs") +def fixture_pretraining_configs(): + configs = [] + for _ in range(10): + atoms = Atoms( + numbers=[8, 1, 1], + positions=np.random.rand(3, 3) * 3, + cell=[5, 5, 5], + pbc=[True] * 3, + ) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) + atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) + configs.append(atoms) + configs.append( + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), + ) + configs.append( + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) + ) + configs[-2].info["REF_energy"] = -2.0 + configs[-2].info["config_type"] = "IsolatedAtom" + configs[-1].info["REF_energy"] = -4.0 + configs[-1].info["config_type"] = "IsolatedAtom" + return configs + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 2, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": device, + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def test_run_train_freeze(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + mace_params["freeze"] = 6 + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + print(f"Running command: {cmd}") + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device=device, default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + + ref_Es = [ + 5.348334089807952, + 2.4128907878403982, + 8.5566950528953, + 7.743803832228654, + 5.788643738738498, + 9.103127501095454, + 8.719323994063377, + 8.169843256425096, + 8.077166786336269, + 8.679676296893602, + 12.189297325152948, + 6.911712148654615, + 8.290506707079263, + 5.303821445834231, + 7.296761518032694, + 5.946962420990914, + 9.043336244248948, + 7.446979685692335, + 5.764245581904601, + 6.975111618768769, + 6.931624082425803, + 6.72206658924676, + ] + + assert np.allclose(Es, ref_Es) + + +def test_run_train_soft_freeze(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + mace_params["lr_params_factors"] = '{"embedding_lr_factor": 0.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}' + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = [sys.executable, str(run_train)] + for k, v in mace_params.items(): + if v is not None: + cmd.append(f"--{k}={v}") + else: + cmd.append(f"--{k}") + + print(f"Running command: {cmd}") + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device=device, default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + + ref_Es = [ + 4.077101520328611, + 1.9125514950167353, + 4.6390361860381795, + 4.6415570296531214, + 3.9153698530138845, + 4.487578378535444, + 4.439674506695098, + 4.906251552572849, + 4.6943771636613985, + 4.443480673870315, + 12.392544826986759, + 4.8014551746345475, + 4.6380462142293455, + 4.126315015844008, + 4.923222049125721, + 4.442558518514199, + 4.556565520687697, + 4.935513763430022, + 4.077869607943539, + 4.4407761603911124, + 5.10253699303561, + 4.537672050884654, + ] + + assert np.allclose(Es, ref_Es) +