Skip to content

Commit db288c5

Browse files
committed
Update .predict_mean() to .predict()
1 parent 7b17732 commit db288c5

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

autoemulate/core/model_selection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def bootstrap(
156156
x: TensorLike,
157157
y: TensorLike,
158158
n_bootstraps: int | None = 100,
159-
n_samples: int = 100,
159+
n_samples: int = 1000,
160160
device: str | torch.device = "cpu",
161161
metrics: list[TorchMetrics] | None = None,
162162
) -> dict[str, tuple[float, float]]:
@@ -177,7 +177,7 @@ def bootstrap(
177177
Defaults to 100.
178178
n_samples: int
179179
Number of samples to generate to predict mean when emulator does not have a
180-
mean directly available. Defaults to 100.
180+
mean directly available. Defaults to 1000.
181181
device: str | torch.device
182182
The device to use for computations. Default is "cpu".
183183
metrics: list[MetricConfig] | None
@@ -200,7 +200,7 @@ def bootstrap(
200200
y_pred = model.predict(x)
201201
results = {}
202202
for metric in metrics:
203-
score = evaluate(y_pred, y, metric)
203+
score = evaluate(y_pred, y, metric=metric, n_samples=n_samples)
204204
results[metric.name] = (score, float("nan"))
205205
return results
206206

@@ -218,11 +218,13 @@ def bootstrap(
218218
y_bootstrap = y[idxs]
219219

220220
# Make predictions
221-
y_pred = model.predict_mean(x_bootstrap, n_samples=n_samples)
221+
y_pred = model.predict(x_bootstrap)
222222

223223
# Compute metrics for this bootstrap sample
224224
for metric in metrics:
225-
metric_scores[metric.name][i] = evaluate(y_pred, y_bootstrap, metric)
225+
metric_scores[metric.name][i] = evaluate(
226+
y_pred, y_bootstrap, metric=metric, n_samples=n_samples
227+
)
226228

227229
# Return mean and std for each metric
228230
return {

0 commit comments

Comments
 (0)