From 10a567c736d25c8c180a7642746b8b71ebb34131 Mon Sep 17 00:00:00 2001 From: lianctrl Date: Wed, 5 Mar 2025 18:08:56 +0100 Subject: [PATCH 1/7] added quartic prior --- mlcg/nn/prior.py | 241 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index 664399b..51bf57a 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1065,3 +1065,244 @@ def from_user(*args): def neighbor_list(topology) -> None: nl = topology.neighbor_list(Dihedral.name) return {Dihedral.name: nl} + + +class Quartic(torch.nn.Module, _Prior): + r""" + Prior that helps in fitting tighter bimodal distributions + using the following energy ansatz. + N.B. the linear term is missing + + .. math: + + V(x) = a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d + + + Especially useful for CA angles, to avoid exploration + toward pi + """ + _order_map = { + "bonds": 2, + "angles": 3, + "dihedrals": 4, + } + _compute_map = { + "bonds": compute_distances, + "angles": compute_angles, + "dihedrals": compute_torsions, + } + _neighbor_list_map = { + "bonds": "bonds", + "angles": "angles", + "dihedrals" : "dihedrals", + } + + + def __init__(self, statistics, name, order : Optional[int] = None, n_degs:int = 4) -> None: + super(Quartic, self).__init__() + keys = torch.tensor(list(statistics.keys()), dtype=torch.long) + self.allowed_interaction_keys = list(statistics.keys()) + self.name = name + if order is not None: + self.order = order + elif name in Quartic._order_map.keys(): + self.order = Quartic._order_map[self.name] + else: + raise ValueError(f"Uncompatible order {order}") + self.neighbor_list_type = Quartic._neighbor_list_map[self.name] + + unique_types = torch.unique(keys.flatten()) + assert unique_types.min() >= 0 + max_type = unique_types.max() + sizes = tuple([max_type + 1 for _ in range(self.order)]) + + self.n_degs = n_degs + self.k_names = ["k_" + str(ii) for ii in range(2, self.n_degs+1)] + self.x0_names = ["x0_" + str(ii) for ii in range(2, self.n_degs+1)] + + k = torch.zeros(self.n_degs-1, *sizes) + x0 = torch.zeros(self.n_degs-1, *sizes) + v_0 = torch.zeros(*sizes) + print(k.shape) + for key in statistics.keys(): + for ii in range(self.n_degs-1): + k_name = self.k_names[ii] + x0_name = self.x0_names[ii] + k[ii][key] = statistics[key]["ks"][k_name] + x0[ii][key] = statistics[key]["x0s"][x0_name] + v_0[key] = statistics[key]["v_0"] + self.register_buffer("ks", k) + self.register_buffer("v_0", v_0) + self.register_buffer("x0s", x0) + + + @staticmethod + def compute_features(pos, mapping, target): + compute_map_type = Quartic._neighbor_list_map[target] + return Quartic._compute_map[compute_map_type](pos, mapping) + + + def data2features(self, data): + mapping = data.neighbor_list[self.name]["index_mapping"] + return Quartic.compute_features(data.pos, mapping, self.name) + + + def data2parameters(self, data): + mapping = data.neighbor_list[self.name]["index_mapping"] + interaction_types = [ + data.atom_types[mapping[ii]] for ii in range(self.order) + ] + # the parameters have shape n_features x n_degs-1 since + # linear term is missing + ks = torch.vstack( + [self.ks[ii][interaction_types] for ii in range(self.n_degs-1)] + ).t() + + x0s = torch.vstack( + [self.x0s[ii][interaction_types] for ii in range(self.n_degs-1)] + ).t() + v_0s = self.v_0[interaction_types].t() + return {"ks": ks, "x0s" : x0s, "v_0s": v_0s} + + + def forward(self, data): + mapping_batch = data.neighbor_list[self.name]["mapping_batch"] + features = self.data2features(data).flatten() + params = self.data2parameters(data) + V0s = params["v_0s"].t() + ks = params["ks"].t() + x0s = params["x0s"].t() + y = Quartic.compute( + features, + ks, + V0s, + x0s, + ) + y = scatter(y, mapping_batch, dim=0, reduce="sum") + data.out[self.name] = {"energy": y} + return data + + + @staticmethod + def compute(x: torch.Tensor, ks: torch.Tensor, + V0 : torch.Tensor, x0s: torch.Tensor): + """Harmonic interaction in the form of a series. The shape of the tensors + should match between each other. + + .. math: + + V(r) = V0 + \sum_{n=2}^{4} k_n (x-x_n)^n + + """ + V = 0 + for i in range(3): + V+=ks[i]*(x-x0s[i])**(i+2) + + V += V0 + return V + + + @staticmethod + def _quartic_model(x, a, b, c, d, xa, xb, xc): + return a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d + + + @staticmethod + def _init_quartic_parameters(n_degs): + """ + Helper method for guessing initial parameter values + Not used + """ + ks = [1.0 for _ in range(n_degs-1)] + x0s = [0.0 for _ in range(n_degs-1)] + V0 = -1.0 + p0 = [V0] + p0.extend(ks) + p0.extend(x0s) + return p0 + + + @staticmethod + def _init_quartic_parameter_dict(n_degs): + """Helper method for initializing the parameter dictionary""" + stat = { + "ks": {}, + "x0s": {}, + "v_0" : 0.0 + } + k_names = ["k_" + str(ii) for ii in range(2,n_degs+1)] + x0_names = ["x0_" + str(ii) for ii in range(2,n_degs+1)] + for ii in range(n_degs-1): + k_name = k_names[ii] + x0_name = x0_names[ii] + stat["ks"][k_name] = {} + stat["x0s"][x0_name] = {} + return stat + + + @staticmethod + def _make_quartic_dict(stat, popt, n_degs): + """Helper method for constructing a fitted parameter dictionary""" + stat["v_0"] = popt[0] + k_names = sorted(list(stat["ks"].keys())) + x0_names = sorted(list(stat["x0s"].keys())) + for ii in range(n_degs-1): + k_name = k_names[ii] + x0_name = x0_names[ii] + stat["ks"][k_name] = popt[ii] + stat["x0s"][x0_name] = popt[ii+n_degs] + + return stat + + + @staticmethod + def fit_quartic_from_potential_estimates( + bin_centers_nz: torch.Tensor, + dG_nz: torch.Tensor, + **kwargs, + ): + """ + + Parameters + ---------- + bin_centers_nz: + Bin centers over which the fit is carried out + dG_nz: + The emperical free energy correspinding to the bin centers + + Returns + ------- + Statistics dictionary with fitted interaction parameters + """ + n_degs = 4 + + integral = torch.tensor( + float(trapezoid(dG_nz.cpu().numpy(), bin_centers_nz.cpu().numpy())) + ) + + mask = torch.abs(dG_nz) > 1e-4 * torch.abs(integral) + try: + stat = Quartic._init_quartic_parameter_dict(n_degs) + popt, _ = curve_fit( + Quartic._quartic_model, + bin_centers_nz.cpu().numpy()[mask], + dG_nz.cpu().numpy()[mask], + p0=[1, 0, 0, torch.argmin(dG_nz[mask]), + 0, 0, 0], + bounds=((0, 0, 0, -np.inf, -np.pi, -np.pi, -np.pi), + (np.inf, np.inf, np.inf, np.inf, np.pi, np.pi, np.pi)), + maxfev=5000 + ) + stat = Quartic._make_quartic_dict(stat, popt, n_degs) + except: + print(f"failed to fit potential estimate for QuarticPrior") + stat = Quartic._init_quartic_parameter_dict(n_degs) + k_names = sorted(list(stat["ks"].keys())) + x_0_names = sorted(list(stat["x0s"].keys())) + for ii in range(n_degs-1): + k1_name = k_names[ii] + k2_name = x_0_names[ii] + stat["ks"][k1_name] = torch.tensor(float("nan")) + stat["x0s"][k2_name] = torch.tensor(float("nan")) + + return stat From d360a343107d7fa73a146ae099518115c7ba304f Mon Sep 17 00:00:00 2001 From: lianctrl Date: Thu, 6 Mar 2025 11:55:48 +0100 Subject: [PATCH 2/7] modify init and minor fixes --- mlcg/nn/__init__.py | 2 +- mlcg/nn/prior.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mlcg/nn/__init__.py b/mlcg/nn/__init__.py index 885861c..5bbfc83 100644 --- a/mlcg/nn/__init__.py +++ b/mlcg/nn/__init__.py @@ -3,7 +3,7 @@ from .radial_basis import GaussianBasis, ExpNormalBasis from .cutoff import CosineCutoff, IdentityCutoff from .losses import ForceMSE, ForceRMSE, Loss -from .prior import Harmonic, HarmonicAngles, HarmonicBonds, Repulsion, Dihedral +from .prior import Harmonic, HarmonicAngles, HarmonicBonds, Repulsion, Dihedral, Quartic from .mlp import MLP, TypesMLP from .attention import ExactAttention, FavorAttention, Nonlocalinteractionblock from .pyg_forward_compatibility import ( diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index 51bf57a..7715de4 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1071,15 +1071,15 @@ class Quartic(torch.nn.Module, _Prior): r""" Prior that helps in fitting tighter bimodal distributions using the following energy ansatz. - N.B. the linear term is missing + .. math: V(x) = a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d - - Especially useful for CA angles, to avoid exploration - toward pi + N.B. the linear term is missing + Especially useful for CA angles, to restrain them + avoiding exploration toward pi """ _order_map = { "bonds": 2, @@ -1186,8 +1186,7 @@ def forward(self, data): @staticmethod def compute(x: torch.Tensor, ks: torch.Tensor, V0 : torch.Tensor, x0s: torch.Tensor): - """Harmonic interaction in the form of a series. The shape of the tensors - should match between each other. + """Quartic potential interaction with missing linear term. .. math: @@ -1211,7 +1210,7 @@ def _quartic_model(x, a, b, c, d, xa, xb, xc): def _init_quartic_parameters(n_degs): """ Helper method for guessing initial parameter values - Not used + Not used for now """ ks = [1.0 for _ in range(n_degs-1)] x0s = [0.0 for _ in range(n_degs-1)] @@ -1268,11 +1267,11 @@ def fit_quartic_from_potential_estimates( bin_centers_nz: Bin centers over which the fit is carried out dG_nz: - The emperical free energy correspinding to the bin centers + The free energy values correspinding to the bin centers Returns ------- - Statistics dictionary with fitted interaction parameters + Statistics dictionary with fitted quartic parameters """ n_degs = 4 @@ -1295,7 +1294,7 @@ def fit_quartic_from_potential_estimates( ) stat = Quartic._make_quartic_dict(stat, popt, n_degs) except: - print(f"failed to fit potential estimate for QuarticPrior") + print(f"failed to fit potential estimate for the prior Quartic") stat = Quartic._init_quartic_parameter_dict(n_degs) k_names = sorted(list(stat["ks"].keys())) x_0_names = sorted(list(stat["x0s"].keys())) From 919436b21b009d8bab77fa614356aab08f7fe8d1 Mon Sep 17 00:00:00 2001 From: lianctrl Date: Thu, 6 Mar 2025 12:08:13 +0100 Subject: [PATCH 3/7] remove check --- mlcg/nn/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index 7715de4..b756f38 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1123,7 +1123,7 @@ def __init__(self, statistics, name, order : Optional[int] = None, n_degs:int = k = torch.zeros(self.n_degs-1, *sizes) x0 = torch.zeros(self.n_degs-1, *sizes) v_0 = torch.zeros(*sizes) - print(k.shape) + for key in statistics.keys(): for ii in range(self.n_degs-1): k_name = self.k_names[ii] From 21b75975ef6cff9733653b18421439bcd26f29a0 Mon Sep 17 00:00:00 2001 From: lianctrl Date: Thu, 6 Mar 2025 12:15:17 +0100 Subject: [PATCH 4/7] blacked --- mlcg/nn/prior.py | 104 +++++++++++++++++++++-------------------------- 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index b756f38..51c5350 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1078,9 +1078,10 @@ class Quartic(torch.nn.Module, _Prior): V(x) = a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d N.B. the linear term is missing - Especially useful for CA angles, to restrain them + Especially useful for CA angles, to restrain them avoiding exploration toward pi """ + _order_map = { "bonds": 2, "angles": 3, @@ -1094,20 +1095,21 @@ class Quartic(torch.nn.Module, _Prior): _neighbor_list_map = { "bonds": "bonds", "angles": "angles", - "dihedrals" : "dihedrals", + "dihedrals": "dihedrals", } - - def __init__(self, statistics, name, order : Optional[int] = None, n_degs:int = 4) -> None: + def __init__( + self, statistics, name, order: Optional[int] = None, n_degs: int = 4 + ) -> None: super(Quartic, self).__init__() keys = torch.tensor(list(statistics.keys()), dtype=torch.long) self.allowed_interaction_keys = list(statistics.keys()) self.name = name if order is not None: self.order = order - elif name in Quartic._order_map.keys(): + elif name in Quartic._order_map.keys(): self.order = Quartic._order_map[self.name] - else: + else: raise ValueError(f"Uncompatible order {order}") self.neighbor_list_type = Quartic._neighbor_list_map[self.name] @@ -1117,15 +1119,15 @@ def __init__(self, statistics, name, order : Optional[int] = None, n_degs:int = sizes = tuple([max_type + 1 for _ in range(self.order)]) self.n_degs = n_degs - self.k_names = ["k_" + str(ii) for ii in range(2, self.n_degs+1)] - self.x0_names = ["x0_" + str(ii) for ii in range(2, self.n_degs+1)] + self.k_names = ["k_" + str(ii) for ii in range(2, self.n_degs + 1)] + self.x0_names = ["x0_" + str(ii) for ii in range(2, self.n_degs + 1)] - k = torch.zeros(self.n_degs-1, *sizes) - x0 = torch.zeros(self.n_degs-1, *sizes) + k = torch.zeros(self.n_degs - 1, *sizes) + x0 = torch.zeros(self.n_degs - 1, *sizes) v_0 = torch.zeros(*sizes) for key in statistics.keys(): - for ii in range(self.n_degs-1): + for ii in range(self.n_degs - 1): k_name = self.k_names[ii] x0_name = self.x0_names[ii] k[ii][key] = statistics[key]["ks"][k_name] @@ -1135,35 +1137,31 @@ def __init__(self, statistics, name, order : Optional[int] = None, n_degs:int = self.register_buffer("v_0", v_0) self.register_buffer("x0s", x0) - @staticmethod def compute_features(pos, mapping, target): compute_map_type = Quartic._neighbor_list_map[target] return Quartic._compute_map[compute_map_type](pos, mapping) - def data2features(self, data): mapping = data.neighbor_list[self.name]["index_mapping"] return Quartic.compute_features(data.pos, mapping, self.name) - def data2parameters(self, data): mapping = data.neighbor_list[self.name]["index_mapping"] interaction_types = [ data.atom_types[mapping[ii]] for ii in range(self.order) ] - # the parameters have shape n_features x n_degs-1 since + # the parameters have shape n_features x n_degs-1 since # linear term is missing ks = torch.vstack( - [self.ks[ii][interaction_types] for ii in range(self.n_degs-1)] + [self.ks[ii][interaction_types] for ii in range(self.n_degs - 1)] ).t() x0s = torch.vstack( - [self.x0s[ii][interaction_types] for ii in range(self.n_degs-1)] + [self.x0s[ii][interaction_types] for ii in range(self.n_degs - 1)] ).t() v_0s = self.v_0[interaction_types].t() - return {"ks": ks, "x0s" : x0s, "v_0s": v_0s} - + return {"ks": ks, "x0s": x0s, "v_0s": v_0s} def forward(self, data): mapping_batch = data.neighbor_list[self.name]["mapping_batch"] @@ -1182,11 +1180,11 @@ def forward(self, data): data.out[self.name] = {"energy": y} return data - @staticmethod - def compute(x: torch.Tensor, ks: torch.Tensor, - V0 : torch.Tensor, x0s: torch.Tensor): - """Quartic potential interaction with missing linear term. + def compute( + x: torch.Tensor, ks: torch.Tensor, V0: torch.Tensor, x0s: torch.Tensor + ): + """Quartic potential interaction with missing linear term. .. math: @@ -1195,70 +1193,61 @@ def compute(x: torch.Tensor, ks: torch.Tensor, """ V = 0 for i in range(3): - V+=ks[i]*(x-x0s[i])**(i+2) + V += ks[i] * (x - x0s[i]) ** (i + 2) V += V0 return V - @staticmethod def _quartic_model(x, a, b, c, d, xa, xb, xc): - return a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d - + return a * (x - xa) ** 2 + b * (x - xb) ** 3 + c * (x - xc) ** 4 + d @staticmethod def _init_quartic_parameters(n_degs): - """ + """ Helper method for guessing initial parameter values Not used for now """ - ks = [1.0 for _ in range(n_degs-1)] - x0s = [0.0 for _ in range(n_degs-1)] + ks = [1.0 for _ in range(n_degs - 1)] + x0s = [0.0 for _ in range(n_degs - 1)] V0 = -1.0 p0 = [V0] p0.extend(ks) p0.extend(x0s) return p0 - @staticmethod def _init_quartic_parameter_dict(n_degs): """Helper method for initializing the parameter dictionary""" - stat = { - "ks": {}, - "x0s": {}, - "v_0" : 0.0 - } - k_names = ["k_" + str(ii) for ii in range(2,n_degs+1)] - x0_names = ["x0_" + str(ii) for ii in range(2,n_degs+1)] - for ii in range(n_degs-1): + stat = {"ks": {}, "x0s": {}, "v_0": 0.0} + k_names = ["k_" + str(ii) for ii in range(2, n_degs + 1)] + x0_names = ["x0_" + str(ii) for ii in range(2, n_degs + 1)] + for ii in range(n_degs - 1): k_name = k_names[ii] x0_name = x0_names[ii] stat["ks"][k_name] = {} stat["x0s"][x0_name] = {} return stat - @staticmethod def _make_quartic_dict(stat, popt, n_degs): """Helper method for constructing a fitted parameter dictionary""" stat["v_0"] = popt[0] k_names = sorted(list(stat["ks"].keys())) x0_names = sorted(list(stat["x0s"].keys())) - for ii in range(n_degs-1): + for ii in range(n_degs - 1): k_name = k_names[ii] x0_name = x0_names[ii] stat["ks"][k_name] = popt[ii] - stat["x0s"][x0_name] = popt[ii+n_degs] - - return stat + stat["x0s"][x0_name] = popt[ii + n_degs] + return stat @staticmethod def fit_quartic_from_potential_estimates( - bin_centers_nz: torch.Tensor, - dG_nz: torch.Tensor, - **kwargs, + bin_centers_nz: torch.Tensor, + dG_nz: torch.Tensor, + **kwargs, ): """ @@ -1283,22 +1272,23 @@ def fit_quartic_from_potential_estimates( try: stat = Quartic._init_quartic_parameter_dict(n_degs) popt, _ = curve_fit( - Quartic._quartic_model, - bin_centers_nz.cpu().numpy()[mask], - dG_nz.cpu().numpy()[mask], - p0=[1, 0, 0, torch.argmin(dG_nz[mask]), - 0, 0, 0], - bounds=((0, 0, 0, -np.inf, -np.pi, -np.pi, -np.pi), - (np.inf, np.inf, np.inf, np.inf, np.pi, np.pi, np.pi)), - maxfev=5000 - ) + Quartic._quartic_model, + bin_centers_nz.cpu().numpy()[mask], + dG_nz.cpu().numpy()[mask], + p0=[1, 0, 0, torch.argmin(dG_nz[mask]), 0, 0, 0], + bounds=( + (0, 0, 0, -np.inf, -np.pi, -np.pi, -np.pi), + (np.inf, np.inf, np.inf, np.inf, np.pi, np.pi, np.pi), + ), + maxfev=5000, + ) stat = Quartic._make_quartic_dict(stat, popt, n_degs) except: print(f"failed to fit potential estimate for the prior Quartic") stat = Quartic._init_quartic_parameter_dict(n_degs) k_names = sorted(list(stat["ks"].keys())) x_0_names = sorted(list(stat["x0s"].keys())) - for ii in range(n_degs-1): + for ii in range(n_degs - 1): k1_name = k_names[ii] k2_name = x_0_names[ii] stat["ks"][k1_name] = torch.tensor(float("nan")) From 1c5ac7aeaee2c49d10b7727c481746dd6ed871c3 Mon Sep 17 00:00:00 2001 From: lianctrl Date: Thu, 6 Mar 2025 12:16:48 +0100 Subject: [PATCH 5/7] blacked init --- mlcg/nn/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlcg/nn/__init__.py b/mlcg/nn/__init__.py index 5bbfc83..5aeb582 100644 --- a/mlcg/nn/__init__.py +++ b/mlcg/nn/__init__.py @@ -3,7 +3,14 @@ from .radial_basis import GaussianBasis, ExpNormalBasis from .cutoff import CosineCutoff, IdentityCutoff from .losses import ForceMSE, ForceRMSE, Loss -from .prior import Harmonic, HarmonicAngles, HarmonicBonds, Repulsion, Dihedral, Quartic +from .prior import ( + Harmonic, + HarmonicAngles, + HarmonicBonds, + Repulsion, + Dihedral, + Quartic, +) from .mlp import MLP, TypesMLP from .attention import ExactAttention, FavorAttention, Nonlocalinteractionblock from .pyg_forward_compatibility import ( From d6bec57a8b486209c40f7e4ca4da1cc389ff53fc Mon Sep 17 00:00:00 2001 From: sayeg84 Date: Tue, 1 Apr 2025 16:14:06 +0200 Subject: [PATCH 6/7] Replace Quartic with Polynomial and QuarticAngles --- mlcg/nn/__init__.py | 3 +- mlcg/nn/prior.py | 232 ++++++++++++++++---------------------------- 2 files changed, 86 insertions(+), 149 deletions(-) diff --git a/mlcg/nn/__init__.py b/mlcg/nn/__init__.py index 5aeb582..5166b00 100644 --- a/mlcg/nn/__init__.py +++ b/mlcg/nn/__init__.py @@ -9,7 +9,8 @@ HarmonicBonds, Repulsion, Dihedral, - Quartic, + Polynomial, + QuarticAngles, ) from .mlp import MLP, TypesMLP from .attention import ExactAttention, FavorAttention, Nonlocalinteractionblock diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index 51c5350..7cb5cf9 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1067,21 +1067,36 @@ def neighbor_list(topology) -> None: return {Dihedral.name: nl} -class Quartic(torch.nn.Module, _Prior): +class Polynomial(torch.nn.Module, _Prior): r""" - Prior that helps in fitting tighter bimodal distributions - using the following energy ansatz. - + Prior representing a polynomial with + the following energy ansatz: .. math: - V(x) = a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d + V(r) = V_0 + \sum_{n=1}^{n_deg} k_n (x-x_0)^n - N.B. the linear term is missing - Especially useful for CA angles, to restrain them - avoiding exploration toward pi - """ + + Parameters + ---------- + statistics: + Dictionary of interaction parameters for each type of atom combination, + where the keys are tuples of interacting bead types and the + corresponding values define the interaction parameters. These + Can be hand-designed or taken from the output of + `mlcg.geometry.statistics.compute_statistics`, but must minimally + contain the following information for each key: + .. code-block:: python + + tuple(*specific_types) : { + "ks" : torch.Tensor that contains all k_1,..,k_{n_degs} coefficients + "v_0" : torch.Tensor that contains the constant offset + ... + } + + The keys must be tuples of 2,3,4 atoms. + """ _order_map = { "bonds": 2, "angles": 3, @@ -1095,203 +1110,124 @@ class Quartic(torch.nn.Module, _Prior): _neighbor_list_map = { "bonds": "bonds", "angles": "angles", - "dihedrals": "dihedrals", + "dihedrals" : "dihedrals", } - def __init__( - self, statistics, name, order: Optional[int] = None, n_degs: int = 4 - ) -> None: - super(Quartic, self).__init__() + def __init__(self, statistics: dict, name: str, order : Optional[int] = None, n_degs:int = 4) -> None: + r""" + + """ + super(Polynomial, self).__init__() keys = torch.tensor(list(statistics.keys()), dtype=torch.long) self.allowed_interaction_keys = list(statistics.keys()) self.name = name + if order is not None: self.order = order - elif name in Quartic._order_map.keys(): - self.order = Quartic._order_map[self.name] - else: + elif name in Polynomial._order_map.keys(): + self.order = Polynomial._order_map[self.name] + else: raise ValueError(f"Uncompatible order {order}") - self.neighbor_list_type = Quartic._neighbor_list_map[self.name] - + unique_types = torch.unique(keys.flatten()) assert unique_types.min() >= 0 + max_type = unique_types.max() sizes = tuple([max_type + 1 for _ in range(self.order)]) + + unique_degs = torch.unique(torch.tensor([len(val["ks"]) for _,val in statistics.items()])) + assert len(unique_degs) == 1, "ks in the statistics dictionary must be of the same size for all the keys" + assert unique_degs[0] == n_degs, f"length of parameters {unique_degs[0]} doesn't match degrees {n_degs}" self.n_degs = n_degs - self.k_names = ["k_" + str(ii) for ii in range(2, self.n_degs + 1)] - self.x0_names = ["x0_" + str(ii) for ii in range(2, self.n_degs + 1)] - - k = torch.zeros(self.n_degs - 1, *sizes) - x0 = torch.zeros(self.n_degs - 1, *sizes) + self.k_names = ["k_" + str(ii) for ii in range(1, self.n_degs+1)] + k = torch.zeros(self.n_degs, *sizes) v_0 = torch.zeros(*sizes) - for key in statistics.keys(): - for ii in range(self.n_degs - 1): + for ii in range(self.n_degs): k_name = self.k_names[ii] - x0_name = self.x0_names[ii] k[ii][key] = statistics[key]["ks"][k_name] - x0[ii][key] = statistics[key]["x0s"][x0_name] v_0[key] = statistics[key]["v_0"] self.register_buffer("ks", k) self.register_buffer("v_0", v_0) - self.register_buffer("x0s", x0) - - @staticmethod - def compute_features(pos, mapping, target): - compute_map_type = Quartic._neighbor_list_map[target] - return Quartic._compute_map[compute_map_type](pos, mapping) + return None def data2features(self, data): mapping = data.neighbor_list[self.name]["index_mapping"] - return Quartic.compute_features(data.pos, mapping, self.name) + if hasattr(data, "pbc"): + return Polynomial.compute_features( + data.pos, + mapping, + self.name, + data.pbc, + data.cell, + data.batch, + ) + else: + return Polynomial.compute_features(data.pos, mapping, self.name) def data2parameters(self, data): mapping = data.neighbor_list[self.name]["index_mapping"] interaction_types = [ data.atom_types[mapping[ii]] for ii in range(self.order) ] - # the parameters have shape n_features x n_degs-1 since - # linear term is missing + # the parameters have shape n_features x n_degs ks = torch.vstack( - [self.ks[ii][interaction_types] for ii in range(self.n_degs - 1)] - ).t() - - x0s = torch.vstack( - [self.x0s[ii][interaction_types] for ii in range(self.n_degs - 1)] + [self.ks[ii][interaction_types] for ii in range(self.n_degs)] ).t() v_0s = self.v_0[interaction_types].t() - return {"ks": ks, "x0s": x0s, "v_0s": v_0s} + return {"ks": ks, "v_0s": v_0s} def forward(self, data): mapping_batch = data.neighbor_list[self.name]["mapping_batch"] features = self.data2features(data).flatten() params = self.data2parameters(data) + #V0s = params["v_0"] if "v_0" in params.keys() else [0 for ii in range(self.n_degs)] V0s = params["v_0s"].t() + # format parameters + #ks = [params["ks"][:,i] for i in range(self.n_degs)] ks = params["ks"].t() - x0s = params["x0s"].t() - y = Quartic.compute( + y = Polynomial.compute( features, ks, V0s, - x0s, ) y = scatter(y, mapping_batch, dim=0, reduce="sum") data.out[self.name] = {"energy": y} return data @staticmethod - def compute( - x: torch.Tensor, ks: torch.Tensor, V0: torch.Tensor, x0s: torch.Tensor - ): - """Quartic potential interaction with missing linear term. + def compute_features(pos, mapping, target): + compute_map_type = Polynomial._neighbor_list_map[target] + return Polynomial._compute_map[compute_map_type](pos, mapping) + + @staticmethod + def compute(x: torch.Tensor, ks: torch.Tensor, V0 : torch.Tensor): + """Harmonic interaction in the form of a series. The shape of the tensors + should match between each other. .. math: - V(r) = V0 + \sum_{n=2}^{4} k_n (x-x_n)^n + V(r) = V0 + \sum_{n=1}^{deg} k_n x^n """ - V = 0 - for i in range(3): - V += ks[i] * (x - x0s[i]) ** (i + 2) - + V = ks[0]*x + for p, k in enumerate(ks[1:],start=2): + V += k * torch.pow(x, p) V += V0 return V + - @staticmethod - def _quartic_model(x, a, b, c, d, xa, xb, xc): - return a * (x - xa) ** 2 + b * (x - xb) ** 3 + c * (x - xc) ** 4 + d - - @staticmethod - def _init_quartic_parameters(n_degs): - """ - Helper method for guessing initial parameter values - Not used for now - """ - ks = [1.0 for _ in range(n_degs - 1)] - x0s = [0.0 for _ in range(n_degs - 1)] - V0 = -1.0 - p0 = [V0] - p0.extend(ks) - p0.extend(x0s) - return p0 - - @staticmethod - def _init_quartic_parameter_dict(n_degs): - """Helper method for initializing the parameter dictionary""" - stat = {"ks": {}, "x0s": {}, "v_0": 0.0} - k_names = ["k_" + str(ii) for ii in range(2, n_degs + 1)] - x0_names = ["x0_" + str(ii) for ii in range(2, n_degs + 1)] - for ii in range(n_degs - 1): - k_name = k_names[ii] - x0_name = x0_names[ii] - stat["ks"][k_name] = {} - stat["x0s"][x0_name] = {} - return stat - - @staticmethod - def _make_quartic_dict(stat, popt, n_degs): - """Helper method for constructing a fitted parameter dictionary""" - stat["v_0"] = popt[0] - k_names = sorted(list(stat["ks"].keys())) - x0_names = sorted(list(stat["x0s"].keys())) - for ii in range(n_degs - 1): - k_name = k_names[ii] - x0_name = x0_names[ii] - stat["ks"][k_name] = popt[ii] - stat["x0s"][x0_name] = popt[ii + n_degs] +class QuarticAngles(Polynomial): + """Wrapper class for angle priors + (order 3 Polynomial priors of degree 4) + """ - return stat + def __init__(self, statistics, name="angles", n_degs: int = 4) -> None: + super(QuarticAngles, self).__init__(statistics, name, order=3, n_degs=n_degs) + @staticmethod - def fit_quartic_from_potential_estimates( - bin_centers_nz: torch.Tensor, - dG_nz: torch.Tensor, - **kwargs, - ): - """ - - Parameters - ---------- - bin_centers_nz: - Bin centers over which the fit is carried out - dG_nz: - The free energy values correspinding to the bin centers - - Returns - ------- - Statistics dictionary with fitted quartic parameters - """ - n_degs = 4 - - integral = torch.tensor( - float(trapezoid(dG_nz.cpu().numpy(), bin_centers_nz.cpu().numpy())) - ) - - mask = torch.abs(dG_nz) > 1e-4 * torch.abs(integral) - try: - stat = Quartic._init_quartic_parameter_dict(n_degs) - popt, _ = curve_fit( - Quartic._quartic_model, - bin_centers_nz.cpu().numpy()[mask], - dG_nz.cpu().numpy()[mask], - p0=[1, 0, 0, torch.argmin(dG_nz[mask]), 0, 0, 0], - bounds=( - (0, 0, 0, -np.inf, -np.pi, -np.pi, -np.pi), - (np.inf, np.inf, np.inf, np.inf, np.pi, np.pi, np.pi), - ), - maxfev=5000, - ) - stat = Quartic._make_quartic_dict(stat, popt, n_degs) - except: - print(f"failed to fit potential estimate for the prior Quartic") - stat = Quartic._init_quartic_parameter_dict(n_degs) - k_names = sorted(list(stat["ks"].keys())) - x_0_names = sorted(list(stat["x0s"].keys())) - for ii in range(n_degs - 1): - k1_name = k_names[ii] - k2_name = x_0_names[ii] - stat["ks"][k1_name] = torch.tensor(float("nan")) - stat["x0s"][k2_name] = torch.tensor(float("nan")) + def compute_features(pos, mapping): + return Polynomial.compute_features(pos, mapping, "angles") - return stat From 8716f84796e4f0b95aa46d3fb02835601a6d80a5 Mon Sep 17 00:00:00 2001 From: sayeg84 Date: Tue, 1 Apr 2025 16:14:41 +0200 Subject: [PATCH 7/7] Black --- mlcg/nn/prior.py | 61 ++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/mlcg/nn/prior.py b/mlcg/nn/prior.py index 7cb5cf9..86b9f2d 100644 --- a/mlcg/nn/prior.py +++ b/mlcg/nn/prior.py @@ -1076,7 +1076,7 @@ class Polynomial(torch.nn.Module, _Prior): V(r) = V_0 + \sum_{n=1}^{n_deg} k_n (x-x_0)^n - + Parameters ---------- statistics: @@ -1097,6 +1097,7 @@ class Polynomial(torch.nn.Module, _Prior): The keys must be tuples of 2,3,4 atoms. """ + _order_map = { "bonds": 2, "angles": 3, @@ -1110,13 +1111,17 @@ class Polynomial(torch.nn.Module, _Prior): _neighbor_list_map = { "bonds": "bonds", "angles": "angles", - "dihedrals" : "dihedrals", + "dihedrals": "dihedrals", } - def __init__(self, statistics: dict, name: str, order : Optional[int] = None, n_degs:int = 4) -> None: - r""" - - """ + def __init__( + self, + statistics: dict, + name: str, + order: Optional[int] = None, + n_degs: int = 4, + ) -> None: + r""" """ super(Polynomial, self).__init__() keys = torch.tensor(list(statistics.keys()), dtype=torch.long) self.allowed_interaction_keys = list(statistics.keys()) @@ -1124,23 +1129,29 @@ def __init__(self, statistics: dict, name: str, order : Optional[int] = None, n_ if order is not None: self.order = order - elif name in Polynomial._order_map.keys(): + elif name in Polynomial._order_map.keys(): self.order = Polynomial._order_map[self.name] - else: + else: raise ValueError(f"Uncompatible order {order}") - + unique_types = torch.unique(keys.flatten()) assert unique_types.min() >= 0 - + max_type = unique_types.max() sizes = tuple([max_type + 1 for _ in range(self.order)]) - - unique_degs = torch.unique(torch.tensor([len(val["ks"]) for _,val in statistics.items()])) - assert len(unique_degs) == 1, "ks in the statistics dictionary must be of the same size for all the keys" - assert unique_degs[0] == n_degs, f"length of parameters {unique_degs[0]} doesn't match degrees {n_degs}" + + unique_degs = torch.unique( + torch.tensor([len(val["ks"]) for _, val in statistics.items()]) + ) + assert ( + len(unique_degs) == 1 + ), "ks in the statistics dictionary must be of the same size for all the keys" + assert ( + unique_degs[0] == n_degs + ), f"length of parameters {unique_degs[0]} doesn't match degrees {n_degs}" self.n_degs = n_degs - self.k_names = ["k_" + str(ii) for ii in range(1, self.n_degs+1)] + self.k_names = ["k_" + str(ii) for ii in range(1, self.n_degs + 1)] k = torch.zeros(self.n_degs, *sizes) v_0 = torch.zeros(*sizes) for key in statistics.keys(): @@ -1182,10 +1193,10 @@ def forward(self, data): mapping_batch = data.neighbor_list[self.name]["mapping_batch"] features = self.data2features(data).flatten() params = self.data2parameters(data) - #V0s = params["v_0"] if "v_0" in params.keys() else [0 for ii in range(self.n_degs)] + # V0s = params["v_0"] if "v_0" in params.keys() else [0 for ii in range(self.n_degs)] V0s = params["v_0s"].t() # format parameters - #ks = [params["ks"][:,i] for i in range(self.n_degs)] + # ks = [params["ks"][:,i] for i in range(self.n_degs)] ks = params["ks"].t() y = Polynomial.compute( features, @@ -1202,7 +1213,7 @@ def compute_features(pos, mapping, target): return Polynomial._compute_map[compute_map_type](pos, mapping) @staticmethod - def compute(x: torch.Tensor, ks: torch.Tensor, V0 : torch.Tensor): + def compute(x: torch.Tensor, ks: torch.Tensor, V0: torch.Tensor): """Harmonic interaction in the form of a series. The shape of the tensors should match between each other. @@ -1211,23 +1222,23 @@ def compute(x: torch.Tensor, ks: torch.Tensor, V0 : torch.Tensor): V(r) = V0 + \sum_{n=1}^{deg} k_n x^n """ - V = ks[0]*x - for p, k in enumerate(ks[1:],start=2): + V = ks[0] * x + for p, k in enumerate(ks[1:], start=2): V += k * torch.pow(x, p) V += V0 return V - + class QuarticAngles(Polynomial): """Wrapper class for angle priors (order 3 Polynomial priors of degree 4) """ - def __init__(self, statistics, name="angles", n_degs: int = 4) -> None: - super(QuarticAngles, self).__init__(statistics, name, order=3, n_degs=n_degs) - + super(QuarticAngles, self).__init__( + statistics, name, order=3, n_degs=n_degs + ) + @staticmethod def compute_features(pos, mapping): return Polynomial.compute_features(pos, mapping, "angles") -