|
215 | 215 | "source": [ |
216 | 216 | "This kind of cluster analysis has stochastic aspects, so results may differ on different runs. Generally, the plot shows a \"knee\" in the curve near `k=7` as the decrease in error begins to level out. That's a reasonable number of clusters, such that each cluster will tend to have ~14% of the items. That choice has an inherent trade-off:\n", |
217 | 217 | "\n", |
218 | | - " * too few clusters → poor predictions (less accuracy)\n", |
| 218 | + " * too few clusters → poor predictions (less precision)\n", |
219 | 219 | " * too many clusters → poor predictive power (less recall)\n", |
220 | 220 | "\n", |
221 | 221 | "Now we can run K-means in `scikit-learn` with that hyperparameter `k=7` to get the clusters that we'll use in our RL environment:" |
|
881 | 881 | " ]\n", |
882 | 882 | "\n", |
883 | 883 | " df.loc[len(df)] = row\n", |
884 | | - " print(status.format(*row))" |
885 | | - ] |
886 | | - }, |
887 | | - { |
888 | | - "cell_type": "markdown", |
889 | | - "metadata": {}, |
890 | | - "source": [ |
891 | | - "The learning is stochastic and not guaranteed to improve *monotonically*, i.e., increase the min/mean/max rewards per episode in every training iterations.\n", |
892 | | - "We can use a [*pareto archive*](https://ieeexplore.ieee.org/document/781913) to find a *non-dominated* solution.\n", |
893 | | - "In other words, among the saved checkpoints of trained policies, which have the best mean rewards per episode, and among those which have the best min and max rewards?\n", |
894 | | - "The following code uses the [`paretoset`](https://github.com/tommyod/paretoset) Python implementation to select the best checkpoint:" |
895 | | - ] |
896 | | - }, |
897 | | - { |
898 | | - "cell_type": "code", |
899 | | - "execution_count": null, |
900 | | - "metadata": {}, |
901 | | - "outputs": [], |
902 | | - "source": [ |
903 | | - "from paretoset import paretoset\n", |
904 | | - "\n", |
905 | | - "df_front = df.drop(columns=[\"steps\", \"checkpoint\"])\n", |
906 | | - "mask = paretoset(df_front, sense=[\"max\", \"max\", \"max\"])\n", |
907 | | - "\n", |
908 | | - "optimal = df_front[mask]\n", |
909 | | - "max_val = optimal[\"avg_reward\"].max()\n", |
910 | | - "\n", |
911 | | - "BEST_CHECKPOINT = df.loc[df[\"avg_reward\"] == max_val, \"checkpoint\"].values[0]\n", |
912 | | - "print(\"best checkpoint:\", BEST_CHECKPOINT)" |
| 884 | + " print(status.format(*row))\n", |
| 885 | + " \n", |
| 886 | + "BEST_CHECKPOINT = checkpoint_file" |
913 | 887 | ] |
914 | 888 | }, |
915 | 889 | { |
|
1077 | 1051 | "source": [ |
1078 | 1052 | "ray.shutdown()" |
1079 | 1053 | ] |
1080 | | - }, |
1081 | | - { |
1082 | | - "cell_type": "code", |
1083 | | - "execution_count": null, |
1084 | | - "metadata": {}, |
1085 | | - "outputs": [], |
1086 | | - "source": [] |
1087 | 1054 | } |
1088 | 1055 | ], |
1089 | 1056 | "metadata": { |
|
1102 | 1069 | "name": "python", |
1103 | 1070 | "nbconvert_exporter": "python", |
1104 | 1071 | "pygments_lexer": "ipython3", |
1105 | | - "version": "3.7.7" |
| 1072 | + "version": "3.7.4" |
1106 | 1073 | } |
1107 | 1074 | }, |
1108 | 1075 | "nbformat": 4, |
|
0 commit comments