@@ -112,6 +112,8 @@ def append_simulations(
112112 With default settings, this is not used at all for `NRE`. Only when
113113 the user later on requests `.train(discard_prior_samples=True)`, we
114114 use these indices to find which training data stemmed from the prior.
115+ algorithm: Which algorithm is used. This is used to give a more informative
116+ warning or error message when invalid simulations are found.
115117 data_device: Where to store the data, default is on the same device where
116118 the training is happening. If training a large dataset on a GPU with not
117119 much VRAM can set to 'cpu' to store data on system memory instead.
@@ -153,8 +155,16 @@ def train(
153155
154156 Args:
155157 num_atoms: Number of atoms to use for classification.
156- exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
157- during training. Expect errors, silent or explicit, when `False`.
158+ training_batch_size: Training batch size.
159+ learning_rate: Learning rate for Adam optimizer.
160+ validation_fraction: The fraction of data to use for validation.
161+ stop_after_epochs: The number of epochs to wait for improvement on the
162+ validation set before terminating training.
163+ max_num_epochs: Maximum number of epochs to run. If reached, we stop
164+ training even when the validation loss is still decreasing. Otherwise,
165+ we train until validation loss increases (see also `stop_after_epochs`).
166+ clip_max_norm: Value at which to clip the total gradient norm in order to
167+ prevent exploding gradients. Use None for no clipping.
158168 resume_training: Can be used in case training time is limited, e.g. on a
159169 cluster. If `True`, the split between train and validation set, the
160170 optimizer, the number of epochs, and the best validation log-prob will
@@ -164,6 +174,8 @@ def train(
164174 samples.
165175 retrain_from_scratch: Whether to retrain the conditional density
166176 estimator for the posterior from scratch each round.
177+ show_train_summary: Whether to print the number of epochs and validation
178+ loss after the training.
167179 dataloader_kwargs: Additional or updated kwargs to be passed to the training
168180 and validation dataloaders (like, e.g., a collate_fn).
169181 loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
0 commit comments