Skip to content

Commit 4606183

Browse files
authored
Merge pull request #153 from mila-iqia/edit_generated_samples
Edit generated samples
2 parents f2edfba + 0ea4043 commit 4606183

File tree

8 files changed

+643
-160
lines changed

8 files changed

+643
-160
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#================================================================================
2+
# Configuration file for an active learning run
3+
#================================================================================
4+
exp_name: excise_and_repaint_sample_maker
5+
6+
seed: 42
7+
8+
elements: [Si]
9+
10+
uncertainty_thresholds: [0.001, 0.0001, 0.00001, 0.000001]
11+
12+
flare:
13+
cutoff: 5.0
14+
n_radial: 12
15+
lmax: 3
16+
initial_sigma: 1000.0
17+
initial_sigma_e: 1.0
18+
initial_sigma_f: 0.050
19+
initial_sigma_s: 1.0
20+
variance_type: local
21+
22+
flare_optimizer:
23+
optimize_on_the_fly: False
24+
# optimization_method: "nelder-mead"
25+
# max_optimization_iterations: 10
26+
# optimize_sigma: False
27+
# optimize_sigma_e: False
28+
# optimize_sigma_f: False
29+
# optimize_sigma_s: False
30+
31+
oracle:
32+
name: stillinger_weber
33+
sw_coeff_filename: Si.sw
34+
35+
sampling:
36+
algorithm: excise_and_repaint
37+
sample_box_strategy: fixed
38+
sample_box_size: [ 10.86, 10.86, 10.86 ]
39+
sample_edit_radius: 5.0 # in Angstrom: generated atoms within this radius from the central atom will be removed.
40+
excision:
41+
algorithm: spherical_cutoff
42+
radial_cutoff: 5.0 # radial cutoff in Angstrom
43+
noise:
44+
total_time_steps: 500
45+
sigma_min: 0.0001
46+
sigma_max: 0.2
47+
schedule_type: linear
48+
corrector_step_epsilon: 2.5e-8
49+
repaint_generator:
50+
number_of_atoms: 64
51+
number_of_corrector_steps: 2
52+
one_atom_type_transition_per_step: False
53+
atom_type_greedy_sampling: False
54+
atom_type_transition_in_corrector: False
55+
record_samples: False
56+
57+
lammps:
58+
mpi_processors: 4
59+
openmp_threads: 2
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Any, AnyStr, Dict, List, Optional, Tuple, Union
2+
3+
import torch
4+
5+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.atom_selector.atom_selector_factory import \
6+
create_atom_selector_parameters
7+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.excisor.excisor_factory import \
8+
create_excisor_parameters
9+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.sample_maker.base_sample_maker import \
10+
BaseSampleMaker
11+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.sample_maker.sample_maker_factory import (
12+
create_sample_maker, create_sample_maker_parameters)
13+
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \
14+
PredictorCorrectorSamplingParameters
15+
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \
16+
ScoreNetwork
17+
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
18+
NoiseParameters
19+
from diffusion_for_multi_scale_molecular_dynamics.sample_diffusion import \
20+
get_axl_network
21+
22+
23+
def get_repaint_parameters(
24+
sampling_dictionary: Dict[AnyStr, Any],
25+
element_list: List[str],
26+
path_to_score_network_checkpoint: Optional[str] = None,
27+
) -> Tuple[
28+
Union[NoiseParameters, None],
29+
Union[PredictorCorrectorSamplingParameters, None],
30+
Union[ScoreNetwork, None],
31+
str,
32+
]:
33+
"""Get repaint parameters.
34+
35+
This convenience method is responsible for extracting the relevant configuration objects in the
36+
case that the sample maker algorithm is "Excise and Repaint", and to return a "None" default for
37+
these configuration objects if a different algorithm is used.
38+
39+
Args:
40+
sampling_dictionary: Dictionary of sampling parameters, as read in from a yaml configuration file.
41+
element_list: List of element names.
42+
path_to_score_network_checkpoint: Path to score network checkpoint.
43+
44+
Returns:
45+
noise_parameters: a NoiseParameters object if the config is present, otherwise None.
46+
sampling_parameters: a PredictorCorrectorSamplingParameters object if the config is present, otherwise None.
47+
axl_network: a Score Network object to draw constrained samples if the config is present, otherwise None.
48+
device: a string indicating which device should be used: either cpu or cuda.
49+
"""
50+
algorithm = sampling_dictionary["algorithm"]
51+
# Default values
52+
device = "cpu"
53+
axl_network = None
54+
noise_parameters = None
55+
sampling_parameters = None
56+
if algorithm != "excise_and_repaint":
57+
return noise_parameters, sampling_parameters, axl_network, device
58+
59+
if torch.cuda.is_available():
60+
device = "cuda"
61+
assert (
62+
path_to_score_network_checkpoint is not None
63+
), "A path to a valid score network checkpoint must be provided to use 'excise_and_repaint'."
64+
axl_network = get_axl_network(path_to_score_network_checkpoint)
65+
66+
assert (
67+
"noise" in sampling_dictionary
68+
), "A 'noise' configuration must be defined in the 'sampling' field in order to use 'excise_and_repaint'."
69+
70+
noise_dictionary = sampling_dictionary["noise"]
71+
noise_parameters = NoiseParameters(**noise_dictionary)
72+
73+
assert "repaint_generator" in sampling_dictionary, (
74+
"A 'repaint_generator' configuration must be defined in the 'sampling' field in order to use "
75+
"'excise_and_repaint'."
76+
)
77+
78+
sampling_generator_dictionary = sampling_dictionary["repaint_generator"]
79+
80+
assert "algorithm" not in sampling_generator_dictionary, (
81+
"Do not specify the 'algorithm' for the repaint generator: only the predictor_corrector repaint generator "
82+
"algorithm is valid and will be automatically selected."
83+
)
84+
sampling_generator_dictionary["algorithm"] = "predictor_corrector"
85+
86+
assert "num_atom_types" not in sampling_generator_dictionary, (
87+
"Do not specify the 'num_atom_types' for the repaint generator: the value will be inferred from "
88+
"the element list."
89+
)
90+
sampling_generator_dictionary["num_atom_types"] = len(element_list)
91+
92+
assert "number_of_samples" not in sampling_generator_dictionary, (
93+
"Do not specify the 'number_of_samples' for the repaint generator: the value will be inferred from "
94+
"the 'number_of_samples_per_substructure' sampling field."
95+
)
96+
sampling_generator_dictionary["number_of_samples"] = sampling_dictionary.get(
97+
"number_of_samples_per_substructure", 1
98+
)
99+
100+
assert (
101+
"use_fixed_lattice_parameters" not in sampling_generator_dictionary
102+
and "cell_dimensions" not in sampling_generator_dictionary
103+
), (
104+
"Do not specify 'use_fixed_lattice_parameters' or 'cell_dimensions' for the repaint generator: these values "
105+
"will be inferred from the sampling field."
106+
)
107+
sampling_generator_dictionary["use_fixed_lattice_parameters"] = (
108+
sampling_dictionary.get("sample_box_strategy", "fixed")
109+
)
110+
111+
if sampling_generator_dictionary["use_fixed_lattice_parameters"] == "fixed":
112+
sampling_generator_dictionary["cell_dimensions"] = sampling_dictionary[
113+
"sample_box_size"
114+
]
115+
116+
sampling_parameters = PredictorCorrectorSamplingParameters(
117+
**sampling_generator_dictionary
118+
)
119+
120+
return noise_parameters, sampling_parameters, axl_network, device
121+
122+
123+
def get_sample_maker_from_configuration(
124+
sampling_dictionary: Dict,
125+
uncertainty_threshold: float,
126+
element_list: List[str],
127+
path_to_score_network_checkpoint: Optional[str] = None,
128+
) -> BaseSampleMaker:
129+
"""Get sample maker from configuration.
130+
131+
the sampling dictionary should have the following structure:
132+
133+
sampling:
134+
algorithm: ...
135+
(other sample maker parameters)
136+
137+
excision [Only if using Excise and *]:
138+
(excision parameters)
139+
140+
noise [Only if using Excise and Repaint]:
141+
(noise parameters)
142+
143+
repaint_generator [Only if using Excise and Repaint]:
144+
(constrained sampling parameters)
145+
146+
Args:
147+
sampling_dictionary: Dictionary of sampling parameters, as read in from a yaml configuration file.
148+
uncertainty_threshold: Uncertainty threshold.
149+
element_list: List of element names.
150+
path_to_score_network_checkpoint: Path to score network checkpoint.
151+
152+
Returns:
153+
sample_maker: A configured Sample Maker instance.
154+
"""
155+
# Let's make sure we don't modify the input, which would lead to undesirable side effects!
156+
sampling_dict = sampling_dictionary.copy()
157+
158+
noise_parameters, sampling_parameters, axl_network, device = get_repaint_parameters(
159+
sampling_dictionary=sampling_dict,
160+
element_list=element_list,
161+
path_to_score_network_checkpoint=path_to_score_network_checkpoint,
162+
)
163+
164+
atom_selector_parameter_dictionary = dict(
165+
algorithm="threshold", uncertainty_threshold=uncertainty_threshold
166+
)
167+
atom_selector_parameters = create_atom_selector_parameters(
168+
atom_selector_parameter_dictionary
169+
)
170+
171+
excisor_parameter_dictionary = sampling_dict.pop("excision", None)
172+
if excisor_parameter_dictionary is not None:
173+
excisor_parameters = create_excisor_parameters(excisor_parameter_dictionary)
174+
else:
175+
excisor_parameters = None
176+
177+
# Let's extract only the sample_maker configuration, popping out components that don't belong.
178+
sample_maker_dictionary = sampling_dict.copy()
179+
sample_maker_dictionary["element_list"] = element_list
180+
sample_maker_dictionary.pop("noise", None)
181+
sample_maker_dictionary.pop("repaint_generator", None)
182+
183+
sample_maker_parameters = create_sample_maker_parameters(sample_maker_dictionary)
184+
185+
sample_maker = create_sample_maker(
186+
sample_maker_parameters=sample_maker_parameters,
187+
atom_selector_parameters=atom_selector_parameters,
188+
excisor_parameters=excisor_parameters,
189+
noise_parameters=noise_parameters,
190+
sampling_parameters=sampling_parameters,
191+
diffusion_model=axl_network,
192+
device=device,
193+
)
194+
return sample_maker

src/diffusion_for_multi_scale_molecular_dynamics/active_learning_loop/sample_maker/excise_and_repaint_sample_maker.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple
33

4+
import numpy as np
45
import torch
56

67
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.atom_selector.base_atom_selector import \
@@ -9,6 +10,8 @@
910
BaseEnvironmentExcision
1011
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.sample_maker.base_sample_maker import (
1112
BaseExciseSampleMaker, BaseExciseSampleMakerArguments)
13+
from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.utils import \
14+
get_distances_from_reference_point
1215
from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \
1316
SamplingParameters
1417
from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import (
@@ -28,6 +31,9 @@ class ExciseAndRepaintSampleMakerArguments(BaseExciseSampleMakerArguments):
2831

2932
algorithm: str = "excise_and_repaint"
3033

34+
# in Angstrom: generated atoms within this radius from the central atom will be removed.
35+
sample_edit_radius: Optional[float] = None
36+
3137

3238
class ExciseAndRepaintSampleMaker(BaseExciseSampleMaker):
3339
"""Sample maker for the excise and repaint approach.
@@ -67,6 +73,11 @@ def __init__(
6773
"substructure requested in the sample_maker configuration (ie 'number_of_samples_per_substructure'). "
6874
"The configuration currently asks for inconsistent things. Review input.")
6975

76+
self.samples_should_be_edited = False
77+
if sample_maker_arguments.sample_edit_radius is not None:
78+
self.samples_should_be_edited = True
79+
self.sample_edit_radius = sample_maker_arguments.sample_edit_radius
80+
7081
self.sample_noise_parameters = noise_parameters
7182
self.sampling_parameters = sampling_parameters
7283
self.diffusion_model = diffusion_model
@@ -143,6 +154,11 @@ def make_samples_from_constrained_substructure(
143154
atom at the center of the excised region.
144155
list_info: list of samples additional information.
145156
"""
157+
number_of_constrained_atoms = len(substructure.X)
158+
assert active_atom_index < number_of_constrained_atoms, \
159+
("The active atom index is larger than the number of constrained atoms: "
160+
"this should be impossible, something is wrong. Review code!")
161+
146162
sampling_constraints = self.create_sampling_constraints(substructure)
147163
generator = ConstrainedLangevinGenerator(
148164
noise_parameters=self.sample_noise_parameters,
@@ -160,6 +176,14 @@ def make_samples_from_constrained_substructure(
160176
new_structures = self.torch_batch_axl_to_list_of_numpy_axl(
161177
generated_samples["original_axl"]
162178
)
179+
if self.samples_should_be_edited:
180+
# Edit the sampled structures in place.
181+
new_structures = [self.edit_generated_structure(sampled_structure,
182+
active_atom_index,
183+
number_of_constrained_atoms,
184+
self.sample_edit_radius)
185+
for sampled_structure in new_structures]
186+
163187
# Since the order of the atoms in the constrained substructure are
164188
# explicitly enforced, the index of the active atom is the same in the
165189
# constrained substructure and in the sample.
@@ -173,3 +197,44 @@ def make_samples_from_constrained_substructure(
173197
def filter_made_samples(self, structures: List[AXL]) -> List[AXL]:
174198
"""Return identical structures."""
175199
return structures
200+
201+
@staticmethod
202+
def edit_generated_structure(sampled_structure: AXL,
203+
active_atom_index: int,
204+
number_of_constrained_atoms: int,
205+
sample_edit_radius: float) -> AXL:
206+
"""Edit generated structure.
207+
208+
This method removes generated atoms that are within a sphere of radius "sample_edit_radius" around
209+
the active atom. It is assumed that the first "number_of_constrained_atoms" are the constrained atoms;
210+
these should not be edited out!
211+
212+
Args:
213+
sampled_structure: generated sampled structure
214+
number_of_constrained_atoms: number of atoms that are constrained and should not be removed.
215+
active_atom_index: index of the "active atom" in the input sample.
216+
sample_edit_radius: radius of exclusion sphere around the active index where
217+
generated atoms must be removed.
218+
219+
Returns:
220+
edited_sampled_structure: the edited sampled structure
221+
"""
222+
central_atom_relative_coordinates = sampled_structure.X[active_atom_index]
223+
distances_from_central_atom = get_distances_from_reference_point(
224+
sampled_structure.X, central_atom_relative_coordinates, sampled_structure.L
225+
)
226+
227+
number_of_atoms = len(sampled_structure.X)
228+
229+
constrained_atoms_mask = np.zeros(number_of_atoms, dtype=bool)
230+
constrained_atoms_mask[:number_of_constrained_atoms] = True
231+
232+
outside_radius_mask = distances_from_central_atom > sample_edit_radius
233+
234+
keep_mask = np.logical_or(constrained_atoms_mask, outside_radius_mask)
235+
236+
edited_structure = AXL(A=sampled_structure.A[keep_mask],
237+
X=sampled_structure.X[keep_mask],
238+
L=sampled_structure.L)
239+
240+
return edited_structure

0 commit comments

Comments
 (0)