55from __future__ import print_function
66from __future__ import division
77
8+ import warnings
9+
810from sklearn import tree
911
1012import numpy as np
1113import pandas as pd
1214
1315import metrics
16+ from exceptions import *
1417
1518__all__ = ["ClassHierarchy" , "DecisionTreeHierarchicalClassifier" ]
1619
1720# =============================================================================
1821# Class Hierarchy
1922# =============================================================================
2023
24+
2125class ClassHierarchy :
2226 """
2327 Class for class heirarchy.
@@ -40,7 +44,8 @@ def _get_parent(self, child):
4044
4145 def _get_children (self , parent ):
4246 # Return a list of children nodes in alpha order
43- return sorted ([child for child , childs_parent in self .nodes .iteritems () if childs_parent == parent ])
47+ return sorted ([child for child , childs_parent in
48+ self .nodes .iteritems () if childs_parent == parent ])
4449
4550 def _get_ancestors (self , child ):
4651 # Return a list of the ancestors of this node
@@ -97,7 +102,8 @@ def add_node(self, child, parent):
97102 raise ValueError ('The hierarchy root: ' + str (child ) + ' is not a valid child node.' )
98103 if child in self .nodes .keys ():
99104 if self .nodes [child ] != parent :
100- raise ValueError ('Node: ' + str (child ) + ' has already been assigned parnet: ' + str (child ) )
105+ raise ValueError ('Node: ' + str (child ) + ' has already been assigned parent: ' +
106+ str (child ))
101107 else :
102108 return
103109 self .nodes [child ] = parent
@@ -126,6 +132,7 @@ def print_(self):
126132# Decision Tree Hierarchical Classifier
127133# =============================================================================
128134
135+
129136class DecisionTreeHierarchicalClassifier :
130137
131138 def __init__ (self , class_hierarchy ):
@@ -145,7 +152,8 @@ def _depth_first_class_prob(self, tree, node, indent, last, hand):
145152 indent += u"\u2502 "
146153 print (hand + " " + str (node ))
147154 for k , count in enumerate (tree .tree_ .value [node ][0 ]):
148- print (indent + str (tree .classes_ [k ]) + ":" + str (stage (count / tree .tree_ .n_node_samples [node ], 2 )))
155+ print (indent + str (tree .classes_ [k ]) + ":" +
156+ str (stage (count / tree .tree_ .n_node_samples [node ], 2 )))
149157 self ._depth_first_class_prob (tree , tree .tree_ .children_right [node ], indent , False , "R" )
150158 self ._depth_first_class_prob (tree , tree .tree_ .children_left [node ], indent , True , "L" )
151159
@@ -183,8 +191,7 @@ def _prep_data(self, X, y):
183191 for stage_number , stage in enumerate (self .stages ):
184192 df [stage ['target' ]] = pd .DataFrame .apply (
185193 df [[target ]],
186- lambda row : self ._recode_label (stage ['classes' ],
187- row [target ]),
194+ lambda row : self ._recode_label (stage ['classes' ], row [target ]),
188195 axis = 1 )
189196 return df , dm_cols
190197
@@ -196,27 +203,41 @@ def fit(self, X, y):
196203 df , dm_cols = self ._prep_data (X , y )
197204 # Fit each stage
198205 for stage_number , stage in enumerate (self .stages ):
206+ dm = df [df [stage ['target' ]].isin (stage ['classes' ])][dm_cols ]
207+ y_stage = df [df [stage ['target' ]].isin (stage ['classes' ])][[stage ['target' ]]]
199208 stage ['tree' ] = tree .DecisionTreeClassifier ()
200- stage ['tree' ] = stage ['tree' ].fit (
201- df [df [stage ['target' ]].isin (stage ['classes' ])][dm_cols ],
202- df [df [stage ['target' ]].isin (stage ['classes' ])][[stage ['target' ]]])
209+ if dm .empty :
210+ warnings .warn ('No samples to fit for stage ' + str (stage ['stage' ]),
211+ NoSamplesForStageWarning )
212+ continue
213+ stage ['tree' ] = stage ['tree' ].fit (dm , y_stage )
203214 return self
204215
205216 def _check_fit (self ):
206217 for stage in self .stages :
207218 if 'tree' not in stage .keys ():
208- raise ValueError ('Estimators not fitted, call `fit` before exploiting the model.' )
219+ raise ClassifierNotFitError (
220+ 'Estimators not fitted, call `fit` before exploiting the model.' )
209221
210222 def _predict_stages (self , X ):
211223 # Score each stage
212224 for stage_number , stage in enumerate (self .stages ):
213225 if stage_number == 0 :
214- y_hat = pd .DataFrame ([self .class_hierarchy .root ] * len (X ), columns = [self .class_hierarchy .root ], index = X .index )
226+ y_hat = pd .DataFrame (
227+ [self .class_hierarchy .root ] * len (X ),
228+ columns = [self .class_hierarchy .root ],
229+ index = X .index )
215230 else :
216231 y_hat [stage ['stage' ]] = y_hat [self .stages [stage_number - 1 ]['stage' ]]
217232 dm = X [y_hat [stage ['stage' ]].isin ([stage ['stage' ]])]
218233 # Skip empty matrices
219234 if dm .empty :
235+ warnings .warn ('No samples to predict for stage ' + str (stage ['stage' ]),
236+ NoSamplesForStageWarning )
237+ continue
238+ if not stage ['tree' ].tree_ :
239+ warnings .warn ('No tree was fit for stage ' + str (stage ['stage' ]),
240+ StageNotFitWarning )
220241 continue
221242 # combine_first reorders DataFrames, so we have to do this the ugly way
222243 y_hat_stage = pd .DataFrame (stage ['tree' ].predict (dm ), index = dm .index )
0 commit comments