20
20
21
21
# Super class for Classification and Regression
22
22
class DecisionTree (object ):
23
- """ Super class of RegressionTree and ClassificationTree.
23
+ ''' Super class of RegressionTree and ClassificationTree.
24
24
25
- """
25
+ '''
26
26
def __init__ (self , min_samples_split = 2 , min_impurity = 1e-7 , max_depth = float ("inf" ), thresholdFromMean = False ):
27
27
28
28
'''
@@ -104,8 +104,8 @@ def fit(self, X, y):
104
104
print ('---------------------------------------' )
105
105
self .trained = True
106
106
def _build_tree (self , X , y , current_depth = 0 ):
107
- """ Recursive method which builds out the decision tree and splits X and respective y
108
- on the feature of X which (based on impurity) best separates the data"""
107
+ ''' Recursive method which builds out the decision tree and splits X and respective y
108
+ on the feature of X which (based on impurity) best separates the data'''
109
109
110
110
largest_impurity = 0
111
111
best_criteria = None # Feature index and threshold
@@ -267,8 +267,8 @@ def _build_tree(self, X, y, current_depth=0):
267
267
268
268
return node
269
269
def predict_value (self , x , tree = None ,path = '' ):
270
- """ Do a recursive search down the tree and make a prediction of the data sample by the
271
- value of the leaf that we end up at """
270
+ ''' Do a recursive search down the tree and make a prediction of the data sample by the
271
+ value of the leaf that we end up at '''
272
272
273
273
# check if sample has same number of features
274
274
assert len (x )== self .nfeatures
@@ -301,7 +301,7 @@ def predict_value(self, x, tree=None,path=''):
301
301
# Test subtree
302
302
return self .predict_value (x , branch ,path )
303
303
def predict (self , X ,treePath = False ):
304
- """ Classify samples one by one and return the set of labels """
304
+ ''' Classify samples one by one and return the set of labels '''
305
305
306
306
if treePath :
307
307
y_pred = np .array ([list (self .predict_value (x )) for x in X ])
@@ -338,7 +338,7 @@ def pruneTree(self,DT):
338
338
DT ['T' ] = self .pruneTree (DT ['T' ])
339
339
DT ['F' ] = self .pruneTree (DT ['F' ])
340
340
return DT
341
- def plotTree (self ,scale = True ,show = True , showtitle = True , showDirection = True ,DiffBranchColor = False ,legend = True ):
341
+ def plotTree (self ,scale = True ,show = True , showtitle = True , showDirection = False ,DiffBranchColor = True ,legend = True ):
342
342
import copy
343
343
self .DT = copy .deepcopy (self .tree )
344
344
if not (self .DT ['leaf' ]):
@@ -444,9 +444,10 @@ def plotTreePath(self,path,ax=None,fig=None):
444
444
445
445
class ClassificationTree (DecisionTree ):
446
446
def entropy (self ,y ):
447
- """ Calculate the entropy of array y
447
+ '''
448
+ Calculate the entropy of array y
448
449
H(y) = - sum(p(y)*log2(p(y)))
449
- """
450
+ '''
450
451
yi = y
451
452
if len (y .shape )> 1 and y .shape [1 ]> 1 :
452
453
yi = np .argmax (y ,axis = 1 )
@@ -457,11 +458,10 @@ def entropy(self,y):
457
458
return Hy
458
459
459
460
def _infoGain (self , y , y1 , y2 ):
460
- # Calculate information gain
461
- """ Calculate the information Gain with Entropy
462
- H(y) = - sum(p(y)*log2(p(y)))
463
- I_gain = H(y) - P(y1) * H(y1) - (1 - P(y1)) * H(y2)
464
- """
461
+ '''
462
+ Calculate the information Gain with Entropy
463
+ I_gain = H(y) - P(y1) * H(y1) - (1 - P(y1)) * H(y2)
464
+ '''
465
465
p = len (y1 ) / len (y )
466
466
info_gain = self .entropy (y ) - p * self .entropy (y1 ) - (1 - p ) * self .entropy (y2 )
467
467
return info_gain
@@ -476,13 +476,13 @@ def fit(self, X, y,verbose=0,feature_names=None,randomBranch=False):
476
476
'''
477
477
Parameters:
478
478
-----------
479
- X:: ndarray (number of sample, number of features)
480
- y:: list of 1D array
481
- verbose::0 - no progress or tree (silent)
482
- ::1 - show progress in short
483
- ::2 - show progress with details with branches
484
- ::3 - show progress with branches True/False
485
- ::4 - show progress in short with plot tree
479
+ X :: ndarray (number of sample, number of features)
480
+ y :: list of 1D array
481
+ verbose ::0 - no progress or tree (silent)
482
+ ::1 - show progress in short
483
+ ::2 - show progress with details with branches
484
+ ::3 - show progress with branches True/False
485
+ ::4 - show progress in short with plot tree
486
486
487
487
feature_names:: (optinal) list, Provide for better look at tree while plotting or shwoing the progress,
488
488
default to None, if not provided, features are named as f1,...fn
@@ -496,10 +496,10 @@ def fit(self, X, y,verbose=0,feature_names=None,randomBranch=False):
496
496
497
497
class RegressionTree (DecisionTree ):
498
498
def _varReduction (self , y , y1 , y2 ):
499
- '''
500
- Calculate the variance reduction
501
- VarRed = Var(y) - P(y1) * Var(y1) - P(p2) * Var(y2)
502
- '''
499
+ '''
500
+ Calculate the variance reduction
501
+ VarRed = Var(y) - P(y1) * Var(y1) - P(p2) * Var(y2)
502
+ '''
503
503
assert len (y .shape )== 1 or y .shape [1 ]== 1
504
504
505
505
p1 = len (y1 ) / len (y )
0 commit comments