Skip to content

Commit 88b8ce6

Browse files
Refactor JaccardIndex and StringContainment metrics to use ReductionInstanceMetric (#1816)
* Refactor JaccardIndex and StringContainment metrics to use ReductionInstanceMetric Signed-off-by: elronbandel <[email protected]> * Add JaccardIndexWords metric and corresponding tests Signed-off-by: elronbandel <[email protected]> * Update StringContainment metric to use Any type for predictions and references Signed-off-by: elronbandel <[email protected]> * Update spearmanr and mean squared error Signed-off-by: elronbandel <[email protected]> * Return old metric and fix old tests Signed-off-by: elronbandel <[email protected]> * Update documentation and test of JaccardIndex Renamed JacardIndexWords to JaccardIndexString and made it more general by default (not providing the split). Also moved to regex tokenzier in jacard_index_words to make it work with multiple spaces. Signed-off-by: Yoav Katz <[email protected]> * Added documentation and checking for MeanSquaredError Signed-off-by: Yoav Katz <[email protected]> * Updated json Signed-off-by: Yoav Katz <[email protected]> * Fix spearmanr Signed-off-by: elronbandel <[email protected]> * Added RMSE metric Signed-off-by: Yoav Katz <[email protected]> * Format Signed-off-by: elronbandel <[email protected]> * Revert naming Signed-off-by: elronbandel <[email protected]> * Add spearmanr_p_value to evaluation metrics in TestAPI Signed-off-by: elronbandel <[email protected]> * FIx tests Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: elronbandel <[email protected]> Signed-off-by: Yoav Katz <[email protected]> Co-authored-by: Yoav Katz <[email protected]> Co-authored-by: Yoav Katz <[email protected]>
1 parent e5c72b7 commit 88b8ce6

File tree

12 files changed

+442
-117
lines changed

12 files changed

+442
-117
lines changed

prepare/metrics/jaccard_index.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
from unitxt import add_to_catalog
2-
from unitxt.metrics import JaccardIndex
2+
from unitxt.metrics import JaccardIndex, JaccardIndexString
3+
from unitxt.string_operators import RegexSplit
34
from unitxt.test_utils.metrics import test_metric
45

5-
metric = JaccardIndex()
6+
metric = JaccardIndex(
7+
__description__="""JaccardIndex metric that operates on predictions and references that are list of elements.
8+
For each prediction, it calculates the score as Intersect(prediction,reference)/Union(prediction,reference).
9+
If multiple references exist, it takes for each predictions, the best ratio achieved by one of the references.
10+
It then aggregates the mean over all references.
11+
12+
Note the metric assumes the prediction and references are either a set of elements or a list of elements.
13+
If the prediction and references are strings use JaccardIndexString metrics like "metrics.jaccard_index_words" .
14+
"""
15+
)
616

717
predictions = [["A", "B", "C"]]
818
references = [[["B", "A", "D"]]]
@@ -27,3 +37,39 @@
2737
)
2838

2939
add_to_catalog(metric, "metrics.jaccard_index", overwrite=True)
40+
41+
42+
metric = JaccardIndexString(
43+
__description__="""JaccardIndex metric that operates on prediction and references that are strings.
44+
It first splits the the string into words using space as a separator.
45+
46+
For each prediction, it calculates the ratio Intersect(prediction_words,reference_words)/Union(prediction_words,reference_words).
47+
If multiple references exist, it takes the best ratio achieved by one of the references.
48+
49+
""",
50+
splitter=RegexSplit(by=r"\s+"),
51+
)
52+
53+
predictions = ["A B C"]
54+
references = [["B A D"]]
55+
56+
instance_targets = [
57+
{"jaccard_index": 0.5, "score": 0.5, "score_name": "jaccard_index"},
58+
]
59+
60+
global_target = {
61+
"jaccard_index": 0.5,
62+
"score": 0.5,
63+
"score_name": "jaccard_index",
64+
"num_of_instances": 1,
65+
}
66+
67+
outputs = test_metric(
68+
metric=metric,
69+
predictions=predictions,
70+
references=references,
71+
instance_targets=instance_targets,
72+
global_target=global_target,
73+
)
74+
75+
add_to_catalog(metric, "metrics.jaccard_index_words", overwrite=True)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from unitxt import add_to_catalog
2+
from unitxt.metrics import MeanSquaredError, RootMeanSquaredError
3+
from unitxt.test_utils.metrics import test_metric
4+
5+
metric = MeanSquaredError(
6+
__description__="""Metric to calculate the mean squared error (MSE) between the prediction and the reference values.
7+
8+
Assume both the prediction and reference are floats.
9+
10+
Support only a single reference per prediction .
11+
"""
12+
)
13+
predictions = [1.0, 2.0, 1.0]
14+
references = [[-1.0], [1.0], [0.0]]
15+
16+
instance_targets = [
17+
{"mean_squared_error": 4.0, "score": 4.0, "score_name": "mean_squared_error"},
18+
{"mean_squared_error": 1.0, "score": 1.0, "score_name": "mean_squared_error"},
19+
{"mean_squared_error": 1.0, "score": 1.0, "score_name": "mean_squared_error"},
20+
]
21+
22+
global_target = {
23+
"mean_squared_error": 2.0,
24+
"score": 2.0,
25+
"score_name": "mean_squared_error",
26+
"mean_squared_error_ci_low": 1.0,
27+
"mean_squared_error_ci_high": 4.0,
28+
"score_ci_low": 1.0,
29+
"score_ci_high": 4.0,
30+
"num_of_instances": 3,
31+
}
32+
33+
outputs = test_metric(
34+
metric=metric,
35+
predictions=predictions,
36+
references=references,
37+
instance_targets=instance_targets,
38+
global_target=global_target,
39+
)
40+
41+
add_to_catalog(metric, "metrics.mean_squared_error", overwrite=True)
42+
43+
44+
metric = RootMeanSquaredError(
45+
__description__="""Metric to calculate the root mean squared error (RMSE) between the prediction and the reference values.
46+
47+
Assume both the prediction and reference are floats.
48+
49+
Support only a single reference per prediction .
50+
"""
51+
)
52+
53+
54+
instance_targets = [
55+
{
56+
"root_mean_squared_error": 2.0,
57+
"score": 2.0,
58+
"score_name": "root_mean_squared_error",
59+
},
60+
{
61+
"root_mean_squared_error": 1.0,
62+
"score": 1.0,
63+
"score_name": "root_mean_squared_error",
64+
},
65+
{
66+
"root_mean_squared_error": 1.0,
67+
"score": 1.0,
68+
"score_name": "root_mean_squared_error",
69+
},
70+
]
71+
72+
global_target = {
73+
"root_mean_squared_error": 1.41,
74+
"score": 1.41,
75+
"score_name": "root_mean_squared_error",
76+
"root_mean_squared_error_ci_low": 1.0,
77+
"root_mean_squared_error_ci_high": 2.0,
78+
"score_ci_low": 1.0,
79+
"score_ci_high": 2.0,
80+
"num_of_instances": 3,
81+
}
82+
83+
outputs = test_metric(
84+
metric=metric,
85+
predictions=predictions,
86+
references=references,
87+
instance_targets=instance_targets,
88+
global_target=global_target,
89+
)
90+
91+
add_to_catalog(metric, "metrics.root_mean_squared_error", overwrite=True)

prepare/metrics/spearman.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,30 @@
11
import numpy as np
22
from unitxt import add_to_catalog
3-
from unitxt.metrics import MetricPipeline, Spearmanr
4-
from unitxt.operators import Copy
3+
from unitxt.metrics import Spearmanr
54
from unitxt.test_utils.metrics import test_metric
65

7-
metric = MetricPipeline(
8-
main_score="spearmanr",
9-
preprocess_steps=[
10-
Copy(field="references/0", to_field="references"),
11-
],
12-
metric=Spearmanr(),
13-
prediction_type=float,
14-
)
15-
16-
predictions = [1.0, 2.0, 1.0]
17-
references = [[-1.0], [1.0], [0.0]]
6+
metric = Spearmanr(n_resamples=100)
7+
predictions = [1.0, 3.0, 1.1, 2.0, 8.0]
8+
references = [[-1.0], [1.0], [0.10], [2.0], [6.0]]
189

1910
instance_targets = [
2011
{"spearmanr": np.nan, "score": np.nan, "score_name": "spearmanr"},
2112
{"spearmanr": np.nan, "score": np.nan, "score_name": "spearmanr"},
2213
{"spearmanr": np.nan, "score": np.nan, "score_name": "spearmanr"},
14+
{"spearmanr": np.nan, "score": np.nan, "score_name": "spearmanr"},
15+
{"spearmanr": np.nan, "score": np.nan, "score_name": "spearmanr"},
2316
]
2417

2518
global_target = {
26-
"spearmanr": 0.87,
27-
"score": 0.87,
19+
"num_of_instances": 5,
20+
"score": 0.9,
21+
"score_ci_high": 1.0,
22+
"score_ci_low": 0.11,
2823
"score_name": "spearmanr",
29-
"spearmanr_ci_low": np.nan,
30-
"spearmanr_ci_high": np.nan,
31-
"score_ci_low": np.nan,
32-
"score_ci_high": np.nan,
33-
"num_of_instances": 3,
24+
"spearmanr": 0.9,
25+
"spearmanr_ci_high": 1.0,
26+
"spearmanr_ci_low": 0.11,
27+
"spearmanr_p_value": 0.04,
3428
}
3529

3630
outputs = test_metric(
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
{
2-
"__type__": "jaccard_index"
2+
"__type__": "jaccard_index",
3+
"__description__": "JaccardIndex metric that operates on predictions and references that are list of elements.\n For each prediction, it calculates the score as Intersect(prediction,reference)/Union(prediction,reference).\n If multiple references exist, it takes for each predictions, the best ratio achieved by one of the references.\n It then aggregates the mean over all references.\n\n Note the metric assumes the prediction and references are either a set of elements or a list of elements.\n If the prediction and references are strings use JaccardIndexString metrics like \"metrics.jaccard_index_words\" .\n "
34
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"__type__": "jaccard_index_string",
3+
"__description__": "JaccardIndex metric that operates on prediction and references that are strings.\n It first splits the the string into words using space as a separator.\n\n For each prediction, it calculates the ratio Intersect(prediction_words,reference_words)/Union(prediction_words,reference_words).\n If multiple references exist, it takes the best ratio achieved by one of the references.\n\n ",
4+
"splitter": {
5+
"__type__": "regex_split",
6+
"by": "\\s+"
7+
}
8+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"__type__": "mean_squared_error",
3+
"__description__": "Metric to calculate the mean squared error (MSE) between the prediction and the reference values.\n\n Assume both the prediction and reference are floats.\n\n Support only a single reference per prediction .\n "
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"__type__": "root_mean_squared_error",
3+
"__description__": "Metric to calculate the root mean squared error (RMSE) between the prediction and the reference values.\n\n Assume both the prediction and reference are floats.\n\n Support only a single reference per prediction .\n "
4+
}
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,4 @@
11
{
2-
"__type__": "metric_pipeline",
3-
"main_score": "spearmanr",
4-
"preprocess_steps": [
5-
{
6-
"__type__": "copy",
7-
"field": "references/0",
8-
"to_field": "references"
9-
}
10-
],
11-
"metric": {
12-
"__type__": "spearmanr"
13-
},
14-
"prediction_type": "float"
2+
"__type__": "spearmanr",
3+
"n_resamples": 100
154
}

0 commit comments

Comments
 (0)