From 82c590e49ce8c8997be595b3c12133859dccf1e3 Mon Sep 17 00:00:00 2001 From: blublinsky Date: Mon, 6 Apr 2020 18:05:27 -0500 Subject: [PATCH 1/2] Add scikit-learn --- scikitLearn/python/IncomePrediction.ipynb | 678 ++++++++++++++++++++++ 1 file changed, 678 insertions(+) create mode 100644 scikitLearn/python/IncomePrediction.ipynb diff --git a/scikitLearn/python/IncomePrediction.ipynb b/scikitLearn/python/IncomePrediction.ipynb new file mode 100644 index 0000000..12300de --- /dev/null +++ b/scikitLearn/python/IncomePrediction.ipynb @@ -0,0 +1,678 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Income prediction\n", + "based on Seldon's implementation\n", + "https://github.com/SeldonIO/alibi/blob/master/examples/anchor_tabular_adult.ipynb and\n", + "https://github.com/SeldonIO/alibi/blob/5aec3ab4ce651ca2249bf849ecb434371c9278e4/alibi/datasets.py#L183" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already up-to-date: pandas in ./.local/lib/python3.6/site-packages (1.0.3)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.1)\n", + "Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2019.3)\n", + "Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1)\n", + "Requirement already satisfied, skipping upgrade: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.6.1->pandas) (1.11.0)\n", + "Requirement already up-to-date: scikit-learn in ./.local/lib/python3.6/site-packages (0.22.2.post1)\n", + "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn) (0.14.1)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.18.1)\n", + "Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.4.1)\n", + "Requirement already up-to-date: alibi in ./.local/lib/python3.6/site-packages (0.4.0)\n", + "Requirement already satisfied, skipping upgrade: scikit-learn in ./.local/lib/python3.6/site-packages (from alibi) (0.22.2.post1)\n", + "Requirement already satisfied, skipping upgrade: attrs in /usr/local/lib/python3.6/dist-packages (from alibi) (19.3.0)\n", + "Requirement already satisfied, skipping upgrade: beautifulsoup4 in ./.local/lib/python3.6/site-packages (from alibi) (4.8.2)\n", + "Requirement already satisfied, skipping upgrade: spacy in ./.local/lib/python3.6/site-packages (from alibi) (2.2.4)\n", + "Requirement already satisfied, skipping upgrade: shap in ./.local/lib/python3.6/site-packages (from alibi) (0.35.0)\n", + "Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from alibi) (1.4.1)\n", + "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from alibi) (2.22.0)\n", + "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from alibi) (1.18.1)\n", + "Requirement already satisfied, skipping upgrade: Pillow in ./.local/lib/python3.6/site-packages (from alibi) (7.0.0)\n", + "Requirement already satisfied, skipping upgrade: tensorflow<2.0 in /usr/local/lib/python3.6/dist-packages (from alibi) (1.15.2)\n", + "Requirement already satisfied, skipping upgrade: pandas in ./.local/lib/python3.6/site-packages (from alibi) (1.0.3)\n", + "Requirement already satisfied, skipping upgrade: prettyprinter in ./.local/lib/python3.6/site-packages (from alibi) (0.18.0)\n", + "Requirement already satisfied, skipping upgrade: scikit-image in ./.local/lib/python3.6/site-packages (from alibi) (0.16.2)\n", + "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn->alibi) (0.14.1)\n", + "Requirement already satisfied, skipping upgrade: soupsieve>=1.2 in ./.local/lib/python3.6/site-packages (from beautifulsoup4->alibi) (2.0)\n", + "Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.2)\n", + "Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (3.0.2)\n", + "Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.1.3)\n", + "Requirement already satisfied, skipping upgrade: blis<0.5.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.4.1)\n", + "Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (2.0.3)\n", + "Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (4.43.0)\n", + "Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.0)\n", + "Requirement already satisfied, skipping upgrade: thinc==7.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (7.4.0)\n", + "Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.2)\n", + "Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy->alibi) (45.1.0)\n", + "Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.6.0)\n", + "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (2019.11.28)\n", + "Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/lib/python3/dist-packages (from requests->alibi) (2.6)\n", + "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in ./.local/lib/python3.6/site-packages (from requests->alibi) (1.24.3)\n", + "Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (3.0.4)\n", + "Requirement already satisfied, skipping upgrade: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.2.2)\n", + "Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/lib/python3/dist-packages (from tensorflow<2.0->alibi) (0.30.0)\n", + "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/lib/python3/dist-packages (from tensorflow<2.0->alibi) (1.11.0)\n", + "Requirement already satisfied, skipping upgrade: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.8.1)\n", + "Requirement already satisfied, skipping upgrade: keras-applications>=1.0.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.0.8)\n", + "Requirement already satisfied, skipping upgrade: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.1.0)\n", + "Requirement already satisfied, skipping upgrade: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.11.2)\n", + "Requirement already satisfied, skipping upgrade: tensorflow-estimator==1.15.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.1)\n", + "Requirement already satisfied, skipping upgrade: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.1.8)\n", + "Requirement already satisfied, skipping upgrade: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.11.2)\n", + "Requirement already satisfied, skipping upgrade: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.26.0)\n", + "Requirement already satisfied, skipping upgrade: tensorboard<1.16.0,>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.0)\n", + "Requirement already satisfied, skipping upgrade: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.9.0)\n", + "Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->alibi) (2019.3)\n", + "Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->alibi) (2.8.1)\n", + "Requirement already satisfied, skipping upgrade: Pygments>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from prettyprinter->alibi) (2.5.2)\n", + "Requirement already satisfied, skipping upgrade: colorful>=0.4.0 in ./.local/lib/python3.6/site-packages (from prettyprinter->alibi) (0.5.4)\n", + "Requirement already satisfied, skipping upgrade: networkx>=2.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.4)\n", + "Requirement already satisfied, skipping upgrade: imageio>=2.3.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.8.0)\n", + "Requirement already satisfied, skipping upgrade: matplotlib!=3.0.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->alibi) (3.1.2)\n", + "Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (1.1.1)\n", + "Requirement already satisfied, skipping upgrade: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy->alibi) (1.4.0)\n", + "Requirement already satisfied, skipping upgrade: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow<2.0->alibi) (2.10.0)\n", + "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (0.16.1)\n", + "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (3.1.1)\n", + "Requirement already satisfied, skipping upgrade: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx>=2.0->scikit-image->alibi) (4.4.1)\n", + "Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (0.10.0)\n", + "Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (2.4.6)\n", + "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy->alibi) (2.1.0)\r\n" + ] + } + ], + "source": [ + "!pip install pandas --upgrade --user\n", + "!pip install scikit-learn --upgrade --user\n", + "!pip install alibi --upgrade --user" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder\n", + "from alibi.explainers import AnchorTabular\n", + "from alibi.datasets import fetch_adult\n", + "from alibi.utils.data import Bunch, gen_category_map\n", + "from typing import Tuple, Union\n", + "import requests\n", + "from requests import RequestException\n", + "from io import BytesIO, StringIO" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fetching and preprocessing data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def fetch_adult(features_drop: list = None, return_X_y: bool = False, url_id: int = 0) -> Union[Bunch, Tuple[np.ndarray, np.ndarray]]:\n", + " \"\"\"\n", + " Downloads and pre-processes 'adult' dataset.\n", + " More info: http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/\n", + " Parameters\n", + " ----------\n", + " features_drop\n", + " List of features to be dropped from dataset, by default drops [\"fnlwgt\", \"Education-Num\"]\n", + " return_X_y\n", + " If true, return features X and labels y as numpy arrays, if False return a Bunch object\n", + " url_id\n", + " Index specifying which URL to use for downloading\n", + " Returns\n", + " -------\n", + " Bunch\n", + " Dataset, labels, a list of features and a dictionary containing a list with the potential categories\n", + " for each categorical feature where the key refers to the feature column.\n", + " (data, target)\n", + " Tuple if ``return_X_y`` is true\n", + " \"\"\"\n", + " ADULT_URLS = ['https://storage.googleapis.com/seldon-datasets/adult/adult.data',\n", + " 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data',\n", + " 'http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data']\n", + " if features_drop is None:\n", + " features_drop = [\"fnlwgt\", \"Education-Num\"]\n", + "\n", + " # download data\n", + " dataset_url = ADULT_URLS[url_id]\n", + " raw_features = ['Age', 'Workclass', 'fnlwgt', 'Education', 'Education-Num', 'Marital Status',\n", + " 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss',\n", + " 'Hours per week', 'Country', 'Target']\n", + " try:\n", + " resp = requests.get(dataset_url)\n", + " resp.raise_for_status()\n", + " except RequestException:\n", + " logger.exception(\"Could not connect, URL may be out of service\")\n", + " raise\n", + "\n", + " raw_data = pd.read_csv(StringIO(resp.text), names=raw_features, delimiter=', ', engine='python').fillna('?')\n", + "\n", + " # get labels, features and drop unnecessary features\n", + " labels = (raw_data['Target'] == '>50K').astype(int).values\n", + " features_drop += ['Target']\n", + " data = raw_data.drop(features_drop, axis=1)\n", + " features = list(data.columns)\n", + "\n", + " # map categorical features\n", + " education_map = {\n", + " '10th': 'Dropout', '11th': 'Dropout', '12th': 'Dropout', '1st-4th':\n", + " 'Dropout', '5th-6th': 'Dropout', '7th-8th': 'Dropout', '9th':\n", + " 'Dropout', 'Preschool': 'Dropout', 'HS-grad': 'High School grad',\n", + " 'Some-college': 'High School grad', 'Masters': 'Masters',\n", + " 'Prof-school': 'Prof-School', 'Assoc-acdm': 'Associates',\n", + " 'Assoc-voc': 'Associates'\n", + " }\n", + " occupation_map = {\n", + " \"Adm-clerical\": \"Admin\", \"Armed-Forces\": \"Military\",\n", + " \"Craft-repair\": \"Blue-Collar\", \"Exec-managerial\": \"White-Collar\",\n", + " \"Farming-fishing\": \"Blue-Collar\", \"Handlers-cleaners\":\n", + " \"Blue-Collar\", \"Machine-op-inspct\": \"Blue-Collar\", \"Other-service\":\n", + " \"Service\", \"Priv-house-serv\": \"Service\", \"Prof-specialty\":\n", + " \"Professional\", \"Protective-serv\": \"Other\", \"Sales\":\n", + " \"Sales\", \"Tech-support\": \"Other\", \"Transport-moving\":\n", + " \"Blue-Collar\"\n", + " }\n", + " country_map = {\n", + " 'Cambodia': 'SE-Asia', 'Canada': 'British-Commonwealth', 'China':\n", + " 'China', 'Columbia': 'South-America', 'Cuba': 'Other',\n", + " 'Dominican-Republic': 'Latin-America', 'Ecuador': 'South-America',\n", + " 'El-Salvador': 'South-America', 'England': 'British-Commonwealth',\n", + " 'France': 'Euro_1', 'Germany': 'Euro_1', 'Greece': 'Euro_2',\n", + " 'Guatemala': 'Latin-America', 'Haiti': 'Latin-America',\n", + " 'Holand-Netherlands': 'Euro_1', 'Honduras': 'Latin-America',\n", + " 'Hong': 'China', 'Hungary': 'Euro_2', 'India':\n", + " 'British-Commonwealth', 'Iran': 'Other', 'Ireland':\n", + " 'British-Commonwealth', 'Italy': 'Euro_1', 'Jamaica':\n", + " 'Latin-America', 'Japan': 'Other', 'Laos': 'SE-Asia', 'Mexico':\n", + " 'Latin-America', 'Nicaragua': 'Latin-America',\n", + " 'Outlying-US(Guam-USVI-etc)': 'Latin-America', 'Peru':\n", + " 'South-America', 'Philippines': 'SE-Asia', 'Poland': 'Euro_2',\n", + " 'Portugal': 'Euro_2', 'Puerto-Rico': 'Latin-America', 'Scotland':\n", + " 'British-Commonwealth', 'South': 'Euro_2', 'Taiwan': 'China',\n", + " 'Thailand': 'SE-Asia', 'Trinadad&Tobago': 'Latin-America',\n", + " 'United-States': 'United-States', 'Vietnam': 'SE-Asia'\n", + " }\n", + " married_map = {\n", + " 'Never-married': 'Never-Married', 'Married-AF-spouse': 'Married',\n", + " 'Married-civ-spouse': 'Married', 'Married-spouse-absent':\n", + " 'Separated', 'Separated': 'Separated', 'Divorced':\n", + " 'Separated', 'Widowed': 'Widowed'\n", + " }\n", + " mapping = {'Education': education_map, 'Occupation': occupation_map, 'Country': country_map,\n", + " 'Marital Status': married_map}\n", + "\n", + " data_copy = data.copy()\n", + " for f, f_map in mapping.items():\n", + " data_tmp = data_copy[f].values\n", + " for key, value in f_map.items():\n", + " data_tmp[data_tmp == key] = value\n", + " data[f] = data_tmp\n", + "\n", + " # get categorical features and apply labelencoding\n", + " categorical_features = [f for f in features if data[f].dtype == 'O']\n", + " category_map = {}\n", + " for f in categorical_features:\n", + " le = LabelEncoder()\n", + " data_tmp = le.fit_transform(data[f].values)\n", + " data[f] = data_tmp\n", + " category_map[features.index(f)] = list(le.classes_)\n", + "\n", + " # only return data values\n", + " data = data.values\n", + " target_names = ['<=50K', '>50K']\n", + "\n", + " if return_X_y:\n", + " return data, labels\n", + "\n", + " return Bunch(data=data, target=labels, feature_names=features, target_names=target_names, category_map=category_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load adult dataset\n", + "The fetch_adult function returns a Bunch object containing the features, the targets, the feature names and a mapping of categorical variables to numbers which are required for formatting the output of the Anchor explainer." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['data', 'target', 'feature_names', 'target_names', 'category_map'])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adult = fetch_adult()\n", + "adult.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "data = adult.data\n", + "target = adult.target\n", + "feature_names = adult.feature_names\n", + "category_map = adult.category_map" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define shuffled training and test set" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(0)\n", + "data_perm = np.random.permutation(np.c_[data, target])\n", + "data = data_perm[:,:-1]\n", + "target = data_perm[:,-1]\n", + "idx = 30000\n", + "X_train,Y_train = data[:idx,:], target[:idx]\n", + "X_test, Y_test = data[idx+1:,:], target[idx+1:]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Create feature transformation pipeline\n", + "Create feature pre-processor. Needs to have 'fit' and 'transform' methods. Different types of pre-processing can be applied to all or part of the features. In the example below we will standardize ordinal features and apply one-hot-encoding to categorical features.\n", + "\n", + "Ordinal features:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "ordinal_features = [x for x in range(len(feature_names)) if x not in list(category_map.keys())]\n", + "ordinal_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n", + " ('scaler', StandardScaler())])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Categorical features:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "categorical_features = list(category_map.keys())\n", + "categorical_transformer = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),\n", + " ('onehot', OneHotEncoder(handle_unknown='ignore'))])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Combine and fit:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n", + " transformer_weights=None,\n", + " transformers=[('num',\n", + " Pipeline(memory=None,\n", + " steps=[('imputer',\n", + " SimpleImputer(add_indicator=False,\n", + " copy=True,\n", + " fill_value=None,\n", + " missing_values=nan,\n", + " strategy='median',\n", + " verbose=0)),\n", + " ('scaler',\n", + " StandardScaler(copy=True,\n", + " with_mean=True,\n", + " with_std=True))],\n", + " verbose=False),\n", + " [0, 8, 9, 10]),\n", + " ('cat',\n", + " Pipeline(memory=None,\n", + " steps=[('imputer',\n", + " SimpleImputer(add_indicator=False,\n", + " copy=True,\n", + " fill_value=None,\n", + " missing_values=nan,\n", + " strategy='median',\n", + " verbose=0)),\n", + " ('onehot',\n", + " OneHotEncoder(categories='auto',\n", + " drop=None,\n", + " dtype=,\n", + " handle_unknown='ignore',\n", + " sparse=True))],\n", + " verbose=False),\n", + " [1, 2, 3, 4, 5, 6, 7, 11])],\n", + " verbose=False)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preprocessor = ColumnTransformer(transformers=[('num', ordinal_transformer, ordinal_features),\n", + " ('cat', categorical_transformer, categorical_features)])\n", + "preprocessor.fit(X_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train Random Forest model\n", + "Fit on pre-processed (imputing, OHE, standardizing) data." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", + " criterion='gini', max_depth=None, max_features='auto',\n", + " max_leaf_nodes=None, max_samples=None,\n", + " min_impurity_decrease=0.0, min_impurity_split=None,\n", + " min_samples_leaf=1, min_samples_split=2,\n", + " min_weight_fraction_leaf=0.0, n_estimators=50,\n", + " n_jobs=None, oob_score=False, random_state=None,\n", + " verbose=0, warm_start=False)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.random.seed(0)\n", + "clf = RandomForestClassifier(n_estimators=50)\n", + "clf.fit(preprocessor.transform(X_train), Y_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define predict function" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train accuracy: 0.9655333333333334\n", + "Test accuracy: 0.855859375\n" + ] + } + ], + "source": [ + "predict_fn = lambda x: clf.predict(preprocessor.transform(x))\n", + "print('Train accuracy: ', accuracy_score(Y_train, predict_fn(X_train)))\n", + "print('Test accuracy: ', accuracy_score(Y_test, predict_fn(X_test)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize and fit anchor explainer for tabular data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "explainer = AnchorTabular(predict_fn, feature_names, categorical_names=category_map, seed=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Discretize the ordinal features into quartiles" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AnchorTabular(meta={\n", + " 'name': 'AnchorTabular',\n", + " 'type': ['blackbox'],\n", + " 'explanations': ['local'],\n", + " 'params': {'seed': 1, 'disc_perc': [25, 50, 75]}\n", + "})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explainer.fit(X_train, disc_perc=[25, 50, 75])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Getting an anchor\n", + "Below, we get an anchor for the prediction of the first observation in the test set. An anchor is a sufficient condition - that is, when the anchor holds, the prediction should be the same as the prediction for this instance." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: <=50K\n" + ] + } + ], + "source": [ + "idx = 0\n", + "class_names = adult.target_names\n", + "print('Prediction: ', class_names[explainer.predictor(X_test[idx].reshape(1, -1))[0]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We set the precision threshold to 0.95. This means that predictions on observations where the anchor holds will be the same as the prediction on the explained instance at least 95% of the time." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Anchor: Marital Status = Separated AND Sex = Female\n", + "Precision: 0.95\n", + "Coverage: 0.18\n" + ] + } + ], + "source": [ + "explanation = explainer.explain(X_test[idx], threshold=0.95)\n", + "print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n", + "print('Precision: %.2f' % explanation.precision)\n", + "print('Coverage: %.2f' % explanation.coverage)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ...or not?\n", + "Let's try getting an anchor for a different observation in the test set - one for the which the prediction is >50K." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction: >50K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Could not find an result satisfying the 0.95 precision constraint. Now returning the best non-eligible result.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Anchor: Capital Loss > 0.00 AND Relationship = Husband AND Marital Status = Married AND Age > 37.00 AND Race = White AND Country = United-States AND Sex = Male\n", + "Precision: 0.71\n", + "Coverage: 0.05\n" + ] + } + ], + "source": [ + "idx = 6\n", + "class_names = adult.target_names\n", + "print('Prediction: ', class_names[explainer.predictor(X_test[idx].reshape(1, -1))[0]])\n", + "\n", + "explanation = explainer.explain(X_test[idx], threshold=0.95)\n", + "print('Anchor: %s' % (' AND '.join(explanation.anchor)))\n", + "print('Precision: %.2f' % explanation.precision)\n", + "print('Coverage: %.2f' % explanation.coverage)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice how no anchor is found!\n", + "\n", + "This is due to the imbalanced dataset (roughly 25:75 high:low earner proportion), so during the sampling stage feature ranges corresponding to low-earners will be oversampled. This is a feature because it can point out an imbalanced dataset, but it can also be fixed by producing balanced datasets to enable anchors to be found for either class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 5290f4d6788fdb43b5d9e65a7b61cabc79544ac0 Mon Sep 17 00:00:00 2001 From: blublinsky Date: Mon, 6 Apr 2020 18:57:12 -0500 Subject: [PATCH 2/2] Add scikit-learn model save --- scikitLearn/python/IncomePrediction.ipynb | 123 ++++++++++++++-------- 1 file changed, 77 insertions(+), 46 deletions(-) diff --git a/scikitLearn/python/IncomePrediction.ipynb b/scikitLearn/python/IncomePrediction.ipynb index 12300de..f98ec68 100644 --- a/scikitLearn/python/IncomePrediction.ipynb +++ b/scikitLearn/python/IncomePrediction.ipynb @@ -20,76 +20,76 @@ "output_type": "stream", "text": [ "Requirement already up-to-date: pandas in ./.local/lib/python3.6/site-packages (1.0.3)\n", - "Requirement already satisfied, skipping upgrade: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.1)\n", "Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2019.3)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from pandas) (1.18.1)\n", "Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1)\n", "Requirement already satisfied, skipping upgrade: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.6.1->pandas) (1.11.0)\n", "Requirement already up-to-date: scikit-learn in ./.local/lib/python3.6/site-packages (0.22.2.post1)\n", - "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn) (0.14.1)\n", - "Requirement already satisfied, skipping upgrade: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.18.1)\n", "Requirement already satisfied, skipping upgrade: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.4.1)\n", + "Requirement already satisfied, skipping upgrade: numpy>=1.11.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn) (1.18.1)\n", + "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn) (0.14.1)\n", "Requirement already up-to-date: alibi in ./.local/lib/python3.6/site-packages (0.4.0)\n", - "Requirement already satisfied, skipping upgrade: scikit-learn in ./.local/lib/python3.6/site-packages (from alibi) (0.22.2.post1)\n", - "Requirement already satisfied, skipping upgrade: attrs in /usr/local/lib/python3.6/dist-packages (from alibi) (19.3.0)\n", - "Requirement already satisfied, skipping upgrade: beautifulsoup4 in ./.local/lib/python3.6/site-packages (from alibi) (4.8.2)\n", + "Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from alibi) (1.4.1)\n", + "Requirement already satisfied, skipping upgrade: prettyprinter in ./.local/lib/python3.6/site-packages (from alibi) (0.18.0)\n", "Requirement already satisfied, skipping upgrade: spacy in ./.local/lib/python3.6/site-packages (from alibi) (2.2.4)\n", + "Requirement already satisfied, skipping upgrade: Pillow in ./.local/lib/python3.6/site-packages (from alibi) (7.0.0)\n", + "Requirement already satisfied, skipping upgrade: scikit-learn in ./.local/lib/python3.6/site-packages (from alibi) (0.22.2.post1)\n", + "Requirement already satisfied, skipping upgrade: scikit-image in ./.local/lib/python3.6/site-packages (from alibi) (0.16.2)\n", "Requirement already satisfied, skipping upgrade: shap in ./.local/lib/python3.6/site-packages (from alibi) (0.35.0)\n", - "Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from alibi) (1.4.1)\n", - "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from alibi) (2.22.0)\n", "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from alibi) (1.18.1)\n", - "Requirement already satisfied, skipping upgrade: Pillow in ./.local/lib/python3.6/site-packages (from alibi) (7.0.0)\n", "Requirement already satisfied, skipping upgrade: tensorflow<2.0 in /usr/local/lib/python3.6/dist-packages (from alibi) (1.15.2)\n", + "Requirement already satisfied, skipping upgrade: attrs in /usr/local/lib/python3.6/dist-packages (from alibi) (19.3.0)\n", + "Requirement already satisfied, skipping upgrade: beautifulsoup4 in ./.local/lib/python3.6/site-packages (from alibi) (4.8.2)\n", + "Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from alibi) (2.22.0)\n", "Requirement already satisfied, skipping upgrade: pandas in ./.local/lib/python3.6/site-packages (from alibi) (1.0.3)\n", - "Requirement already satisfied, skipping upgrade: prettyprinter in ./.local/lib/python3.6/site-packages (from alibi) (0.18.0)\n", - "Requirement already satisfied, skipping upgrade: scikit-image in ./.local/lib/python3.6/site-packages (from alibi) (0.16.2)\n", - "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn->alibi) (0.14.1)\n", - "Requirement already satisfied, skipping upgrade: soupsieve>=1.2 in ./.local/lib/python3.6/site-packages (from beautifulsoup4->alibi) (2.0)\n", - "Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.2)\n", - "Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (3.0.2)\n", - "Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.1.3)\n", - "Requirement already satisfied, skipping upgrade: blis<0.5.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.4.1)\n", - "Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (2.0.3)\n", - "Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (4.43.0)\n", - "Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.0)\n", + "Requirement already satisfied, skipping upgrade: colorful>=0.4.0 in ./.local/lib/python3.6/site-packages (from prettyprinter->alibi) (0.5.4)\n", + "Requirement already satisfied, skipping upgrade: Pygments>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from prettyprinter->alibi) (2.5.2)\n", + "Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.6.0)\n", "Requirement already satisfied, skipping upgrade: thinc==7.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (7.4.0)\n", "Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.2)\n", + "Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.1.3)\n", + "Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (3.0.2)\n", + "Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.0)\n", + "Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (4.43.0)\n", + "Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (2.0.3)\n", + "Requirement already satisfied, skipping upgrade: blis<0.5.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.4.1)\n", "Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy->alibi) (45.1.0)\n", - "Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (0.6.0)\n", - "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (2019.11.28)\n", - "Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/lib/python3/dist-packages (from requests->alibi) (2.6)\n", - "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in ./.local/lib/python3.6/site-packages (from requests->alibi) (1.24.3)\n", - "Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (3.0.4)\n", - "Requirement already satisfied, skipping upgrade: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in ./.local/lib/python3.6/site-packages (from spacy->alibi) (1.0.2)\n", + "Requirement already satisfied, skipping upgrade: joblib>=0.11 in ./.local/lib/python3.6/site-packages (from scikit-learn->alibi) (0.14.1)\n", + "Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (1.1.1)\n", + "Requirement already satisfied, skipping upgrade: networkx>=2.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.4)\n", + "Requirement already satisfied, skipping upgrade: imageio>=2.3.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.8.0)\n", + "Requirement already satisfied, skipping upgrade: matplotlib!=3.0.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->alibi) (3.1.2)\n", + "Requirement already satisfied, skipping upgrade: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.11.2)\n", + "Requirement already satisfied, skipping upgrade: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.1.8)\n", + "Requirement already satisfied, skipping upgrade: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.26.0)\n", "Requirement already satisfied, skipping upgrade: gast==0.2.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.2.2)\n", - "Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/lib/python3/dist-packages (from tensorflow<2.0->alibi) (0.30.0)\n", + "Requirement already satisfied, skipping upgrade: tensorboard<1.16.0,>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.0)\n", "Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/lib/python3/dist-packages (from tensorflow<2.0->alibi) (1.11.0)\n", - "Requirement already satisfied, skipping upgrade: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: tensorflow-estimator==1.15.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.1)\n", "Requirement already satisfied, skipping upgrade: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.8.1)\n", + "Requirement already satisfied, skipping upgrade: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", + "Requirement already satisfied, skipping upgrade: wheel>=0.26; python_version >= \"3\" in /usr/lib/python3/dist-packages (from tensorflow<2.0->alibi) (0.30.0)\n", "Requirement already satisfied, skipping upgrade: keras-applications>=1.0.8 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.0.8)\n", - "Requirement already satisfied, skipping upgrade: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.1.0)\n", - "Requirement already satisfied, skipping upgrade: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.11.2)\n", - "Requirement already satisfied, skipping upgrade: tensorflow-estimator==1.15.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.1)\n", - "Requirement already satisfied, skipping upgrade: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.1.8)\n", - "Requirement already satisfied, skipping upgrade: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.11.2)\n", - "Requirement already satisfied, skipping upgrade: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.26.0)\n", - "Requirement already satisfied, skipping upgrade: tensorboard<1.16.0,>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.15.0)\n", + "Requirement already satisfied, skipping upgrade: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (1.1.0)\n", "Requirement already satisfied, skipping upgrade: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (0.9.0)\n", + "Requirement already satisfied, skipping upgrade: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.11.2)\n", + "Requirement already satisfied, skipping upgrade: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tensorflow<2.0->alibi) (3.1.0)\n", + "Requirement already satisfied, skipping upgrade: soupsieve>=1.2 in ./.local/lib/python3.6/site-packages (from beautifulsoup4->alibi) (2.0)\n", + "Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /usr/lib/python3/dist-packages (from requests->alibi) (2.6)\n", + "Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in ./.local/lib/python3.6/site-packages (from requests->alibi) (1.24.3)\n", + "Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (2019.11.28)\n", + "Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->alibi) (3.0.4)\n", "Requirement already satisfied, skipping upgrade: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas->alibi) (2019.3)\n", "Requirement already satisfied, skipping upgrade: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas->alibi) (2.8.1)\n", - "Requirement already satisfied, skipping upgrade: Pygments>=2.2.0 in /usr/local/lib/python3.6/dist-packages (from prettyprinter->alibi) (2.5.2)\n", - "Requirement already satisfied, skipping upgrade: colorful>=0.4.0 in ./.local/lib/python3.6/site-packages (from prettyprinter->alibi) (0.5.4)\n", - "Requirement already satisfied, skipping upgrade: networkx>=2.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.4)\n", - "Requirement already satisfied, skipping upgrade: imageio>=2.3.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (2.8.0)\n", - "Requirement already satisfied, skipping upgrade: matplotlib!=3.0.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image->alibi) (3.1.2)\n", - "Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in ./.local/lib/python3.6/site-packages (from scikit-image->alibi) (1.1.1)\n", "Requirement already satisfied, skipping upgrade: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy->alibi) (1.4.0)\n", - "Requirement already satisfied, skipping upgrade: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow<2.0->alibi) (2.10.0)\n", - "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (0.16.1)\n", - "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (3.1.1)\n", "Requirement already satisfied, skipping upgrade: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx>=2.0->scikit-image->alibi) (4.4.1)\n", "Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (1.1.0)\n", - "Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (0.10.0)\n", "Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (2.4.6)\n", + "Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image->alibi) (0.10.0)\n", + "Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (3.1.1)\n", + "Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.16.0,>=1.15.0->tensorflow<2.0->alibi) (0.16.1)\r\n", + "Requirement already satisfied, skipping upgrade: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.8->tensorflow<2.0->alibi) (2.10.0)\r\n", "Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy->alibi) (2.1.0)\r\n" ] } @@ -120,7 +120,8 @@ "from typing import Tuple, Union\n", "import requests\n", "from requests import RequestException\n", - "from io import BytesIO, StringIO" + "from io import BytesIO, StringIO\n", + "from joblib import dump" ] }, { @@ -646,6 +647,36 @@ "This is due to the imbalanced dataset (roughly 25:75 high:low earner proportion), so during the sampling stage feature ranges corresponding to low-earners will be oversampled. This is a feature because it can point out an imbalanced dataset, but it can also be fixed by producing balanced datasets to enable anchors to be found for either class." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exporting model" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "['income.joblib']" + ], + "text/plain": [ + "['income.joblib']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dump(clf, 'income.joblib')" + ] + }, { "cell_type": "code", "execution_count": null,