Skip to content

Commit 82102b6

Browse files
authored
Merge pull request #26 from anyscale/recsys-fix
Recsys fix
2 parents d441b73 + 9e6acb5 commit 82102b6

File tree

2 files changed

+6
-41
lines changed

2 files changed

+6
-41
lines changed

ray-rllib/recsys/01-Recsys.ipynb

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@
215215
"source": [
216216
"This kind of cluster analysis has stochastic aspects, so results may differ on different runs. Generally, the plot shows a \"knee\" in the curve near `k=7` as the decrease in error begins to level out. That's a reasonable number of clusters, such that each cluster will tend to have ~14% of the items. That choice has an inherent trade-off:\n",
217217
"\n",
218-
" * too few clusters → poor predictions (less accuracy)\n",
218+
" * too few clusters → poor predictions (less precision)\n",
219219
" * too many clusters → poor predictive power (less recall)\n",
220220
"\n",
221221
"Now we can run K-means in `scikit-learn` with that hyperparameter `k=7` to get the clusters that we'll use in our RL environment:"
@@ -881,35 +881,9 @@
881881
" ]\n",
882882
"\n",
883883
" df.loc[len(df)] = row\n",
884-
" print(status.format(*row))"
885-
]
886-
},
887-
{
888-
"cell_type": "markdown",
889-
"metadata": {},
890-
"source": [
891-
"The learning is stochastic and not guaranteed to improve *monotonically*, i.e., increase the min/mean/max rewards per episode in every training iterations.\n",
892-
"We can use a [*pareto archive*](https://ieeexplore.ieee.org/document/781913) to find a *non-dominated* solution.\n",
893-
"In other words, among the saved checkpoints of trained policies, which have the best mean rewards per episode, and among those which have the best min and max rewards?\n",
894-
"The following code uses the [`paretoset`](https://github.com/tommyod/paretoset) Python implementation to select the best checkpoint:"
895-
]
896-
},
897-
{
898-
"cell_type": "code",
899-
"execution_count": null,
900-
"metadata": {},
901-
"outputs": [],
902-
"source": [
903-
"from paretoset import paretoset\n",
904-
"\n",
905-
"df_front = df.drop(columns=[\"steps\", \"checkpoint\"])\n",
906-
"mask = paretoset(df_front, sense=[\"max\", \"max\", \"max\"])\n",
907-
"\n",
908-
"optimal = df_front[mask]\n",
909-
"max_val = optimal[\"avg_reward\"].max()\n",
910-
"\n",
911-
"BEST_CHECKPOINT = df.loc[df[\"avg_reward\"] == max_val, \"checkpoint\"].values[0]\n",
912-
"print(\"best checkpoint:\", BEST_CHECKPOINT)"
884+
" print(status.format(*row))\n",
885+
" \n",
886+
"BEST_CHECKPOINT = checkpoint_file"
913887
]
914888
},
915889
{
@@ -1077,13 +1051,6 @@
10771051
"source": [
10781052
"ray.shutdown()"
10791053
]
1080-
},
1081-
{
1082-
"cell_type": "code",
1083-
"execution_count": null,
1084-
"metadata": {},
1085-
"outputs": [],
1086-
"source": []
10871054
}
10881055
],
10891056
"metadata": {
@@ -1102,7 +1069,7 @@
11021069
"name": "python",
11031070
"nbconvert_exporter": "python",
11041071
"pygments_lexer": "ipython3",
1105-
"version": "3.7.7"
1072+
"version": "3.7.4"
11061073
}
11071074
},
11081075
"nbformat": 4,

requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
gym >= 0.17.2
2-
paretoset >= 1.1.2
32
numpy >= 1.18.5
43
pandas
54
requests
6-
pytorch
5+
torch
76
torchvision
87
tensorboard >= 2.3.0
98
tensorflow >= 2.3.0
@@ -18,7 +17,6 @@ jupyterlab
1817
jupyter-server-proxy
1918
beautifulsoup4
2019
lxml
21-
setproctitle
2220
pytz
2321
ray[all]==0.8.7
2422
atoma

0 commit comments

Comments
 (0)