diff --git a/mlxtend/frequent_patterns/__init__.py b/mlxtend/frequent_patterns/__init__.py index 683f87b03..b6f2689c4 100644 --- a/mlxtend/frequent_patterns/__init__.py +++ b/mlxtend/frequent_patterns/__init__.py @@ -9,5 +9,6 @@ from .fpgrowth import fpgrowth from .fpmax import fpmax from .hmine import hmine +from .pipeline import RuleExtractor __all__ = ["apriori", "association_rules", "fpgrowth", "fpmax", "hmine"] diff --git a/mlxtend/frequent_patterns/pipeline.py b/mlxtend/frequent_patterns/pipeline.py new file mode 100644 index 000000000..1f5a9f419 --- /dev/null +++ b/mlxtend/frequent_patterns/pipeline.py @@ -0,0 +1,37 @@ +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, TransformerMixin + +from .apriori import apriori +from .association_rules import association_rules + + +class RuleExtractor(BaseEstimator, TransformerMixin): + def __init__(self, min_support=0.1, metric="confidence", min_threshold=0.8): + self.min_support = min_support + self.metric = metric + self.min_threshold = min_threshold + + def fit(self, X, y=None): + X_df = ( + X.astype(bool) + if isinstance(X, pd.DataFrame) + else pd.DataFrame(X).astype(bool) + ) + self.frequent_itemsets_ = apriori( + X_df, min_support=self.min_support, use_colnames=True + ) + return self + + def transform(self, X): + if self.frequent_itemsets_.empty: + return pd.DataFrame( + columns=["antecedents", "consequents"] + + ["support", "confidence", "lift", "leverage", "conviction"] + ) + with np.errstate(divide="ignore", invalid="ignore"): + return association_rules( + self.frequent_itemsets_, + metric=self.metric, + min_threshold=self.min_threshold, + ) diff --git a/mlxtend/frequent_patterns/tests/test_pipeline.py b/mlxtend/frequent_patterns/tests/test_pipeline.py new file mode 100644 index 000000000..b428fdf4a --- /dev/null +++ b/mlxtend/frequent_patterns/tests/test_pipeline.py @@ -0,0 +1,25 @@ +import pandas as pd +import pytest +from sklearn.pipeline import Pipeline + +from mlxtend.frequent_patterns import RuleExtractor + + +def test_rule_extractor_basic(): + data = pd.DataFrame([[1, 0, 1], [1, 1, 1], [0, 1, 1]], columns=["A", "B", "C"]) + + pipe = Pipeline([("extractor", RuleExtractor(min_support=0.1))]) + + rules = pipe.fit_transform(data) + + assert isinstance(rules, pd.DataFrame) + assert not rules.empty + assert "antecedents" in rules.columns + assert "consequents" in rules.columns + + +def test_rule_extractor_empty(): + data = pd.DataFrame([[1, 0], [0, 1]], columns=["A", "B"]) + extractor = RuleExtractor(min_support=0.9) + rules = extractor.fit_transform(data) + assert rules.empty