-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathClassifAI_ 2 - HW
More file actions
1 lines (1 loc) · 10.3 KB
/
ClassifAI_ 2 - HW
File metadata and controls
1 lines (1 loc) · 10.3 KB
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"14EW0Z-etrAqoe87V1N5fSrrBLnkjY4-T","timestamp":1653851905553}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":270},"id":"r25eSJK2IaJ4","executionInfo":{"status":"ok","timestamp":1660432366081,"user_tz":420,"elapsed":1603,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"e23bdce1-b722-4161-f6d4-9779729ab82a"},"source":["import sklearn\n","from sklearn.neighbors import KNeighborsClassifier\n","from sklearn import neighbors,linear_model, preprocessing\n","from sklearn.model_selection import train_test_split\n","from sklearn import datasets\n","import pandas as pd\n","from math import *\n","from sklearn.metrics import classification_report, accuracy_score\n","\n","dataset = sklearn.datasets.load_wine()\n","names = dataset[\"feature_names\"]\n","target = dataset[\"target\"]\n","target_names = dataset[\"target_names\"]\n","df = dataset[\"data\"]\n","names[-2] = \"diluted_wines_metric\"\n","df = pd.DataFrame(df)\n","df.columns = names\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[" alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols \\\n","0 14.23 1.71 2.43 15.6 127.0 2.80 \n","1 13.20 1.78 2.14 11.2 100.0 2.65 \n","2 13.16 2.36 2.67 18.6 101.0 2.80 \n","3 14.37 1.95 2.50 16.8 113.0 3.85 \n","4 13.24 2.59 2.87 21.0 118.0 2.80 \n","\n"," flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue \\\n","0 3.06 0.28 2.29 5.64 1.04 \n","1 2.76 0.26 1.28 4.38 1.05 \n","2 3.24 0.30 2.81 5.68 1.03 \n","3 3.49 0.24 2.18 7.80 0.86 \n","4 2.69 0.39 1.82 4.32 1.04 \n","\n"," diluted_wines_metric proline \n","0 3.92 1065.0 \n","1 3.40 1050.0 \n","2 3.17 1185.0 \n","3 3.45 1480.0 \n","4 2.93 735.0 "],"text/html":["\n"," <div id=\"df-9c3c08ea-0938-430d-8d89-d18129e1ae97\">\n"," <div class=\"colab-df-container\">\n"," <div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>alcohol</th>\n"," <th>malic_acid</th>\n"," <th>ash</th>\n"," <th>alcalinity_of_ash</th>\n"," <th>magnesium</th>\n"," <th>total_phenols</th>\n"," <th>flavanoids</th>\n"," <th>nonflavanoid_phenols</th>\n"," <th>proanthocyanins</th>\n"," <th>color_intensity</th>\n"," <th>hue</th>\n"," <th>diluted_wines_metric</th>\n"," <th>proline</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>0</th>\n"," <td>14.23</td>\n"," <td>1.71</td>\n"," <td>2.43</td>\n"," <td>15.6</td>\n"," <td>127.0</td>\n"," <td>2.80</td>\n"," <td>3.06</td>\n"," <td>0.28</td>\n"," <td>2.29</td>\n"," <td>5.64</td>\n"," <td>1.04</td>\n"," <td>3.92</td>\n"," <td>1065.0</td>\n"," </tr>\n"," <tr>\n"," <th>1</th>\n"," <td>13.20</td>\n"," <td>1.78</td>\n"," <td>2.14</td>\n"," <td>11.2</td>\n"," <td>100.0</td>\n"," <td>2.65</td>\n"," <td>2.76</td>\n"," <td>0.26</td>\n"," <td>1.28</td>\n"," <td>4.38</td>\n"," <td>1.05</td>\n"," <td>3.40</td>\n"," <td>1050.0</td>\n"," </tr>\n"," <tr>\n"," <th>2</th>\n"," <td>13.16</td>\n"," <td>2.36</td>\n"," <td>2.67</td>\n"," <td>18.6</td>\n"," <td>101.0</td>\n"," <td>2.80</td>\n"," <td>3.24</td>\n"," <td>0.30</td>\n"," <td>2.81</td>\n"," <td>5.68</td>\n"," <td>1.03</td>\n"," <td>3.17</td>\n"," <td>1185.0</td>\n"," </tr>\n"," <tr>\n"," <th>3</th>\n"," <td>14.37</td>\n"," <td>1.95</td>\n"," <td>2.50</td>\n"," <td>16.8</td>\n"," <td>113.0</td>\n"," <td>3.85</td>\n"," <td>3.49</td>\n"," <td>0.24</td>\n"," <td>2.18</td>\n"," <td>7.80</td>\n"," <td>0.86</td>\n"," <td>3.45</td>\n"," <td>1480.0</td>\n"," </tr>\n"," <tr>\n"," <th>4</th>\n"," <td>13.24</td>\n"," <td>2.59</td>\n"," <td>2.87</td>\n"," <td>21.0</td>\n"," <td>118.0</td>\n"," <td>2.80</td>\n"," <td>2.69</td>\n"," <td>0.39</td>\n"," <td>1.82</td>\n"," <td>4.32</td>\n"," <td>1.04</td>\n"," <td>2.93</td>\n"," <td>735.0</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>\n"," <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-9c3c08ea-0938-430d-8d89-d18129e1ae97')\"\n"," title=\"Convert this dataframe to an interactive table.\"\n"," style=\"display:none;\">\n"," \n"," <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n"," width=\"24px\">\n"," <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n"," <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n"," </svg>\n"," </button>\n"," \n"," <style>\n"," .colab-df-container {\n"," display:flex;\n"," flex-wrap:wrap;\n"," gap: 12px;\n"," }\n","\n"," .colab-df-convert {\n"," background-color: #E8F0FE;\n"," border: none;\n"," border-radius: 50%;\n"," cursor: pointer;\n"," display: none;\n"," fill: #1967D2;\n"," height: 32px;\n"," padding: 0 0 0 0;\n"," width: 32px;\n"," }\n","\n"," .colab-df-convert:hover {\n"," background-color: #E2EBFA;\n"," box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n"," fill: #174EA6;\n"," }\n","\n"," [theme=dark] .colab-df-convert {\n"," background-color: #3B4455;\n"," fill: #D2E3FC;\n"," }\n","\n"," [theme=dark] .colab-df-convert:hover {\n"," background-color: #434B5C;\n"," box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n"," filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n"," fill: #FFFFFF;\n"," }\n"," </style>\n","\n"," <script>\n"," const buttonEl =\n"," document.querySelector('#df-9c3c08ea-0938-430d-8d89-d18129e1ae97 button.colab-df-convert');\n"," buttonEl.style.display =\n"," google.colab.kernel.accessAllowed ? 'block' : 'none';\n","\n"," async function convertToInteractive(key) {\n"," const element = document.querySelector('#df-9c3c08ea-0938-430d-8d89-d18129e1ae97');\n"," const dataTable =\n"," await google.colab.kernel.invokeFunction('convertToInteractive',\n"," [key], {});\n"," if (!dataTable) return;\n","\n"," const docLinkHtml = 'Like what you see? Visit the ' +\n"," '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n"," + ' to learn more about interactive tables.';\n"," element.innerHTML = '';\n"," dataTable['output_type'] = 'display_data';\n"," await google.colab.output.renderOutput(dataTable, element);\n"," const docLink = document.createElement('div');\n"," docLink.innerHTML = docLinkHtml;\n"," element.appendChild(docLink);\n"," }\n"," </script>\n"," </div>\n"," </div>\n"," "]},"metadata":{},"execution_count":1}]},{"cell_type":"code","metadata":{"id":"HZNb-rJ7IfkN","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1660432366331,"user_tz":420,"elapsed":255,"user":{"displayName":"Leo Huang","userId":"16558901284710269921"}},"outputId":"cc62fa7e-ef3d-4bc0-bbb8-c977ba382077"},"source":["###TODO: Create a KNN model based on this dataset :)\n","\n","#Step 1: Create X, y (What should they be?)\n","X = df.iloc[:, :12].values\n","y = target\n","\n","#Step 2: Apply train_test_split (X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = constant))\n","X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.4)\n","\n","#Step 3: Create the model, apply sklearn KNN Model\n","model = KNeighborsClassifier(n_neighbors=5)\n","model.fit(X_train, y_train)\n","\n","#Step 4: Create a y_pred (prediction variable) to see results\n","y_pred = model.predict(X_test)\n","\n","#Step 5: Call certain functions to see the classification report and accuracy\n","result = classification_report(y_test, y_pred)\n","print(\"Classification report:\")\n","print(result)\n","accuracy = accuracy_score(y_test, y_pred)\n","print(\"Accuracy:\", accuracy)\n","\n","\n","#Try to aim for 80% Accuracy!"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Classification report:\n"," precision recall f1-score support\n","\n"," 0 0.79 0.81 0.80 27\n"," 1 0.66 0.95 0.78 20\n"," 2 1.00 0.60 0.75 25\n","\n"," accuracy 0.78 72\n"," macro avg 0.81 0.79 0.78 72\n","weighted avg 0.82 0.78 0.78 72\n","\n","Accuracy: 0.7777777777777778\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"7Y8z_2lj9KLd"},"execution_count":null,"outputs":[]}]}