From e5c92d4e6b876cdab0a7b1efca36be1e9d35b6b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 09:55:45 +0100 Subject: [PATCH 01/12] Update ModelTree.py Added function to leaf nodes, all models must have .get_params() func --- advanced_ML/model_tree/src/ModelTree.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/src/ModelTree.py b/advanced_ML/model_tree/src/ModelTree.py index cdead5e..70a7dff 100644 --- a/advanced_ML/model_tree/src/ModelTree.py +++ b/advanced_ML/model_tree/src/ModelTree.py @@ -149,6 +149,8 @@ def _explain(node, x, explanation): no_children = node["children"]["left"] is None and \ node["children"]["right"] is None if no_children: + + return explanation else: if x[node["j_feature"]] <= node["threshold"]: # x[j] < threshold @@ -196,7 +198,14 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label # Create node node_index = node["index"] if node["children"]["left"] is None and node["children"]["right"] is None: - threshold_str = "" + + params = node['model'].get_params() + threshold_str="y = " + for i in range(len(params)): + threshold_str += str(round(params[i],2])) + threshold_str += "*X"+str(i) + threshold_str+="\n" + else: threshold_str = "{} <= {:.1f}\\n".format(feature_names[node['j_feature']], node["threshold"]) From cbcc8a273b1be191997b658619ab67dd389e545f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 10:51:36 +0100 Subject: [PATCH 02/12] Update ModelTree.py --- advanced_ML/model_tree/src/ModelTree.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/advanced_ML/model_tree/src/ModelTree.py b/advanced_ML/model_tree/src/ModelTree.py index 70a7dff..7a9e9f9 100644 --- a/advanced_ML/model_tree/src/ModelTree.py +++ b/advanced_ML/model_tree/src/ModelTree.py @@ -200,11 +200,12 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label if node["children"]["left"] is None and node["children"]["right"] is None: params = node['model'].get_params() - threshold_str="y = " - for i in range(len(params)): - threshold_str += str(round(params[i],2])) - threshold_str += "*X"+str(i) - threshold_str+="\n" + if params is not None : + threshold_str="y = " + for i in range(len(params)): + threshold_str += str(round(params[i],2])) + threshold_str += "*X"+str(i) + threshold_str+="\n" else: threshold_str = "{} <= {:.1f}\\n".format(feature_names[node['j_feature']], node["threshold"]) From 7532991cd4cacd8c15e490ad7f37cda77ee12fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 10:52:43 +0100 Subject: [PATCH 03/12] Update ModelTree.py --- advanced_ML/model_tree/src/ModelTree.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/advanced_ML/model_tree/src/ModelTree.py b/advanced_ML/model_tree/src/ModelTree.py index 7a9e9f9..1397d22 100644 --- a/advanced_ML/model_tree/src/ModelTree.py +++ b/advanced_ML/model_tree/src/ModelTree.py @@ -206,6 +206,8 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label threshold_str += str(round(params[i],2])) threshold_str += "*X"+str(i) threshold_str+="\n" + else: + threshold_str="" else: threshold_str = "{} <= {:.1f}\\n".format(feature_names[node['j_feature']], node["threshold"]) From d730b4b67f5522b1d4cb0bfe9548386d38338a63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 10:53:42 +0100 Subject: [PATCH 04/12] Update linear_regr.py --- advanced_ML/model_tree/models/linear_regr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/linear_regr.py b/advanced_ML/model_tree/models/linear_regr.py index 8cbbf58..4c16e48 100644 --- a/advanced_ML/model_tree/models/linear_regr.py +++ b/advanced_ML/model_tree/models/linear_regr.py @@ -21,4 +21,5 @@ def predict(self, X): def loss(self, X, y, y_pred): return mean_squared_error(y, y_pred) - + def get_params(self): + return self.model.coef_ From 3019da13bb00f9610124ca28d076c3d3aa51df39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:03:50 +0100 Subject: [PATCH 05/12] Update logistic_regr.py Fixed numpy bug --- advanced_ML/model_tree/models/logistic_regr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/logistic_regr.py b/advanced_ML/model_tree/models/logistic_regr.py index 86f4518..f3ba6c1 100644 --- a/advanced_ML/model_tree/models/logistic_regr.py +++ b/advanced_ML/model_tree/models/logistic_regr.py @@ -4,6 +4,7 @@ """ from sklearn.metrics import mean_squared_error +import numpy as np class logistic_regr: @@ -26,6 +27,8 @@ def predict(self, X): return self.flag_y_pred * np.ones((len(X),), dtype=int) else: return self.model.predict(X) + def get_params(self): + return None def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) From e624825a6253facc3ad53c93bee51bd5b31eb5e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:04:26 +0100 Subject: [PATCH 06/12] Update DT_sklearn_clf.py --- advanced_ML/model_tree/models/DT_sklearn_clf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/DT_sklearn_clf.py b/advanced_ML/model_tree/models/DT_sklearn_clf.py index 24dcecc..4d70e6d 100644 --- a/advanced_ML/model_tree/models/DT_sklearn_clf.py +++ b/advanced_ML/model_tree/models/DT_sklearn_clf.py @@ -21,6 +21,8 @@ def predict(self, X): def loss(self, X, y, y_pred): return gini_impurity(y) + def get_params(self): + return None def gini_impurity(y): p2 = 0.0 @@ -28,4 +30,4 @@ def gini_impurity(y): for c in y_classes: p2 += (np.sum(y == c) / len(y))**2 loss = 1.0 - p2 - return loss \ No newline at end of file + return loss From d96299f36d4108fd50ce5f8dc9f87a1dbdbde170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:06:36 +0100 Subject: [PATCH 07/12] Update DT_sklearn_regr.py --- advanced_ML/model_tree/models/DT_sklearn_regr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/DT_sklearn_regr.py b/advanced_ML/model_tree/models/DT_sklearn_regr.py index 1ba3c2a..4925bac 100644 --- a/advanced_ML/model_tree/models/DT_sklearn_regr.py +++ b/advanced_ML/model_tree/models/DT_sklearn_regr.py @@ -20,4 +20,6 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + def get_params(self): + return None From 4d1af5db0bbe8e48e990937444e1f5b7832bb9bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:08:24 +0100 Subject: [PATCH 08/12] Update NN_regr.py --- advanced_ML/model_tree/models/NN_regr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/NN_regr.py b/advanced_ML/model_tree/models/NN_regr.py index 5989f21..de40192 100644 --- a/advanced_ML/model_tree/models/NN_regr.py +++ b/advanced_ML/model_tree/models/NN_regr.py @@ -33,4 +33,7 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None From a600df65ad83ee52366aae8ef6d84e93d1a3f44a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:08:51 +0100 Subject: [PATCH 09/12] Update mean_regr.py --- advanced_ML/model_tree/models/mean_regr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/mean_regr.py b/advanced_ML/model_tree/models/mean_regr.py index 95f01d4..a972662 100644 --- a/advanced_ML/model_tree/models/mean_regr.py +++ b/advanced_ML/model_tree/models/mean_regr.py @@ -19,4 +19,7 @@ def predict(self, X): return self.y_mean * np.ones(len(X)) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None From 3e233031c8d2c3a9bb8bbc41dd82a4f33ccccc8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:09:15 +0100 Subject: [PATCH 10/12] Update modal_clf.py --- advanced_ML/model_tree/models/modal_clf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/modal_clf.py b/advanced_ML/model_tree/models/modal_clf.py index 7270516..eb5bfcf 100644 --- a/advanced_ML/model_tree/models/modal_clf.py +++ b/advanced_ML/model_tree/models/modal_clf.py @@ -20,6 +20,9 @@ def predict(self, X): def loss(self, X, y, y_pred): return gini_impurity(y) + + def get_params(self): + return None def gini_impurity(y): p2 = 0.0 @@ -27,4 +30,4 @@ def gini_impurity(y): for c in y_classes: p2 += (np.sum(y == c) / len(y))**2 loss = 1.0 - p2 - return loss \ No newline at end of file + return loss From fbfc5c028ba1bab87cd906afe2f57e19d3aea95d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:12:50 +0100 Subject: [PATCH 11/12] Update svm_regr.py --- advanced_ML/model_tree/models/svm_regr.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/models/svm_regr.py b/advanced_ML/model_tree/models/svm_regr.py index 58a1f13..f685eeb 100644 --- a/advanced_ML/model_tree/models/svm_regr.py +++ b/advanced_ML/model_tree/models/svm_regr.py @@ -18,4 +18,7 @@ def predict(self, X): return self.model.predict(X) def loss(self, X, y, y_pred): - return mean_squared_error(y, y_pred) \ No newline at end of file + return mean_squared_error(y, y_pred) + + def get_params(self): + return None From 02dee30cd7b7ce46daacadb24d5864ae85044d4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vilde=20Gj=C3=A6rum?= Date: Mon, 23 Mar 2020 11:13:42 +0100 Subject: [PATCH 12/12] Update ModelTree.py --- advanced_ML/model_tree/src/ModelTree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/advanced_ML/model_tree/src/ModelTree.py b/advanced_ML/model_tree/src/ModelTree.py index 1397d22..0a96b0d 100644 --- a/advanced_ML/model_tree/src/ModelTree.py +++ b/advanced_ML/model_tree/src/ModelTree.py @@ -203,7 +203,7 @@ def build_graphviz_recurse(node, parent_node_index=0, parent_depth=0, edge_label if params is not None : threshold_str="y = " for i in range(len(params)): - threshold_str += str(round(params[i],2])) + threshold_str += str(round(params[i],2)) threshold_str += "*X"+str(i) threshold_str+="\n" else: