diff --git a/advanced_ML/model_tree/models/DT_sklearn_clf.py b/advanced_ML/model_tree/models/DT_sklearn_clf.py index 24dcecc..4d70e6d 100644 --- a/advanced_ML/model_tree/models/DT_sklearn_clf.py +++ b/advanced_ML/model_tree/models/DT_sklearn_clf.py @@ -21,6 +21,8 @@ def predict(self, X): def loss(self, X, y, y_pred): return gini_impurity(y) + def get_params(self): + return None def gini_impurity(y): p2 = 0.0 @@ -28,4 +30,4 @@ def gini_impurity(y): for c in y_classes: p2 += (np.sum(y == c) / len(y))**2 loss = 1.0 - p2 - return loss \ No newline at end of file + return loss diff --git a/advanced_ML/model_tree/models/DT_sklearn_regr.py b/advanced_ML/model_tree/models/DT_sklearn_regr.py index 1ba3c2a..4925bac 100644 --- a/advanced_ML/model_tree/models/DT_sklearn_regr.py +++ b/advanced_ML/model_tree/models/DT_sklearn_regr.py @@ -20,4 +20,6 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + def get_params(self): + return None diff --git a/advanced_ML/model_tree/models/NN_regr.py b/advanced_ML/model_tree/models/NN_regr.py index 5989f21..de40192 100644 --- a/advanced_ML/model_tree/models/NN_regr.py +++ b/advanced_ML/model_tree/models/NN_regr.py @@ -33,4 +33,7 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None diff --git a/advanced_ML/model_tree/models/linear_regr.py b/advanced_ML/model_tree/models/linear_regr.py index 8cbbf58..4c16e48 100644 --- a/advanced_ML/model_tree/models/linear_regr.py +++ b/advanced_ML/model_tree/models/linear_regr.py @@ -21,4 +21,5 @@ def predict(self, X): def loss(self, X, y, y_pred): return mean_squared_error(y, y_pred) - + def get_params(self): + return self.model.coef_ diff --git a/advanced_ML/model_tree/models/logistic_regr.py b/advanced_ML/model_tree/models/logistic_regr.py index 86f4518..f3ba6c1 100644 --- a/advanced_ML/model_tree/models/logistic_regr.py +++ b/advanced_ML/model_tree/models/logistic_regr.py @@ -4,6 +4,7 @@ """ from sklearn.metrics import mean_squared_error +import numpy as np class logistic_regr: @@ -26,6 +27,8 @@ def predict(self, X): return self.flag_y_pred * np.ones((len(X),), dtype=int) else: return self.model.predict(X) + def get_params(self): + return None def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) diff --git a/advanced_ML/model_tree/models/mean_regr.py b/advanced_ML/model_tree/models/mean_regr.py index 95f01d4..a972662 100644 --- a/advanced_ML/model_tree/models/mean_regr.py +++ b/advanced_ML/model_tree/models/mean_regr.py @@ -19,4 +19,7 @@ def predict(self, X): return self.y_mean * np.ones(len(X)) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None diff --git a/advanced_ML/model_tree/models/modal_clf.py b/advanced_ML/model_tree/models/modal_clf.py index 7270516..eb5bfcf 100644 --- a/advanced_ML/model_tree/models/modal_clf.py +++ b/advanced_ML/model_tree/models/modal_clf.py @@ -20,6 +20,9 @@ def predict(self, X): def loss(self, X, y, y_pred): return gini_impurity(y) + + def get_params(self): + return None def gini_impurity(y): p2 = 0.0 @@ -27,4 +30,4 @@ def gini_impurity(y): for c in y_classes: p2 += (np.sum(y == c) / len(y))**2 loss = 1.0 - p2 - return loss \ No newline at end of file + return loss diff --git a/advanced_ML/model_tree/models/svm_regr.py b/advanced_ML/model_tree/models/svm_regr.py index 58a1f13..f685eeb 100644 --- a/advanced_ML/model_tree/models/svm_regr.py +++ b/advanced_ML/model_tree/models/svm_regr.py @@ -18,4 +18,7 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None diff --git a/advanced_ML/model_tree/src/ModelTree.py b/advanced_ML/model_tree/src/ModelTree.py index cdead5e..0a96b0d 100644 --- a/advanced_ML/model_tree/src/ModelTree.py +++ b/advanced_ML/model_tree/src/ModelTree.py @@ -149,6 +149,8 @@ def _explain(node, x, explanation): no_children = node["children"]["left"] is None and \ node["children"]["right"] is None if no_children: + + return explanation else: if x[node["j_feature"]] <= node["threshold"]: # x[j] < threshold @@ -196,7 +198,17 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label # Create node node_index = node["index"] if node["children"]["left"] is None and node["children"]["right"] is None: - threshold_str = "" + + params = node['model'].get_params() + if params is not None : + threshold_str="y = " + for i in range(len(params)): + threshold_str += str(round(params[i],2)) + threshold_str += "*X"+str(i) + threshold_str+="\n" + else: + threshold_str="" + else: threshold_str = "{} <= {:.1f}\\n".format(feature_names[node['j_feature']], node["threshold"])