Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlxtend/frequent_patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from .fpgrowth import fpgrowth
from .fpmax import fpmax
from .hmine import hmine
from .pipeline import RuleExtractor

__all__ = ["apriori", "association_rules", "fpgrowth", "fpmax", "hmine"]
37 changes: 37 additions & 0 deletions mlxtend/frequent_patterns/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

from .apriori import apriori
from .association_rules import association_rules


class RuleExtractor(BaseEstimator, TransformerMixin):
def __init__(self, min_support=0.1, metric="confidence", min_threshold=0.8):
self.min_support = min_support
self.metric = metric
self.min_threshold = min_threshold

def fit(self, X, y=None):
X_df = (
X.astype(bool)
if isinstance(X, pd.DataFrame)
else pd.DataFrame(X).astype(bool)
)
self.frequent_itemsets_ = apriori(
X_df, min_support=self.min_support, use_colnames=True
)
return self

def transform(self, X):
if self.frequent_itemsets_.empty:
return pd.DataFrame(
columns=["antecedents", "consequents"]
+ ["support", "confidence", "lift", "leverage", "conviction"]
)
with np.errstate(divide="ignore", invalid="ignore"):
return association_rules(
self.frequent_itemsets_,
metric=self.metric,
min_threshold=self.min_threshold,
)
25 changes: 25 additions & 0 deletions mlxtend/frequent_patterns/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd
import pytest
from sklearn.pipeline import Pipeline

from mlxtend.frequent_patterns import RuleExtractor


def test_rule_extractor_basic():
data = pd.DataFrame([[1, 0, 1], [1, 1, 1], [0, 1, 1]], columns=["A", "B", "C"])

pipe = Pipeline([("extractor", RuleExtractor(min_support=0.1))])

rules = pipe.fit_transform(data)

assert isinstance(rules, pd.DataFrame)
assert not rules.empty
assert "antecedents" in rules.columns
assert "consequents" in rules.columns


def test_rule_extractor_empty():
data = pd.DataFrame([[1, 0], [0, 1]], columns=["A", "B"])
extractor = RuleExtractor(min_support=0.9)
rules = extractor.fit_transform(data)
assert rules.empty
Loading