|
6 | 6 | "metadata": {}, |
7 | 7 | "outputs": [], |
8 | 8 | "source": [ |
9 | | - "import numpy as np\n", |
| 9 | + "%pip install -U open_spiel" |
| 10 | + ] |
| 11 | + }, |
| 12 | + { |
| 13 | + "cell_type": "code", |
| 14 | + "execution_count": 1, |
| 15 | + "metadata": {}, |
| 16 | + "outputs": [ |
| 17 | + { |
| 18 | + "name": "stdout", |
| 19 | + "output_type": "stream", |
| 20 | + "text": [ |
| 21 | + "Optional module pokerkit_wrapper was not importable: No module named 'pokerkit'\n" |
| 22 | + ] |
| 23 | + } |
| 24 | + ], |
| 25 | + "source": [ |
10 | 26 | "import pyspiel \n", |
11 | | - "import tensorflow.compat.v1 as tf\n", |
| 27 | + "import functools\n", |
12 | 28 | "import torch \n", |
13 | 29 | "import torch.nn as nn\n", |
14 | 30 | "\n", |
15 | | - "import algorithms.rcfr as rcfr_tf\n", |
16 | | - "import pytorch.rcfr as rcfr_pt\n", |
17 | | - "tf.disable_v2_behavior()\n", |
| 31 | + "import jax\n", |
| 32 | + "import flax.nnx as nnx\n", |
| 33 | + "import optax\n", |
18 | 34 | "\n", |
19 | | - "tf.enable_eager_execution()\n", |
| 35 | + "from open_spiel.python.pytorch import rcfr as rcfr_pt\n", |
| 36 | + "from open_spiel.python.jax import rcfr as rcfr_jax\n", |
20 | 37 | "\n", |
21 | | - "_GAME = pyspiel.load_game('kuhn_poker')\n", |
22 | | - "_BATCH_SIZE = 12" |
| 38 | + "game = pyspiel.load_game('kuhn_poker')\n", |
| 39 | + "batch_size = 12" |
23 | 40 | ] |
24 | 41 | }, |
25 | 42 | { |
26 | 43 | "cell_type": "code", |
27 | | - "execution_count": null, |
| 44 | + "execution_count": 2, |
28 | 45 | "metadata": {}, |
29 | 46 | "outputs": [], |
30 | 47 | "source": [ |
31 | | - "def tnsorflow_example(game_name, num_epochs, iterations):\n", |
| 48 | + "def flax_example(game_name, num_epochs, iterations):\n", |
| 49 | + " \n", |
| 50 | + " @nnx.vmap(in_axes=(None, 0), out_axes=0)\n", |
| 51 | + " def forward(model: nnx.Module, x: jax.Array) -> jax.Array:\n", |
| 52 | + " \"\"\"Batched call for the flax.nnx model.\"\"\"\n", |
| 53 | + " return model(x)\n", |
| 54 | + " \n", |
| 55 | + " @functools.partial(jax.jit, static_argnames=(\"graphdef\",))\n", |
| 56 | + " def jax_train_step(\n", |
| 57 | + " graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array, y: jax.Array\n", |
| 58 | + " ) -> tuple:\n", |
| 59 | + " \"\"\"Train step in pure jax.\"\"\"\n", |
| 60 | + "\n", |
| 61 | + " model, optimizer = nnx.merge(graphdef, state, copy=True)\n", |
| 62 | + "\n", |
| 63 | + " def loss_fn(model):\n", |
| 64 | + " y_pred = forward(model, x)\n", |
| 65 | + " return optax.hinge_loss(y_pred, y).mean()\n", |
| 66 | + "\n", |
| 67 | + " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", |
| 68 | + " optimizer.update(model, grads)\n", |
| 69 | + " state = nnx.state((model, optimizer))\n", |
| 70 | + " return loss, state\n", |
| 71 | + "\n", |
32 | 72 | " game = pyspiel.load_game(game_name)\n", |
33 | 73 | "\n", |
34 | 74 | " models = []\n", |
35 | 75 | " for _ in range(game.num_players()):\n", |
36 | 76 | " models.append(\n", |
37 | | - " rcfr_tf.DeepRcfrModel(\n", |
| 77 | + " rcfr_jax.DeepRcfrModel(\n", |
38 | 78 | " game,\n", |
39 | 79 | " num_hidden_layers=1,\n", |
40 | 80 | " num_hidden_units=13,\n", |
41 | 81 | " num_hidden_factors=8,\n", |
42 | 82 | " use_skip_connections=True))\n", |
43 | 83 | "\n", |
| 84 | + " # these parameters are fixed initially\n", |
44 | 85 | " buffer_size = -1\n", |
45 | 86 | " truncate_negative = False\n", |
46 | 87 | " bootstrap = False\n", |
| 88 | + "\n", |
47 | 89 | " if buffer_size > 0:\n", |
48 | | - " solver = rcfr_tf.ReservoirRcfrSolver(\n", |
| 90 | + " solver = rcfr_jax.ReservoirRcfrSolver(\n", |
49 | 91 | " game,\n", |
50 | 92 | " models,\n", |
51 | 93 | " buffer_size,\n", |
52 | 94 | " truncate_negative=truncate_negative)\n", |
53 | 95 | " else:\n", |
54 | | - " solver = rcfr_tf.RcfrSolver(\n", |
| 96 | + " solver = rcfr_jax.RcfrSolver(\n", |
55 | 97 | " game,\n", |
56 | 98 | " models,\n", |
57 | 99 | " truncate_negative=truncate_negative,\n", |
58 | 100 | " bootstrap=bootstrap)\n", |
| 101 | + " \n", |
| 102 | + " batch_size = 12\n", |
| 103 | + " step_size = 0.01\n", |
59 | 104 | "\n", |
60 | | - " def _train_fn(model, data):\n", |
61 | | - " \"\"\"Train `model` on `data`.\"\"\"\n", |
62 | | - " batch_size = 100\n", |
63 | | - " step_size = 0.01\n", |
64 | | - " data = data.shuffle(batch_size * 10)\n", |
65 | | - " data = data.batch(batch_size)\n", |
66 | | - " data = data.repeat(num_epochs)\n", |
| 105 | + " def _train_fn(model: nn.Module, data: tuple) -> None:\n", |
67 | 106 | "\n", |
68 | | - " optimizer = tf.keras.optimizers.Adam(lr=step_size, amsgrad=True)\n", |
| 107 | + " \"\"\"Train `model` on `data`.\"\"\"\n", |
| 108 | + " data_, rng = data\n", |
| 109 | + " optimizer = nnx.Optimizer(\n", |
| 110 | + " model, optax.amsgrad(learning_rate=step_size), wrt=nnx.Param\n", |
| 111 | + " )\n", |
| 112 | + " graphdef, state = nnx.split((model, optimizer))\n", |
69 | 113 | "\n", |
70 | | - " @tf.function\n", |
71 | | - " def _train():\n", |
72 | | - " for x, y in data:\n", |
73 | | - " optimizer.minimize(\n", |
74 | | - " lambda: tf.losses.huber_loss(y, model(x), delta=0.01), # pylint: disable=cell-var-from-loop\n", |
75 | | - " model.trainable_variables)\n", |
| 114 | + " num_batches = len(data_[0]) // batch_size\n", |
| 115 | + " data_ = jax.tree.map(\n", |
| 116 | + " lambda x: jax.random.permutation(rng, x, axis=0).reshape(\n", |
| 117 | + " num_batches, batch_size, -1\n", |
| 118 | + " ),\n", |
| 119 | + " data_,\n", |
| 120 | + " )\n", |
76 | 121 | "\n", |
77 | | - " _train()\n", |
| 122 | + " for _ in range(num_epochs):\n", |
| 123 | + " for x, y in zip(*data_):\n", |
| 124 | + " _, state = jax_train_step(graphdef, state, x, y.squeeze(-1))\n", |
78 | 125 | "\n", |
79 | | - " # End of _train_fn\n", |
| 126 | + " nnx.update((model, optimizer), state)\n", |
| 127 | + " return\n", |
| 128 | + " \n", |
80 | 129 | " result = []\n", |
81 | 130 | " for i in range(iterations):\n", |
82 | | - " solver.evaluate_and_update_policy(_train_fn)\n", |
| 131 | + " solver.evaluate_and_update_policy(_train_fn, jax.random.key(i))\n", |
83 | 132 | " if i % 10 == 0:\n", |
84 | 133 | " conv = pyspiel.exploitability(game, solver.average_policy())\n", |
85 | 134 | " result.append(conv)\n", |
|
89 | 138 | }, |
90 | 139 | { |
91 | 140 | "cell_type": "code", |
92 | | - "execution_count": null, |
| 141 | + "execution_count": 3, |
93 | 142 | "metadata": {}, |
94 | 143 | "outputs": [], |
95 | 144 | "source": [ |
|
125 | 174 | " def _train_fn(model, data):\n", |
126 | 175 | " \"\"\"Train `model` on `data`.\"\"\"\n", |
127 | 176 | " batch_size = 100\n", |
128 | | - " num_epochs = 20\n", |
129 | 177 | " step_size = 0.01\n", |
130 | 178 | " \n", |
131 | 179 | " data = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n", |
132 | 180 | " loss_fn = nn.SmoothL1Loss()\n", |
133 | 181 | " optimizer = torch.optim.Adam(model.parameters(), lr=step_size, amsgrad=True)\n", |
134 | 182 | "\n", |
135 | 183 | " def _train(model, data):\n", |
136 | | - " for epoch in range(num_epochs):\n", |
| 184 | + " for _ in range(num_epochs):\n", |
137 | 185 | " for x, y in data:\n", |
138 | 186 | " optimizer.zero_grad()\n", |
139 | 187 | " output = model(x)\n", |
|
150 | 198 | " if i % 10 == 0:\n", |
151 | 199 | " conv = pyspiel.exploitability(game, solver.average_policy())\n", |
152 | 200 | " result.append(conv)\n", |
153 | | - " # print(\"Iteration {} exploitability {}\".format(i, conv))\n", |
154 | 201 | " return result" |
155 | 202 | ] |
156 | 203 | }, |
|
160 | 207 | "metadata": {}, |
161 | 208 | "outputs": [], |
162 | 209 | "source": [ |
163 | | - "tensorflow_rcfr = []\n", |
| 210 | + "flax_rcfr = []\n", |
164 | 211 | "pytorch_rcfr = []\n", |
165 | 212 | "num_epochs, iterations = 20, 100\n", |
166 | 213 | "for _ in range(10):\n", |
167 | | - " tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n", |
| 214 | + " flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))\n", |
168 | 215 | " pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))" |
169 | 216 | ] |
170 | 217 | }, |
|
177 | 224 | "import matplotlib.pyplot as plt\n", |
178 | 225 | "\n", |
179 | 226 | "x = [i for i in range(10)]\n", |
180 | | - "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n", |
| 227 | + "flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n", |
181 | 228 | "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n", |
182 | 229 | "\n", |
183 | | - "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n", |
| 230 | + "plt.plot(x, flax_exploitability, label=\"flax.nnx\")\n", |
184 | 231 | "plt.plot(x, pt_exploitability, label=\"pytorch\")\n", |
185 | 232 | "\n", |
186 | 233 | "plt.legend()\n", |
|
194 | 241 | "metadata": {}, |
195 | 242 | "outputs": [], |
196 | 243 | "source": [ |
197 | | - "tensorflow_rcfr = []\n", |
| 244 | + "flax_rcfr = []\n", |
198 | 245 | "pytorch_rcfr = []\n", |
199 | | - "num_epochs, iterations = 200, 100\n", |
| 246 | + "num_epochs, iterations = 200, 100 \n", |
200 | 247 | "for _ in range(10):\n", |
201 | | - " tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n", |
| 248 | + " flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))\n", |
202 | 249 | " pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))" |
203 | 250 | ] |
204 | 251 | }, |
|
211 | 258 | "import matplotlib.pyplot as plt\n", |
212 | 259 | "\n", |
213 | 260 | "x = [i for i in range(10)]\n", |
214 | | - "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n", |
| 261 | + "flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n", |
215 | 262 | "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n", |
216 | 263 | "\n", |
217 | | - "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n", |
| 264 | + "plt.plot(x, flax_exploitability, label=\"flax_nnx\")\n", |
218 | 265 | "plt.plot(x, pt_exploitability, label=\"pytorch\")\n", |
219 | 266 | "\n", |
220 | 267 | "plt.legend()\n", |
|
228 | 275 | "metadata": {}, |
229 | 276 | "outputs": [], |
230 | 277 | "source": [ |
231 | | - "tensorflow_rcfr = []\n", |
| 278 | + "flax_rcfr = []\n", |
232 | 279 | "pytorch_rcfr = []\n", |
233 | 280 | "num_epochs, iterations = 20, 100\n", |
234 | 281 | "for _ in range(10):\n", |
235 | | - " tensorflow_rcfr.append(tnsorflow_example('leduc_poker', num_epochs, iterations))\n", |
| 282 | + " flax_rcfr.append(flax_example('leduc_poker', num_epochs, iterations))\n", |
236 | 283 | " pytorch_rcfr.append(pytorch_example('leduc_poker', num_epochs, iterations))" |
237 | 284 | ] |
238 | 285 | }, |
|
245 | 292 | "import matplotlib.pyplot as plt\n", |
246 | 293 | "\n", |
247 | 294 | "x = [i for i in range(10)]\n", |
248 | | - "tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n", |
| 295 | + "flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n", |
249 | 296 | "pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n", |
250 | 297 | "\n", |
251 | | - "plt.plot(x, tf_exploitability, label=\"tensorflow\")\n", |
| 298 | + "plt.plot(x, flax_exploitability, label=\"flax_nnx\")\n", |
252 | 299 | "plt.plot(x, pt_exploitability, label=\"pytorch\")\n", |
253 | 300 | "\n", |
254 | 301 | "plt.legend()\n", |
|
266 | 313 | ], |
267 | 314 | "metadata": { |
268 | 315 | "kernelspec": { |
269 | | - "display_name": "Python 3", |
| 316 | + "display_name": "open_spiel", |
270 | 317 | "language": "python", |
271 | 318 | "name": "python3" |
272 | 319 | }, |
|
280 | 327 | "name": "python", |
281 | 328 | "nbconvert_exporter": "python", |
282 | 329 | "pygments_lexer": "ipython3", |
283 | | - "version": "3.7.3" |
| 330 | + "version": "3.12.11" |
284 | 331 | } |
285 | 332 | }, |
286 | 333 | "nbformat": 4, |
|
0 commit comments