-
Notifications
You must be signed in to change notification settings - Fork 3
Adding quartic prior #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
10a567c
added quartic prior
d360a34
modify init and minor fixes
919436b
remove check
21b7597
blacked
1c5ac7a
blacked init
d6bec57
Replace Quartic with Polynomial and QuarticAngles
sayeg84 8716f84
Black
sayeg84 dc23101
Merge branch 'main' into feat/quart_prior
sayeg84 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1065,3 +1065,233 @@ def from_user(*args): | |
def neighbor_list(topology) -> None: | ||
nl = topology.neighbor_list(Dihedral.name) | ||
return {Dihedral.name: nl} | ||
|
||
|
||
class Quartic(torch.nn.Module, _Prior): | ||
r""" | ||
Prior that helps in fitting tighter bimodal distributions | ||
using the following energy ansatz. | ||
|
||
|
||
.. math: | ||
|
||
V(x) = a*(x-xa)**2 + b*(x-xb)**3 + c*(x-xc)**4 + d | ||
|
||
N.B. the linear term is missing | ||
Especially useful for CA angles, to restrain them | ||
avoiding exploration toward pi | ||
""" | ||
|
||
_order_map = { | ||
"bonds": 2, | ||
"angles": 3, | ||
"dihedrals": 4, | ||
} | ||
_compute_map = { | ||
"bonds": compute_distances, | ||
"angles": compute_angles, | ||
"dihedrals": compute_torsions, | ||
} | ||
_neighbor_list_map = { | ||
"bonds": "bonds", | ||
"angles": "angles", | ||
"dihedrals": "dihedrals", | ||
} | ||
|
||
def __init__( | ||
self, statistics, name, order: Optional[int] = None, n_degs: int = 4 | ||
) -> None: | ||
super(Quartic, self).__init__() | ||
keys = torch.tensor(list(statistics.keys()), dtype=torch.long) | ||
self.allowed_interaction_keys = list(statistics.keys()) | ||
self.name = name | ||
if order is not None: | ||
self.order = order | ||
elif name in Quartic._order_map.keys(): | ||
self.order = Quartic._order_map[self.name] | ||
else: | ||
raise ValueError(f"Uncompatible order {order}") | ||
self.neighbor_list_type = Quartic._neighbor_list_map[self.name] | ||
|
||
unique_types = torch.unique(keys.flatten()) | ||
assert unique_types.min() >= 0 | ||
max_type = unique_types.max() | ||
sizes = tuple([max_type + 1 for _ in range(self.order)]) | ||
|
||
self.n_degs = n_degs | ||
self.k_names = ["k_" + str(ii) for ii in range(2, self.n_degs + 1)] | ||
self.x0_names = ["x0_" + str(ii) for ii in range(2, self.n_degs + 1)] | ||
|
||
k = torch.zeros(self.n_degs - 1, *sizes) | ||
x0 = torch.zeros(self.n_degs - 1, *sizes) | ||
v_0 = torch.zeros(*sizes) | ||
|
||
for key in statistics.keys(): | ||
for ii in range(self.n_degs - 1): | ||
k_name = self.k_names[ii] | ||
x0_name = self.x0_names[ii] | ||
k[ii][key] = statistics[key]["ks"][k_name] | ||
x0[ii][key] = statistics[key]["x0s"][x0_name] | ||
v_0[key] = statistics[key]["v_0"] | ||
self.register_buffer("ks", k) | ||
self.register_buffer("v_0", v_0) | ||
self.register_buffer("x0s", x0) | ||
|
||
@staticmethod | ||
def compute_features(pos, mapping, target): | ||
compute_map_type = Quartic._neighbor_list_map[target] | ||
return Quartic._compute_map[compute_map_type](pos, mapping) | ||
|
||
def data2features(self, data): | ||
mapping = data.neighbor_list[self.name]["index_mapping"] | ||
return Quartic.compute_features(data.pos, mapping, self.name) | ||
|
||
def data2parameters(self, data): | ||
mapping = data.neighbor_list[self.name]["index_mapping"] | ||
interaction_types = [ | ||
data.atom_types[mapping[ii]] for ii in range(self.order) | ||
] | ||
# the parameters have shape n_features x n_degs-1 since | ||
# linear term is missing | ||
ks = torch.vstack( | ||
[self.ks[ii][interaction_types] for ii in range(self.n_degs - 1)] | ||
).t() | ||
|
||
x0s = torch.vstack( | ||
[self.x0s[ii][interaction_types] for ii in range(self.n_degs - 1)] | ||
).t() | ||
v_0s = self.v_0[interaction_types].t() | ||
return {"ks": ks, "x0s": x0s, "v_0s": v_0s} | ||
|
||
def forward(self, data): | ||
mapping_batch = data.neighbor_list[self.name]["mapping_batch"] | ||
features = self.data2features(data).flatten() | ||
params = self.data2parameters(data) | ||
V0s = params["v_0s"].t() | ||
ks = params["ks"].t() | ||
x0s = params["x0s"].t() | ||
y = Quartic.compute( | ||
features, | ||
ks, | ||
V0s, | ||
x0s, | ||
) | ||
y = scatter(y, mapping_batch, dim=0, reduce="sum") | ||
data.out[self.name] = {"energy": y} | ||
return data | ||
|
||
@staticmethod | ||
def compute( | ||
x: torch.Tensor, ks: torch.Tensor, V0: torch.Tensor, x0s: torch.Tensor | ||
): | ||
"""Quartic potential interaction with missing linear term. | ||
|
||
.. math: | ||
|
||
V(r) = V0 + \sum_{n=2}^{4} k_n (x-x_n)^n | ||
|
||
""" | ||
V = 0 | ||
for i in range(3): | ||
V += ks[i] * (x - x0s[i]) ** (i + 2) | ||
|
||
V += V0 | ||
return V | ||
|
||
@staticmethod | ||
def _quartic_model(x, a, b, c, d, xa, xb, xc): | ||
return a * (x - xa) ** 2 + b * (x - xb) ** 3 + c * (x - xc) ** 4 + d | ||
|
||
@staticmethod | ||
def _init_quartic_parameters(n_degs): | ||
""" | ||
Helper method for guessing initial parameter values | ||
Not used for now | ||
""" | ||
ks = [1.0 for _ in range(n_degs - 1)] | ||
x0s = [0.0 for _ in range(n_degs - 1)] | ||
V0 = -1.0 | ||
p0 = [V0] | ||
p0.extend(ks) | ||
p0.extend(x0s) | ||
return p0 | ||
|
||
@staticmethod | ||
def _init_quartic_parameter_dict(n_degs): | ||
"""Helper method for initializing the parameter dictionary""" | ||
stat = {"ks": {}, "x0s": {}, "v_0": 0.0} | ||
k_names = ["k_" + str(ii) for ii in range(2, n_degs + 1)] | ||
x0_names = ["x0_" + str(ii) for ii in range(2, n_degs + 1)] | ||
for ii in range(n_degs - 1): | ||
k_name = k_names[ii] | ||
x0_name = x0_names[ii] | ||
stat["ks"][k_name] = {} | ||
stat["x0s"][x0_name] = {} | ||
return stat | ||
|
||
@staticmethod | ||
def _make_quartic_dict(stat, popt, n_degs): | ||
"""Helper method for constructing a fitted parameter dictionary""" | ||
stat["v_0"] = popt[0] | ||
k_names = sorted(list(stat["ks"].keys())) | ||
x0_names = sorted(list(stat["x0s"].keys())) | ||
for ii in range(n_degs - 1): | ||
k_name = k_names[ii] | ||
x0_name = x0_names[ii] | ||
stat["ks"][k_name] = popt[ii] | ||
stat["x0s"][x0_name] = popt[ii + n_degs] | ||
|
||
return stat | ||
|
||
@staticmethod | ||
def fit_quartic_from_potential_estimates( | ||
bin_centers_nz: torch.Tensor, | ||
dG_nz: torch.Tensor, | ||
**kwargs, | ||
): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to make a PR also to mlcg-tk because we need to call also the fitting function from there. |
||
|
||
Parameters | ||
---------- | ||
bin_centers_nz: | ||
Bin centers over which the fit is carried out | ||
dG_nz: | ||
The free energy values correspinding to the bin centers | ||
|
||
Returns | ||
------- | ||
Statistics dictionary with fitted quartic parameters | ||
""" | ||
n_degs = 4 | ||
|
||
integral = torch.tensor( | ||
float(trapezoid(dG_nz.cpu().numpy(), bin_centers_nz.cpu().numpy())) | ||
) | ||
|
||
mask = torch.abs(dG_nz) > 1e-4 * torch.abs(integral) | ||
try: | ||
stat = Quartic._init_quartic_parameter_dict(n_degs) | ||
popt, _ = curve_fit( | ||
Quartic._quartic_model, | ||
bin_centers_nz.cpu().numpy()[mask], | ||
dG_nz.cpu().numpy()[mask], | ||
p0=[1, 0, 0, torch.argmin(dG_nz[mask]), 0, 0, 0], | ||
bounds=( | ||
(0, 0, 0, -np.inf, -np.pi, -np.pi, -np.pi), | ||
(np.inf, np.inf, np.inf, np.inf, np.pi, np.pi, np.pi), | ||
), | ||
maxfev=5000, | ||
) | ||
stat = Quartic._make_quartic_dict(stat, popt, n_degs) | ||
except: | ||
print(f"failed to fit potential estimate for the prior Quartic") | ||
stat = Quartic._init_quartic_parameter_dict(n_degs) | ||
k_names = sorted(list(stat["ks"].keys())) | ||
x_0_names = sorted(list(stat["x0s"].keys())) | ||
for ii in range(n_degs - 1): | ||
k1_name = k_names[ii] | ||
k2_name = x_0_names[ii] | ||
stat["ks"][k1_name] = torch.tensor(float("nan")) | ||
stat["x0s"][k2_name] = torch.tensor(float("nan")) | ||
|
||
return stat |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A deg 4 polynomial p(x) is uniquely caracterized by 5 numbers: a_0, .. a_4 such that p(x) = sum_{k=0}^4 a_k * x^k. Your degree 4 polynomial uses 7 parameters in order to keep this shape of p(x) = sum_{k=2}^4 b_k *(x_0^{(k)} x)^k. Isn't the fit then ill posed?
At some point I tried to use an expression similar to what you have now and found it problematic because of this redundancy of parameters.