Skip to content

Commit 3bfcaf8

Browse files
committed
added history matching and test
1 parent 005b3c0 commit 3bfcaf8

File tree

4 files changed

+165
-1231
lines changed

4 files changed

+165
-1231
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ notebooks/
22
__pycache__/
33
.pytest_cache/
44
dist/
5+
my_tests/
56

67
# Ignore Sphinx build artifacts
78
docs/build/

autoemulate/history_matching.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
4+
def history_matching(obs, expectations, threshold=3.0, discrepancy=0.0, rank=1):
5+
"""
6+
Perform history matching to compute implausibility and identify NROY and RO points.
7+
8+
Parameters:
9+
obs (tuple): Observations as (mean, variance).
10+
expectations (tuple): Predicted (mean, variance).
11+
threshold (float): Implausibility threshold for NROY classification.
12+
discrepancy (float or ndarray): Discrepancy value(s).
13+
rank (int): Rank for implausibility calculation.
14+
15+
Returns:
16+
dict: Contains implausibility (I), NROY indices, and RO indices.
17+
"""
18+
obs_mean, obs_var = np.atleast_1d(obs[0]), np.atleast_1d(obs[1])
19+
pred_mean, pred_var = np.atleast_1d(expectations[0]), np.atleast_1d(expectations[1])
20+
21+
discrepancy = np.atleast_1d(discrepancy)
22+
n_obs = len(obs_mean)
23+
rank = min(max(rank, 0), n_obs - 1)
24+
if discrepancy.size == 1:
25+
discrepancy = np.full(n_obs, discrepancy)
26+
27+
Vs = pred_var + discrepancy + obs_var
28+
I = np.abs(obs_mean - pred_mean) / np.sqrt(Vs)
29+
30+
NROY = np.where(I <= threshold)[0]
31+
RO = np.where(I > threshold)[0]
32+
33+
return {"I": I, "NROY": list(NROY), "RO": list(RO)}

docs/tutorials/01_start.ipynb

Lines changed: 75 additions & 1231 deletions
Large diffs are not rendered by default.

tests/test_history_matching.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autoemulate.history_matching import history_matching
5+
6+
7+
@pytest.fixture
8+
def sample_data_2d():
9+
pred_mean = np.array([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1], [4.0, 4.1], [5.0, 5.1]])
10+
pred_std = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4], [0.5, 0.5]])
11+
pred_var = np.square(pred_std)
12+
expectations = (pred_mean, pred_var)
13+
obs = [(1.5, 0.1), (2.5, 0.2)]
14+
return expectations, obs
15+
16+
17+
@pytest.fixture
18+
def sample_data_1d():
19+
pred_mean = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
20+
pred_std = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
21+
pred_var = np.square(pred_std)
22+
expectations = (pred_mean, pred_var)
23+
obs = [1.5, 10]
24+
return expectations, obs
25+
26+
27+
def test_history_matching_1d(sample_data_1d):
28+
expectations, obs = sample_data_1d
29+
result = history_matching(expectations=expectations, obs=obs, threshold=1.0)
30+
assert "NROY" in result # Ensure the key exists in the result
31+
assert isinstance(result["NROY"], list) # Validate that NROY is a list
32+
assert len(result["NROY"]) > 0 # Ensure the list is not empty
33+
34+
35+
def test_history_matching_threshold_1d(sample_data_1d):
36+
expectations, obs = sample_data_1d
37+
result = history_matching(expectations=expectations, obs=obs, threshold=0.5)
38+
assert "NROY" in result
39+
assert isinstance(result["NROY"], list)
40+
assert len(result["NROY"]) <= len(expectations[0])
41+
42+
43+
def test_history_matching_2d(sample_data_2d):
44+
expectations, obs = sample_data_2d
45+
result = history_matching(expectations=expectations, obs=obs, threshold=1.0)
46+
assert "NROY" in result # Ensure the key exists in the result
47+
assert isinstance(result["NROY"], list) # Validate that NROY is a list
48+
assert len(result["NROY"]) > 0 # Ensure the list is not empty
49+
50+
51+
def test_history_matching_threshold_2d(sample_data_2d):
52+
expectations, obs = sample_data_2d
53+
result = history_matching(expectations=expectations, obs=obs, threshold=0.5)
54+
assert "NROY" in result
55+
assert isinstance(result["NROY"], list)
56+
assert len(result["NROY"]) <= len(expectations[0])

0 commit comments

Comments
 (0)