Skip to content

Commit 5b5cfac

Browse files
authored
Merge pull request #9 from davidwarshaw/empty-fit-bug
Fixed empty DM on fit bug. Conform to pep8.
2 parents cec498e + b0d18ba commit 5b5cfac

File tree

8 files changed

+131
-56
lines changed

8 files changed

+131
-56
lines changed

hmc/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from .datasets import load_shades_class_hierachy
55
from .datasets import load_shades_data
66
from .metrics import accuracy_score
7+
from .exceptions import *
8+
79

810
__all__ = ["ClassHierarchy",
9-
"DecisionTreeHierarchicalClassifier",
10-
"load_shades_class_hierachy",
11-
"load_shades_data",
12-
"accuracy_score",
13-
"precision_score_ancestors", "recall_score_ancestors",
14-
"precision_score_descendants", "recall_score_descendants",
15-
"f1_score_ancestors", "f1_score_descendants"]
11+
"DecisionTreeHierarchicalClassifier",
12+
"load_shades_class_hierachy",
13+
"load_shades_data",
14+
"accuracy_score",
15+
"precision_score_ancestors", "recall_score_ancestors",
16+
"precision_score_descendants", "recall_score_descendants",
17+
"f1_score_ancestors", "f1_score_descendants",
18+
"NoSamplesForStageWarning", "StageNotFitWarning", "ClassifierNotFitError"]

hmc/datasets.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,30 @@
1010
from hmc import DecisionTreeHierarchicalClassifier
1111

1212
seeds = [
13-
{"node":"dark",
14-
"mu_1":0, "mu_2":10, "mu_3":0,
15-
"sigma_1":4, "sigma_2":4, "sigma_3":4},
16-
{"node":"black",
17-
"mu_1":1, "mu_2":9, "mu_3":1,
18-
"sigma_1":3, "sigma_2":3, "sigma_3":3},
19-
{"node":"gray",
20-
"mu_1":2, "mu_2":8, "mu_3":2,
21-
"sigma_1":2, "sigma_2":2, "sigma_3":2},
22-
{"node":"ash",
23-
"mu_1":3, "mu_2":7, "mu_3":3,
24-
"sigma_1":1, "sigma_2":1, "sigma_3":1},
25-
{"node":"slate",
26-
"mu_1":4, "mu_2":8, "mu_3":2,
27-
"sigma_1":1, "sigma_2":1, "sigma_3":1},
28-
{"node":"light",
29-
"mu_1":10, "mu_2":0, "mu_3":10,
30-
"sigma_1":4, "sigma_2":4, "sigma_3":4},
31-
{"node":"white",
32-
"mu_1":9, "mu_2":1, "mu_3":9,
33-
"sigma_1":3, "sigma_2":3, "sigma_3":3},
13+
{"node": "dark",
14+
"mu_1": 0, "mu_2": 10, "mu_3": 0,
15+
"sigma_1": 4, "sigma_2": 4, "sigma_3": 4},
16+
{"node": "black",
17+
"mu_1": 1, "mu_2": 9, "mu_3": 1,
18+
"sigma_1": 3, "sigma_2": 3, "sigma_3": 3},
19+
{"node": "gray",
20+
"mu_1": 2, "mu_2": 8, "mu_3": 2,
21+
"sigma_1": 2, "sigma_2": 2, "sigma_3": 2},
22+
{"node": "ash",
23+
"mu_1": 3, "mu_2": 7, "mu_3": 3,
24+
"sigma_1": 1, "sigma_2": 1, "sigma_3": 1},
25+
{"node": "slate",
26+
"mu_1": 4, "mu_2": 8, "mu_3": 2,
27+
"sigma_1": 1, "sigma_2": 1, "sigma_3": 1},
28+
{"node": "light",
29+
"mu_1": 10, "mu_2": 0, "mu_3": 10,
30+
"sigma_1": 4, "sigma_2": 4, "sigma_3": 4},
31+
{"node": "white",
32+
"mu_1": 9, "mu_2": 1, "mu_3": 9,
33+
"sigma_1": 3, "sigma_2": 3, "sigma_3": 3},
3434
]
3535

36+
3637
def load_shades_class_hierachy():
3738
ch = ClassHierarchy("colors")
3839
ch.add_node("light", "colors")
@@ -44,10 +45,11 @@ def load_shades_class_hierachy():
4445
ch.add_node("ash", "gray")
4546
return ch
4647

48+
4749
def load_shades_data(random_seed=1):
4850
random.seed(random_seed)
4951
data_rows = []
50-
label_rows =[]
52+
label_rows = []
5153
for seed in seeds:
5254
for i in range(0, int(100 + 100 * random.random())):
5355
data_row = {}

hmc/exceptions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
Exceptions and warnings particular to hierachical multi-classification.
3+
"""
4+
5+
__all__ = ["NoSamplesForStageWarning", "StageNotFitWarning", "ClassifierNotFitError"]
6+
7+
8+
class NoSamplesForStageWarning(UserWarning):
9+
"""Warning used to notify that classification stage has no eligible samples.
10+
This warning happens when no samples in the input set (the design matrix) are eligible for
11+
classification at the current stage. This can happen during training or prediction if no
12+
samples are in, or descend from, the class in the class hierarchy corresponding to the stage.
13+
"""
14+
15+
16+
class StageNotFitWarning(UserWarning):
17+
"""Warning used to notify that no estimator was fit for the classification stage.
18+
This warning happens when samples are eligible for prediction at a classification stage that
19+
was not fit when the hierachical classifier was fit. This can happen if the training set used
20+
to fit the hierachical classifier had no samples in, or descending from, the class in the
21+
class hierarchy corresponding to the stage.
22+
"""
23+
24+
25+
class ClassifierNotFitError(ValueError):
26+
"""Warning used to notify that no estimators were fit for the hierachical classifier.
27+
This warning happens when the hierachical classifier is exploited without first being fit.
28+
"""

hmc/hmc.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
from __future__ import print_function
66
from __future__ import division
77

8+
import warnings
9+
810
from sklearn import tree
911

1012
import numpy as np
1113
import pandas as pd
1214

1315
import metrics
16+
from exceptions import *
1417

1518
__all__ = ["ClassHierarchy", "DecisionTreeHierarchicalClassifier"]
1619

1720
# =============================================================================
1821
# Class Hierarchy
1922
# =============================================================================
2023

24+
2125
class ClassHierarchy:
2226
"""
2327
Class for class heirarchy.
@@ -40,7 +44,8 @@ def _get_parent(self, child):
4044

4145
def _get_children(self, parent):
4246
# Return a list of children nodes in alpha order
43-
return sorted([child for child, childs_parent in self.nodes.iteritems() if childs_parent == parent])
47+
return sorted([child for child, childs_parent in
48+
self.nodes.iteritems() if childs_parent == parent])
4449

4550
def _get_ancestors(self, child):
4651
# Return a list of the ancestors of this node
@@ -97,7 +102,8 @@ def add_node(self, child, parent):
97102
raise ValueError('The hierarchy root: ' + str(child) + ' is not a valid child node.')
98103
if child in self.nodes.keys():
99104
if self.nodes[child] != parent:
100-
raise ValueError('Node: ' + str(child) + ' has already been assigned parnet: ' + str(child) )
105+
raise ValueError('Node: ' + str(child) + ' has already been assigned parent: ' +
106+
str(child))
101107
else:
102108
return
103109
self.nodes[child] = parent
@@ -126,6 +132,7 @@ def print_(self):
126132
# Decision Tree Hierarchical Classifier
127133
# =============================================================================
128134

135+
129136
class DecisionTreeHierarchicalClassifier:
130137

131138
def __init__(self, class_hierarchy):
@@ -145,7 +152,8 @@ def _depth_first_class_prob(self, tree, node, indent, last, hand):
145152
indent += u"\u2502 "
146153
print(hand + " " + str(node))
147154
for k, count in enumerate(tree.tree_.value[node][0]):
148-
print(indent + str(tree.classes_[k]) + ":" + str(stage(count / tree.tree_.n_node_samples[node], 2)))
155+
print(indent + str(tree.classes_[k]) + ":" +
156+
str(stage(count / tree.tree_.n_node_samples[node], 2)))
149157
self._depth_first_class_prob(tree, tree.tree_.children_right[node], indent, False, "R")
150158
self._depth_first_class_prob(tree, tree.tree_.children_left[node], indent, True, "L")
151159

@@ -183,8 +191,7 @@ def _prep_data(self, X, y):
183191
for stage_number, stage in enumerate(self.stages):
184192
df[stage['target']] = pd.DataFrame.apply(
185193
df[[target]],
186-
lambda row: self._recode_label(stage['classes'],
187-
row[target]),
194+
lambda row: self._recode_label(stage['classes'], row[target]),
188195
axis=1)
189196
return df, dm_cols
190197

@@ -196,27 +203,41 @@ def fit(self, X, y):
196203
df, dm_cols = self._prep_data(X, y)
197204
# Fit each stage
198205
for stage_number, stage in enumerate(self.stages):
206+
dm = df[df[stage['target']].isin(stage['classes'])][dm_cols]
207+
y_stage = df[df[stage['target']].isin(stage['classes'])][[stage['target']]]
199208
stage['tree'] = tree.DecisionTreeClassifier()
200-
stage['tree'] = stage['tree'].fit(
201-
df[df[stage['target']].isin(stage['classes'])][dm_cols],
202-
df[df[stage['target']].isin(stage['classes'])][[stage['target']]])
209+
if dm.empty:
210+
warnings.warn('No samples to fit for stage ' + str(stage['stage']),
211+
NoSamplesForStageWarning)
212+
continue
213+
stage['tree'] = stage['tree'].fit(dm, y_stage)
203214
return self
204215

205216
def _check_fit(self):
206217
for stage in self.stages:
207218
if 'tree' not in stage.keys():
208-
raise ValueError('Estimators not fitted, call `fit` before exploiting the model.')
219+
raise ClassifierNotFitError(
220+
'Estimators not fitted, call `fit` before exploiting the model.')
209221

210222
def _predict_stages(self, X):
211223
# Score each stage
212224
for stage_number, stage in enumerate(self.stages):
213225
if stage_number == 0:
214-
y_hat = pd.DataFrame([self.class_hierarchy.root] * len(X), columns=[self.class_hierarchy.root], index=X.index)
226+
y_hat = pd.DataFrame(
227+
[self.class_hierarchy.root] * len(X),
228+
columns=[self.class_hierarchy.root],
229+
index=X.index)
215230
else:
216231
y_hat[stage['stage']] = y_hat[self.stages[stage_number - 1]['stage']]
217232
dm = X[y_hat[stage['stage']].isin([stage['stage']])]
218233
# Skip empty matrices
219234
if dm.empty:
235+
warnings.warn('No samples to predict for stage ' + str(stage['stage']),
236+
NoSamplesForStageWarning)
237+
continue
238+
if not stage['tree'].tree_:
239+
warnings.warn('No tree was fit for stage ' + str(stage['stage']),
240+
StageNotFitWarning)
220241
continue
221242
# combine_first reorders DataFrames, so we have to do this the ugly way
222243
y_hat_stage = pd.DataFrame(stage['tree'].predict(dm), index=dm.index)

hmc/metrics.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pandas as pd
1818

19+
1920
def _check_targets_hmc(y_true, y_pred):
2021
check_consistent_length(y_true, y_pred)
2122
y_type = set([type_of_target(y_true), type_of_target(y_pred)])
@@ -27,13 +28,15 @@ def _check_targets_hmc(y_true, y_pred):
2728
y_pred = column_or_1d(y_pred)
2829
return y_true, y_pred
2930

30-
## General Scores
31+
32+
# General Scores
3133
# Average accuracy
3234
def accuracy_score(class_hierarchy, y_true, y_pred):
3335
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
3436
return skmetrics.accuracy_score(y_true, y_pred)
3537

36-
## Hierarchy Precision / Recall
38+
39+
# Hierarchy Precision / Recall
3740
def _aggregate_class_sets(set_function, y_true, y_pred):
3841
intersection_sum = 0
3942
true_sum = 0
@@ -46,32 +49,41 @@ def _aggregate_class_sets(set_function, y_true, y_pred):
4649
predicted_sum += len(pred_set)
4750
return (true_sum, predicted_sum, intersection_sum)
4851

52+
4953
# Ancestors Scores (Super Class)
5054
# Precision
5155
def precision_score_ancestors(class_hierarchy, y_true, y_pred):
5256
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
53-
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_ancestors, y_true, y_pred)
57+
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(
58+
class_hierarchy._get_ancestors, y_true, y_pred)
5459
return intersection_sum / predicted_sum
5560

61+
5662
# Recall
5763
def recall_score_ancestors(class_hierarchy, y_true, y_pred):
5864
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
59-
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_ancestors, y_true, y_pred)
65+
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(
66+
class_hierarchy._get_ancestors, y_true, y_pred)
6067
return intersection_sum / true_sum
6168

69+
6270
# Descendants Scores (Sub Class)
6371
# Precision
6472
def precision_score_descendants(class_hierarchy, y_true, y_pred):
6573
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
66-
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_descendants, y_true, y_pred)
74+
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(
75+
class_hierarchy._get_descendants, y_true, y_pred)
6776
return intersection_sum / predicted_sum
6877

78+
6979
# Recall
7080
def recall_score_descendants(class_hierarchy, y_true, y_pred):
7181
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
72-
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_descendants, y_true, y_pred)
82+
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(
83+
class_hierarchy._get_descendants, y_true, y_pred)
7384
return intersection_sum / true_sum
7485

86+
7587
# Hierarchy Fscore
7688
def _fbeta_score_class_sets(set_function, y_true, y_pred, beta=1):
7789
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
@@ -80,10 +92,12 @@ def _fbeta_score_class_sets(set_function, y_true, y_pred, beta=1):
8092
recall = intersection_sum / true_sum
8193
return ((beta ** 2 + 1) * precision * recall) / ((beta ** 2 * precision) + recall)
8294

95+
8396
def f1_score_ancestors(class_hierarchy, y_true, y_pred):
8497
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
8598
return _fbeta_score_class_sets(class_hierarchy._get_ancestors, y_true, y_pred)
8699

100+
87101
def f1_score_descendants(class_hierarchy, y_true, y_pred):
88102
y_true, y_pred = _check_targets_hmc(y_true, y_pred)
89103
return _fbeta_score_class_sets(class_hierarchy._get_descendants, y_true, y_pred)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from setuptools import setup
44

55
setup(name='hmc',
6-
version='0.2',
6+
version='0.3',
77
description='Decision tree based hierachical multi-classifier',
88
url='https://github.com/davidwarshaw/hmc',
99
author='David Warshaw',
1010
author_email='[email protected]',
11-
py_modules=['hmc.hmc', 'hmc.datasets', 'hmc.metrics'],
11+
py_modules=['hmc.hmc', 'hmc.datasets', 'hmc.metrics', 'hmc.exceptions'],
1212
requires=['sklearn', 'numpy', 'pandas'])

tests/test_hmc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from sklearn.cross_validation import train_test_split
1111

1212
import hmc
13+
from hmc.exceptions import *
14+
1315

1416
class TestClassHierarchy(unittest.TestCase):
1517

@@ -66,6 +68,7 @@ def test_add_dag_node(self):
6668
# Adding a child with a new parent should throw an exception
6769
self.assertRaises(ValueError, ch.add_node, "slate", "light")
6870

71+
6972
class TestDecisionTreeHierarchicalClassifier(unittest.TestCase):
7073

7174
def test_fit(self):
@@ -103,15 +106,15 @@ def row_is_hierarchical(row):
103106
return is_hierarchical
104107

105108
stage_predictions = dt._predict_stages(X)
106-
stage_predictions['Hierarchical'] = stage_predictions.apply(lambda row: row_is_hierarchical(row), axis=1)
109+
stage_predictions['Hierarchical'] = stage_predictions.apply(
110+
lambda row: row_is_hierarchical(row), axis=1)
107111
# Each stage of classification should descend from the previous class
108112
self.assertEqual(len(stage_predictions[stage_predictions['Hierarchical'] != True]), 0)
109113

110114
def test_score(self):
111115
ch = hmc.load_shades_class_hierachy()
112116
X, y = hmc.load_shades_data()
113-
X_train, X_test, y_train, y_test = train_test_split(X, y,
114-
test_size = 0.50, random_state = 0)
117+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.50, random_state=0)
115118
dt = hmc.DecisionTreeHierarchicalClassifier(ch)
116119
dt_nonh = tree.DecisionTreeClassifier()
117120
dt = dt.fit(X_train, y_train)
@@ -126,7 +129,8 @@ def test_score_before_fit(self):
126129
X, y = hmc.load_shades_data()
127130
dt = hmc.DecisionTreeHierarchicalClassifier(ch)
128131
# Scoring without fitting should raise exception
129-
self.assertRaises(ValueError, dt.score, X, y)
132+
self.assertRaises(ClassifierNotFitError, dt.score, X, y)
133+
130134

131135
if __name__ == '__main__':
132136
unittest.main()

0 commit comments

Comments
 (0)