Skip to content

Commit 99bab84

Browse files
authored
add include_tests flag for presets (#1530)
* add include_tests flag for presets * add tests
1 parent 89d40f4 commit 99bab84

File tree

8 files changed

+130
-52
lines changed

8 files changed

+130
-52
lines changed

src/evidently/future/container.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020

2121
class MetricContainer(abc.ABC):
22+
def __init__(self, include_tests: bool = True):
23+
self.include_tests = include_tests
24+
2225
@abc.abstractmethod
2326
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
2427
raise NotImplementedError()
@@ -50,7 +53,15 @@ def list_metrics(self, context: "Context") -> Generator[Metric, None, None]:
5053
else:
5154
raise ValueError(f"invalid metric type {type(item)}")
5255

56+
def _get_tests(self, tests):
57+
if tests is not None:
58+
return tests
59+
if self.include_tests:
60+
return None
61+
return []
62+
5363

5464
class ColumnMetricContainer(MetricContainer, abc.ABC):
55-
def __init__(self, column: str):
65+
def __init__(self, column: str, include_tests: bool = True):
66+
super().__init__(include_tests=include_tests)
5667
self._column = column

src/evidently/future/generators/column.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
self.columns = columns
3434
self.column_types = column_types
3535
self.metric_kwargs = metric_kwargs or {}
36+
super().__init__(include_tests=True)
3637

3738
def _instantiate_metric(self, column: str) -> MetricOrContainer:
3839
return self.metric_type(column=column, **self.metric_kwargs)

src/evidently/future/metrics/group_by.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class GroupBy(MetricContainer):
6767
def __init__(self, metric: Metric, column_name: str):
6868
self._column_name = column_name
6969
self._metric = metric
70+
super().__init__(True)
7071

7172
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
7273
labels = context.column(self._column_name).labels()

src/evidently/future/presets/classification.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
tnr_tests: SingleValueMetricTests = None,
5757
fpr_tests: SingleValueMetricTests = None,
5858
fnr_tests: SingleValueMetricTests = None,
59+
include_tests: bool = True,
5960
):
6061
self._accuracy_tests = accuracy_tests
6162
self._precision_tests = precision_tests
@@ -71,6 +72,7 @@ def __init__(
7172
self._conf_matrix = conf_matrix
7273
self._pr_curve = pr_curve
7374
self._pr_table = pr_table
75+
super().__init__(include_tests=include_tests)
7476

7577
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
7678
classification = context.data_definition.get_classification("default")
@@ -80,25 +82,25 @@ def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
8082
metrics: List[Metric]
8183

8284
metrics = [
83-
Accuracy(probas_threshold=self._probas_threshold, tests=self._accuracy_tests),
84-
Precision(probas_threshold=self._probas_threshold, tests=self._precision_tests),
85-
Recall(probas_threshold=self._probas_threshold, tests=self._recall_tests),
86-
F1Score(probas_threshold=self._probas_threshold, tests=self._f1score_tests),
85+
Accuracy(probas_threshold=self._probas_threshold, tests=self._get_tests(self._accuracy_tests)),
86+
Precision(probas_threshold=self._probas_threshold, tests=self._get_tests(self._precision_tests)),
87+
Recall(probas_threshold=self._probas_threshold, tests=self._get_tests(self._recall_tests)),
88+
F1Score(probas_threshold=self._probas_threshold, tests=self._get_tests(self._f1score_tests)),
8789
]
8890
if classification.prediction_probas is not None:
8991
metrics.extend(
9092
[
91-
RocAuc(probas_threshold=self._probas_threshold, tests=self._rocauc_test),
92-
LogLoss(probas_threshold=self._probas_threshold, tests=self._logloss_test),
93+
RocAuc(probas_threshold=self._probas_threshold, tests=self._get_tests(self._rocauc_test)),
94+
LogLoss(probas_threshold=self._probas_threshold, tests=self._get_tests(self._logloss_test)),
9395
]
9496
)
9597
if isinstance(classification, BinaryClassification):
9698
metrics.extend(
9799
[
98-
TPR(probas_threshold=self._probas_threshold, tests=self._tpr_test),
99-
TNR(probas_threshold=self._probas_threshold, tests=self._tnr_test),
100-
FPR(probas_threshold=self._probas_threshold, tests=self._fpr_test),
101-
FNR(probas_threshold=self._probas_threshold, tests=self._fnr_test),
100+
TPR(probas_threshold=self._probas_threshold, tests=self._get_tests(self._tpr_test)),
101+
TNR(probas_threshold=self._probas_threshold, tests=self._get_tests(self._tnr_test)),
102+
FPR(probas_threshold=self._probas_threshold, tests=self._get_tests(self._fpr_test)),
103+
FNR(probas_threshold=self._probas_threshold, tests=self._get_tests(self._fnr_test)),
102104
]
103105
)
104106
return metrics
@@ -144,27 +146,35 @@ def __init__(
144146
precision_tests: ByLabelMetricTests = None,
145147
recall_tests: ByLabelMetricTests = None,
146148
rocauc_tests: ByLabelMetricTests = None,
149+
include_tests: bool = True,
147150
):
148151
self._probas_threshold = probas_threshold
149152
self._k = k
150153
self._f1score_tests = f1score_tests
151154
self._precision_tests = precision_tests
152155
self._recall_tests = recall_tests
153156
self._rocauc_tests = rocauc_tests
157+
super().__init__(include_tests=include_tests)
154158

155159
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
156160
classification = context.data_definition.get_classification("default")
157161
if classification is None:
158162
raise ValueError("Cannot use ClassificationPreset without a classification configration")
159163
return [
160-
F1ByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._f1score_tests),
161-
PrecisionByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._precision_tests),
162-
RecallByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._recall_tests),
164+
F1ByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._get_tests(self._f1score_tests)),
165+
PrecisionByLabel(
166+
probas_threshold=self._probas_threshold, k=self._k, tests=self._get_tests(self._precision_tests)
167+
),
168+
RecallByLabel(
169+
probas_threshold=self._probas_threshold, k=self._k, tests=self._get_tests(self._recall_tests)
170+
),
163171
] + (
164172
[]
165173
if classification.prediction_probas is None
166174
else [
167-
RocAucByLabel(probas_threshold=self._probas_threshold, k=self._k, tests=self._rocauc_tests),
175+
RocAucByLabel(
176+
probas_threshold=self._probas_threshold, k=self._k, tests=self._get_tests(self._rocauc_tests)
177+
),
168178
]
169179
)
170180

@@ -192,6 +202,7 @@ def __init__(
192202
):
193203
self._probas_threshold = probas_threshold
194204
self._k = k
205+
super().__init__(include_tests=True)
195206

196207
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:
197208
return [
@@ -232,7 +243,9 @@ def __init__(
232243
precision_by_label_tests: ByLabelMetricTests = None,
233244
recall_by_label_tests: ByLabelMetricTests = None,
234245
rocauc_by_label_tests: ByLabelMetricTests = None,
246+
include_tests: bool = True,
235247
):
248+
super().__init__(include_tests=include_tests)
236249
self._probas_threshold = probas_threshold
237250
self._quality = ClassificationQuality(
238251
probas_threshold=probas_threshold,
@@ -249,17 +262,19 @@ def __init__(
249262
tnr_tests=tnr_tests,
250263
fpr_tests=fpr_tests,
251264
fnr_tests=fnr_tests,
265+
include_tests=include_tests,
252266
)
253267
self._quality_by_label = ClassificationQualityByLabel(
254268
probas_threshold=probas_threshold,
255269
f1score_tests=f1score_by_label_tests,
256270
precision_tests=precision_by_label_tests,
257271
recall_tests=recall_by_label_tests,
258272
rocauc_tests=rocauc_by_label_tests,
273+
include_tests=include_tests,
259274
)
260275
self._roc_auc: Optional[RocAuc] = RocAuc(
261276
probas_threshold=probas_threshold,
262-
tests=rocauc_tests,
277+
tests=self._get_tests(rocauc_tests),
263278
)
264279

265280
def generate_metrics(self, context: "Context") -> Sequence[MetricOrContainer]:

src/evidently/future/presets/dataset_stats.py

+39-24
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(
5757
q50_tests: SingleValueMetricTests = None,
5858
q75_tests: SingleValueMetricTests = None,
5959
unique_values_count_tests: ByLabelMetricTests = None,
60+
include_tests: bool = True,
6061
):
61-
super().__init__(column=column)
62+
super().__init__(column=column, include_tests=include_tests)
6263
self._row_count_tests = row_count_tests
6364
self._missing_values_count_tests = missing_values_count_tests
6465
self._min_tests = min_tests
@@ -72,23 +73,23 @@ def __init__(
7273

7374
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
7475
metrics: List[Metric] = [
75-
RowCount(tests=self._row_count_tests),
76-
MissingValueCount(column=self._column, tests=self._missing_values_count_tests),
76+
RowCount(tests=self._get_tests(self._row_count_tests)),
77+
MissingValueCount(column=self._column, tests=self._get_tests(self._missing_values_count_tests)),
7778
]
7879
column_type = context.column(self._column).column_type
7980
if column_type == ColumnType.Numerical:
8081
metrics += [
81-
MinValue(column=self._column, tests=self._min_tests),
82-
MaxValue(column=self._column, tests=self._max_tests),
83-
MeanValue(column=self._column, tests=self._mean_tests),
84-
StdValue(column=self._column, tests=self._std_tests),
85-
QuantileValue(column=self._column, quantile=0.25, tests=self._q25_tests),
86-
QuantileValue(column=self._column, quantile=0.5, tests=self._q50_tests),
87-
QuantileValue(column=self._column, quantile=0.75, tests=self._q75_tests),
82+
MinValue(column=self._column, tests=self._get_tests(self._min_tests)),
83+
MaxValue(column=self._column, tests=self._get_tests(self._max_tests)),
84+
MeanValue(column=self._column, tests=self._get_tests(self._mean_tests)),
85+
StdValue(column=self._column, tests=self._get_tests(self._std_tests)),
86+
QuantileValue(column=self._column, quantile=0.25, tests=self._get_tests(self._q25_tests)),
87+
QuantileValue(column=self._column, quantile=0.5, tests=self._get_tests(self._q50_tests)),
88+
QuantileValue(column=self._column, quantile=0.75, tests=self._get_tests(self._q75_tests)),
8889
]
8990
if column_type == ColumnType.Categorical:
9091
metrics += [
91-
UniqueValueCount(column=self._column, tests=self._unique_values_count_tests),
92+
UniqueValueCount(column=self._column, tests=self._get_tests(self._unique_values_count_tests)),
9293
]
9394
if column_type == ColumnType.Datetime:
9495
metrics += [
@@ -313,6 +314,7 @@ def __init__(
313314
empty_column_count_tests: SingleValueMetricTests = None,
314315
constant_columns_count_tests: SingleValueMetricTests = None,
315316
dataset_missing_value_count_tests: SingleValueMetricTests = None,
317+
include_tests: bool = True,
316318
):
317319
self.duplicated_row_count_tests = duplicated_row_count_tests
318320
self.duplicated_column_count_tests = duplicated_column_count_tests
@@ -324,23 +326,24 @@ def __init__(
324326
self.dataset_missing_value_count_tests = dataset_missing_value_count_tests
325327
self.column_count_tests = column_count_tests
326328
self.row_count_tests = row_count_tests
329+
super().__init__(include_tests=include_tests)
327330

328331
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
329332
return [
330-
RowCount(tests=self.row_count_tests),
331-
ColumnCount(tests=self.column_count_tests),
333+
RowCount(tests=self._get_tests(self.row_count_tests)),
334+
ColumnCount(tests=self._get_tests(self.column_count_tests)),
332335
ColumnCount(column_type=ColumnType.Numerical, tests=[]),
333336
ColumnCount(column_type=ColumnType.Categorical, tests=[]),
334337
ColumnCount(column_type=ColumnType.Datetime, tests=[]),
335338
ColumnCount(column_type=ColumnType.Text, tests=[]),
336-
DuplicatedRowCount(tests=self.duplicated_row_count_tests),
337-
DuplicatedColumnsCount(tests=self.duplicated_column_count_tests),
338-
AlmostDuplicatedColumnsCount(tests=self.almost_duplicated_column_count_tests),
339-
AlmostConstantColumnsCount(tests=self.almost_constant_column_count_tests),
340-
EmptyRowsCount(tests=self.empty_row_count_tests),
341-
EmptyColumnsCount(tests=self.empty_column_count_tests),
342-
ConstantColumnsCount(tests=self.constant_columns_count_tests),
343-
DatasetMissingValueCount(tests=self.dataset_missing_value_count_tests),
339+
DuplicatedRowCount(tests=self._get_tests(self.duplicated_row_count_tests)),
340+
DuplicatedColumnsCount(tests=self._get_tests(self.duplicated_column_count_tests)),
341+
AlmostDuplicatedColumnsCount(tests=self._get_tests(self.almost_duplicated_column_count_tests)),
342+
AlmostConstantColumnsCount(tests=self._get_tests(self.almost_constant_column_count_tests)),
343+
EmptyRowsCount(tests=self._get_tests(self.empty_row_count_tests)),
344+
EmptyColumnsCount(tests=self._get_tests(self.empty_column_count_tests)),
345+
ConstantColumnsCount(tests=self._get_tests(self.constant_columns_count_tests)),
346+
DatasetMissingValueCount(tests=self._get_tests(self.dataset_missing_value_count_tests)),
344347
]
345348

346349
def render(
@@ -375,20 +378,27 @@ def __init__(
375378
columns: Optional[List[str]] = None,
376379
row_count_tests: SingleValueMetricTests = None,
377380
column_tests: Optional[Dict[str, ValueStatsTests]] = None,
381+
include_tests: bool = True,
378382
):
379383
self._columns = columns
380384
self._value_stats: List[ValueStats] = []
381385
self._row_count_tests = row_count_tests
382386
self._column_tests = column_tests
387+
super().__init__(include_tests=include_tests)
383388

384389
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
385390
if self._columns is None:
386391
cols = context.data_definition.numerical_descriptors + context.data_definition.categorical_descriptors
387392
else:
388393
cols = self._columns
389-
metrics: List[MetricOrContainer] = [RowCount(tests=self._row_count_tests)]
394+
metrics: List[MetricOrContainer] = [RowCount(tests=self._get_tests(self._row_count_tests))]
390395
self._value_stats = [
391-
ValueStats(column, **(self._column_tests or {}).get(column, ValueStatsTests()).__dict__) for column in cols
396+
ValueStats(
397+
column,
398+
**(self._column_tests or {}).get(column, ValueStatsTests()).__dict__,
399+
include_tests=self.include_tests,
400+
)
401+
for column in cols
392402
]
393403
metrics.extend(list(chain(*[vs.metrics(context)[1:] for vs in self._value_stats])))
394404
return metrics
@@ -419,6 +429,7 @@ def __init__(
419429
constant_columns_count_tests: SingleValueMetricTests = None,
420430
dataset_missing_value_count_tests: SingleValueMetricTests = None,
421431
column_tests: Optional[Dict[str, ValueStatsTests]] = None,
432+
include_tests: bool = True,
422433
):
423434
self.duplicated_row_count_tests = duplicated_row_count_tests
424435
self.duplicated_column_count_tests = duplicated_column_count_tests
@@ -432,6 +443,7 @@ def __init__(
432443
self.row_count_tests = row_count_tests
433444
self._columns = columns
434445
self._column_tests = column_tests
446+
super().__init__(include_tests=include_tests)
435447

436448
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
437449
columns_ = context.data_definition.get_categorical_columns() + context.data_definition.get_numerical_columns()
@@ -446,8 +458,11 @@ def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
446458
empty_column_count_tests=self.empty_column_count_tests,
447459
constant_columns_count_tests=self.constant_columns_count_tests,
448460
dataset_missing_value_count_tests=self.dataset_missing_value_count_tests,
461+
include_tests=self.include_tests,
462+
)
463+
self._text_evals = TextEvals(
464+
self._columns or columns_, column_tests=self._column_tests, include_tests=self.include_tests
449465
)
450-
self._text_evals = TextEvals(self._columns or columns_, column_tests=self._column_tests)
451466
return self._dataset_stats.metrics(context) + self._text_evals.metrics(context)
452467

453468
def render(

src/evidently/future/presets/drift.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
self.embeddings_drift_method = embeddings_drift_method
5454
self.embeddings = embeddings
5555
self.columns = columns
56+
super().__init__(include_tests=True)
5657

5758
def generate_metrics(self, context: Context) -> Sequence[MetricOrContainer]:
5859
types = [ColumnType.Numerical, ColumnType.Categorical, ColumnType.Text]

0 commit comments

Comments
 (0)