|
| 1 | +from typing import Any, AnyStr, Dict, List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.atom_selector.atom_selector_factory import \ |
| 6 | + create_atom_selector_parameters |
| 7 | +from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.excisor.excisor_factory import \ |
| 8 | + create_excisor_parameters |
| 9 | +from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.sample_maker.base_sample_maker import \ |
| 10 | + BaseSampleMaker |
| 11 | +from diffusion_for_multi_scale_molecular_dynamics.active_learning_loop.sample_maker.sample_maker_factory import ( |
| 12 | + create_sample_maker, create_sample_maker_parameters) |
| 13 | +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ |
| 14 | + PredictorCorrectorSamplingParameters |
| 15 | +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ |
| 16 | + ScoreNetwork |
| 17 | +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ |
| 18 | + NoiseParameters |
| 19 | +from diffusion_for_multi_scale_molecular_dynamics.sample_diffusion import \ |
| 20 | + get_axl_network |
| 21 | + |
| 22 | + |
| 23 | +def get_repaint_parameters( |
| 24 | + sampling_dictionary: Dict[AnyStr, Any], |
| 25 | + element_list: List[str], |
| 26 | + path_to_score_network_checkpoint: Optional[str] = None, |
| 27 | +) -> Tuple[ |
| 28 | + Union[NoiseParameters, None], |
| 29 | + Union[PredictorCorrectorSamplingParameters, None], |
| 30 | + Union[ScoreNetwork, None], |
| 31 | + str, |
| 32 | +]: |
| 33 | + """Get repaint parameters. |
| 34 | +
|
| 35 | + This convenience method is responsible for extracting the relevant configuration objects in the |
| 36 | + case that the sample maker algorithm is "Excise and Repaint", and to return a "None" default for |
| 37 | + these configuration objects if a different algorithm is used. |
| 38 | +
|
| 39 | + Args: |
| 40 | + sampling_dictionary: Dictionary of sampling parameters, as read in from a yaml configuration file. |
| 41 | + element_list: List of element names. |
| 42 | + path_to_score_network_checkpoint: Path to score network checkpoint. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + noise_parameters: a NoiseParameters object if the config is present, otherwise None. |
| 46 | + sampling_parameters: a PredictorCorrectorSamplingParameters object if the config is present, otherwise None. |
| 47 | + axl_network: a Score Network object to draw constrained samples if the config is present, otherwise None. |
| 48 | + device: a string indicating which device should be used: either cpu or cuda. |
| 49 | + """ |
| 50 | + algorithm = sampling_dictionary["algorithm"] |
| 51 | + # Default values |
| 52 | + device = "cpu" |
| 53 | + axl_network = None |
| 54 | + noise_parameters = None |
| 55 | + sampling_parameters = None |
| 56 | + if algorithm != "excise_and_repaint": |
| 57 | + return noise_parameters, sampling_parameters, axl_network, device |
| 58 | + |
| 59 | + if torch.cuda.is_available(): |
| 60 | + device = "cuda" |
| 61 | + assert ( |
| 62 | + path_to_score_network_checkpoint is not None |
| 63 | + ), "A path to a valid score network checkpoint must be provided to use 'excise_and_repaint'." |
| 64 | + axl_network = get_axl_network(path_to_score_network_checkpoint) |
| 65 | + |
| 66 | + assert ( |
| 67 | + "noise" in sampling_dictionary |
| 68 | + ), "A 'noise' configuration must be defined in the 'sampling' field in order to use 'excise_and_repaint'." |
| 69 | + |
| 70 | + noise_dictionary = sampling_dictionary["noise"] |
| 71 | + noise_parameters = NoiseParameters(**noise_dictionary) |
| 72 | + |
| 73 | + assert "repaint_generator" in sampling_dictionary, ( |
| 74 | + "A 'repaint_generator' configuration must be defined in the 'sampling' field in order to use " |
| 75 | + "'excise_and_repaint'." |
| 76 | + ) |
| 77 | + |
| 78 | + sampling_generator_dictionary = sampling_dictionary["repaint_generator"] |
| 79 | + |
| 80 | + assert "algorithm" not in sampling_generator_dictionary, ( |
| 81 | + "Do not specify the 'algorithm' for the repaint generator: only the predictor_corrector repaint generator " |
| 82 | + "algorithm is valid and will be automatically selected." |
| 83 | + ) |
| 84 | + sampling_generator_dictionary["algorithm"] = "predictor_corrector" |
| 85 | + |
| 86 | + assert "num_atom_types" not in sampling_generator_dictionary, ( |
| 87 | + "Do not specify the 'num_atom_types' for the repaint generator: the value will be inferred from " |
| 88 | + "the element list." |
| 89 | + ) |
| 90 | + sampling_generator_dictionary["num_atom_types"] = len(element_list) |
| 91 | + |
| 92 | + assert "number_of_samples" not in sampling_generator_dictionary, ( |
| 93 | + "Do not specify the 'number_of_samples' for the repaint generator: the value will be inferred from " |
| 94 | + "the 'number_of_samples_per_substructure' sampling field." |
| 95 | + ) |
| 96 | + sampling_generator_dictionary["number_of_samples"] = sampling_dictionary.get( |
| 97 | + "number_of_samples_per_substructure", 1 |
| 98 | + ) |
| 99 | + |
| 100 | + assert ( |
| 101 | + "use_fixed_lattice_parameters" not in sampling_generator_dictionary |
| 102 | + and "cell_dimensions" not in sampling_generator_dictionary |
| 103 | + ), ( |
| 104 | + "Do not specify 'use_fixed_lattice_parameters' or 'cell_dimensions' for the repaint generator: these values " |
| 105 | + "will be inferred from the sampling field." |
| 106 | + ) |
| 107 | + sampling_generator_dictionary["use_fixed_lattice_parameters"] = ( |
| 108 | + sampling_dictionary.get("sample_box_strategy", "fixed") |
| 109 | + ) |
| 110 | + |
| 111 | + if sampling_generator_dictionary["use_fixed_lattice_parameters"] == "fixed": |
| 112 | + sampling_generator_dictionary["cell_dimensions"] = sampling_dictionary[ |
| 113 | + "sample_box_size" |
| 114 | + ] |
| 115 | + |
| 116 | + sampling_parameters = PredictorCorrectorSamplingParameters( |
| 117 | + **sampling_generator_dictionary |
| 118 | + ) |
| 119 | + |
| 120 | + return noise_parameters, sampling_parameters, axl_network, device |
| 121 | + |
| 122 | + |
| 123 | +def get_sample_maker_from_configuration( |
| 124 | + sampling_dictionary: Dict, |
| 125 | + uncertainty_threshold: float, |
| 126 | + element_list: List[str], |
| 127 | + path_to_score_network_checkpoint: Optional[str] = None, |
| 128 | +) -> BaseSampleMaker: |
| 129 | + """Get sample maker from configuration. |
| 130 | +
|
| 131 | + the sampling dictionary should have the following structure: |
| 132 | +
|
| 133 | + sampling: |
| 134 | + algorithm: ... |
| 135 | + (other sample maker parameters) |
| 136 | +
|
| 137 | + excision [Only if using Excise and *]: |
| 138 | + (excision parameters) |
| 139 | +
|
| 140 | + noise [Only if using Excise and Repaint]: |
| 141 | + (noise parameters) |
| 142 | +
|
| 143 | + repaint_generator [Only if using Excise and Repaint]: |
| 144 | + (constrained sampling parameters) |
| 145 | +
|
| 146 | + Args: |
| 147 | + sampling_dictionary: Dictionary of sampling parameters, as read in from a yaml configuration file. |
| 148 | + uncertainty_threshold: Uncertainty threshold. |
| 149 | + element_list: List of element names. |
| 150 | + path_to_score_network_checkpoint: Path to score network checkpoint. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + sample_maker: A configured Sample Maker instance. |
| 154 | + """ |
| 155 | + # Let's make sure we don't modify the input, which would lead to undesirable side effects! |
| 156 | + sampling_dict = sampling_dictionary.copy() |
| 157 | + |
| 158 | + noise_parameters, sampling_parameters, axl_network, device = get_repaint_parameters( |
| 159 | + sampling_dictionary=sampling_dict, |
| 160 | + element_list=element_list, |
| 161 | + path_to_score_network_checkpoint=path_to_score_network_checkpoint, |
| 162 | + ) |
| 163 | + |
| 164 | + atom_selector_parameter_dictionary = dict( |
| 165 | + algorithm="threshold", uncertainty_threshold=uncertainty_threshold |
| 166 | + ) |
| 167 | + atom_selector_parameters = create_atom_selector_parameters( |
| 168 | + atom_selector_parameter_dictionary |
| 169 | + ) |
| 170 | + |
| 171 | + excisor_parameter_dictionary = sampling_dict.pop("excision", None) |
| 172 | + if excisor_parameter_dictionary is not None: |
| 173 | + excisor_parameters = create_excisor_parameters(excisor_parameter_dictionary) |
| 174 | + else: |
| 175 | + excisor_parameters = None |
| 176 | + |
| 177 | + # Let's extract only the sample_maker configuration, popping out components that don't belong. |
| 178 | + sample_maker_dictionary = sampling_dict.copy() |
| 179 | + sample_maker_dictionary["element_list"] = element_list |
| 180 | + sample_maker_dictionary.pop("noise", None) |
| 181 | + sample_maker_dictionary.pop("repaint_generator", None) |
| 182 | + |
| 183 | + sample_maker_parameters = create_sample_maker_parameters(sample_maker_dictionary) |
| 184 | + |
| 185 | + sample_maker = create_sample_maker( |
| 186 | + sample_maker_parameters=sample_maker_parameters, |
| 187 | + atom_selector_parameters=atom_selector_parameters, |
| 188 | + excisor_parameters=excisor_parameters, |
| 189 | + noise_parameters=noise_parameters, |
| 190 | + sampling_parameters=sampling_parameters, |
| 191 | + diffusion_model=axl_network, |
| 192 | + device=device, |
| 193 | + ) |
| 194 | + return sample_maker |
0 commit comments