Skip to content

Commit bd7b303

Browse files
committed
Added jax version + tests
1 parent 3cb4d43 commit bd7b303

File tree

7 files changed

+1559
-60
lines changed

7 files changed

+1559
-60
lines changed
Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,80 +6,129 @@
66
"metadata": {},
77
"outputs": [],
88
"source": [
9-
"import numpy as np\n",
9+
"%pip install -U open_spiel"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 1,
15+
"metadata": {},
16+
"outputs": [
17+
{
18+
"name": "stdout",
19+
"output_type": "stream",
20+
"text": [
21+
"Optional module pokerkit_wrapper was not importable: No module named 'pokerkit'\n"
22+
]
23+
}
24+
],
25+
"source": [
1026
"import pyspiel \n",
11-
"import tensorflow.compat.v1 as tf\n",
27+
"import functools\n",
1228
"import torch \n",
1329
"import torch.nn as nn\n",
1430
"\n",
15-
"import algorithms.rcfr as rcfr_tf\n",
16-
"import pytorch.rcfr as rcfr_pt\n",
17-
"tf.disable_v2_behavior()\n",
31+
"import jax\n",
32+
"import flax.nnx as nnx\n",
33+
"import optax\n",
1834
"\n",
19-
"tf.enable_eager_execution()\n",
35+
"from open_spiel.python.pytorch import rcfr as rcfr_pt\n",
36+
"from open_spiel.python.jax import rcfr as rcfr_jax\n",
2037
"\n",
21-
"_GAME = pyspiel.load_game('kuhn_poker')\n",
22-
"_BATCH_SIZE = 12"
38+
"game = pyspiel.load_game('kuhn_poker')\n",
39+
"batch_size = 12"
2340
]
2441
},
2542
{
2643
"cell_type": "code",
27-
"execution_count": null,
44+
"execution_count": 2,
2845
"metadata": {},
2946
"outputs": [],
3047
"source": [
31-
"def tnsorflow_example(game_name, num_epochs, iterations):\n",
48+
"def flax_example(game_name, num_epochs, iterations):\n",
49+
" \n",
50+
" @nnx.vmap(in_axes=(None, 0), out_axes=0)\n",
51+
" def forward(model: nnx.Module, x: jax.Array) -> jax.Array:\n",
52+
" \"\"\"Batched call for the flax.nnx model.\"\"\"\n",
53+
" return model(x)\n",
54+
" \n",
55+
" @functools.partial(jax.jit, static_argnames=(\"graphdef\",))\n",
56+
" def jax_train_step(\n",
57+
" graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array, y: jax.Array\n",
58+
" ) -> tuple:\n",
59+
" \"\"\"Train step in pure jax.\"\"\"\n",
60+
"\n",
61+
" model, optimizer = nnx.merge(graphdef, state, copy=True)\n",
62+
"\n",
63+
" def loss_fn(model):\n",
64+
" y_pred = forward(model, x)\n",
65+
" return optax.hinge_loss(y_pred, y).mean()\n",
66+
"\n",
67+
" loss, grads = nnx.value_and_grad(loss_fn)(model)\n",
68+
" optimizer.update(model, grads)\n",
69+
" state = nnx.state((model, optimizer))\n",
70+
" return loss, state\n",
71+
"\n",
3272
" game = pyspiel.load_game(game_name)\n",
3373
"\n",
3474
" models = []\n",
3575
" for _ in range(game.num_players()):\n",
3676
" models.append(\n",
37-
" rcfr_tf.DeepRcfrModel(\n",
77+
" rcfr_jax.DeepRcfrModel(\n",
3878
" game,\n",
3979
" num_hidden_layers=1,\n",
4080
" num_hidden_units=13,\n",
4181
" num_hidden_factors=8,\n",
4282
" use_skip_connections=True))\n",
4383
"\n",
84+
" # these parameters are fixed initially\n",
4485
" buffer_size = -1\n",
4586
" truncate_negative = False\n",
4687
" bootstrap = False\n",
88+
"\n",
4789
" if buffer_size > 0:\n",
48-
" solver = rcfr_tf.ReservoirRcfrSolver(\n",
90+
" solver = rcfr_jax.ReservoirRcfrSolver(\n",
4991
" game,\n",
5092
" models,\n",
5193
" buffer_size,\n",
5294
" truncate_negative=truncate_negative)\n",
5395
" else:\n",
54-
" solver = rcfr_tf.RcfrSolver(\n",
96+
" solver = rcfr_jax.RcfrSolver(\n",
5597
" game,\n",
5698
" models,\n",
5799
" truncate_negative=truncate_negative,\n",
58100
" bootstrap=bootstrap)\n",
101+
" \n",
102+
" batch_size = 12\n",
103+
" step_size = 0.01\n",
59104
"\n",
60-
" def _train_fn(model, data):\n",
61-
" \"\"\"Train `model` on `data`.\"\"\"\n",
62-
" batch_size = 100\n",
63-
" step_size = 0.01\n",
64-
" data = data.shuffle(batch_size * 10)\n",
65-
" data = data.batch(batch_size)\n",
66-
" data = data.repeat(num_epochs)\n",
105+
" def _train_fn(model: nn.Module, data: tuple) -> None:\n",
67106
"\n",
68-
" optimizer = tf.keras.optimizers.Adam(lr=step_size, amsgrad=True)\n",
107+
" \"\"\"Train `model` on `data`.\"\"\"\n",
108+
" data_, rng = data\n",
109+
" optimizer = nnx.Optimizer(\n",
110+
" model, optax.amsgrad(learning_rate=step_size), wrt=nnx.Param\n",
111+
" )\n",
112+
" graphdef, state = nnx.split((model, optimizer))\n",
69113
"\n",
70-
" @tf.function\n",
71-
" def _train():\n",
72-
" for x, y in data:\n",
73-
" optimizer.minimize(\n",
74-
" lambda: tf.losses.huber_loss(y, model(x), delta=0.01), # pylint: disable=cell-var-from-loop\n",
75-
" model.trainable_variables)\n",
114+
" num_batches = len(data_[0]) // batch_size\n",
115+
" data_ = jax.tree.map(\n",
116+
" lambda x: jax.random.permutation(rng, x, axis=0).reshape(\n",
117+
" num_batches, batch_size, -1\n",
118+
" ),\n",
119+
" data_,\n",
120+
" )\n",
76121
"\n",
77-
" _train()\n",
122+
" for _ in range(num_epochs):\n",
123+
" for x, y in zip(*data_):\n",
124+
" _, state = jax_train_step(graphdef, state, x, y.squeeze(-1))\n",
78125
"\n",
79-
" # End of _train_fn\n",
126+
" nnx.update((model, optimizer), state)\n",
127+
" return\n",
128+
" \n",
80129
" result = []\n",
81130
" for i in range(iterations):\n",
82-
" solver.evaluate_and_update_policy(_train_fn)\n",
131+
" solver.evaluate_and_update_policy(_train_fn, jax.random.key(i))\n",
83132
" if i % 10 == 0:\n",
84133
" conv = pyspiel.exploitability(game, solver.average_policy())\n",
85134
" result.append(conv)\n",
@@ -89,7 +138,7 @@
89138
},
90139
{
91140
"cell_type": "code",
92-
"execution_count": null,
141+
"execution_count": 3,
93142
"metadata": {},
94143
"outputs": [],
95144
"source": [
@@ -125,15 +174,14 @@
125174
" def _train_fn(model, data):\n",
126175
" \"\"\"Train `model` on `data`.\"\"\"\n",
127176
" batch_size = 100\n",
128-
" num_epochs = 20\n",
129177
" step_size = 0.01\n",
130178
" \n",
131179
" data = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)\n",
132180
" loss_fn = nn.SmoothL1Loss()\n",
133181
" optimizer = torch.optim.Adam(model.parameters(), lr=step_size, amsgrad=True)\n",
134182
"\n",
135183
" def _train(model, data):\n",
136-
" for epoch in range(num_epochs):\n",
184+
" for _ in range(num_epochs):\n",
137185
" for x, y in data:\n",
138186
" optimizer.zero_grad()\n",
139187
" output = model(x)\n",
@@ -150,7 +198,6 @@
150198
" if i % 10 == 0:\n",
151199
" conv = pyspiel.exploitability(game, solver.average_policy())\n",
152200
" result.append(conv)\n",
153-
" # print(\"Iteration {} exploitability {}\".format(i, conv))\n",
154201
" return result"
155202
]
156203
},
@@ -160,11 +207,11 @@
160207
"metadata": {},
161208
"outputs": [],
162209
"source": [
163-
"tensorflow_rcfr = []\n",
210+
"flax_rcfr = []\n",
164211
"pytorch_rcfr = []\n",
165212
"num_epochs, iterations = 20, 100\n",
166213
"for _ in range(10):\n",
167-
" tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n",
214+
" flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))\n",
168215
" pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))"
169216
]
170217
},
@@ -177,10 +224,10 @@
177224
"import matplotlib.pyplot as plt\n",
178225
"\n",
179226
"x = [i for i in range(10)]\n",
180-
"tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
227+
"flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n",
181228
"pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
182229
"\n",
183-
"plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
230+
"plt.plot(x, flax_exploitability, label=\"flax.nnx\")\n",
184231
"plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
185232
"\n",
186233
"plt.legend()\n",
@@ -194,11 +241,11 @@
194241
"metadata": {},
195242
"outputs": [],
196243
"source": [
197-
"tensorflow_rcfr = []\n",
244+
"flax_rcfr = []\n",
198245
"pytorch_rcfr = []\n",
199-
"num_epochs, iterations = 200, 100\n",
246+
"num_epochs, iterations = 200, 100 \n",
200247
"for _ in range(10):\n",
201-
" tensorflow_rcfr.append(tnsorflow_example('kuhn_poker', num_epochs, iterations))\n",
248+
" flax_rcfr.append(flax_example('kuhn_poker', num_epochs, iterations))\n",
202249
" pytorch_rcfr.append(pytorch_example('kuhn_poker', num_epochs, iterations))"
203250
]
204251
},
@@ -211,10 +258,10 @@
211258
"import matplotlib.pyplot as plt\n",
212259
"\n",
213260
"x = [i for i in range(10)]\n",
214-
"tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
261+
"flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n",
215262
"pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
216263
"\n",
217-
"plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
264+
"plt.plot(x, flax_exploitability, label=\"flax_nnx\")\n",
218265
"plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
219266
"\n",
220267
"plt.legend()\n",
@@ -228,11 +275,11 @@
228275
"metadata": {},
229276
"outputs": [],
230277
"source": [
231-
"tensorflow_rcfr = []\n",
278+
"flax_rcfr = []\n",
232279
"pytorch_rcfr = []\n",
233280
"num_epochs, iterations = 20, 100\n",
234281
"for _ in range(10):\n",
235-
" tensorflow_rcfr.append(tnsorflow_example('leduc_poker', num_epochs, iterations))\n",
282+
" flax_rcfr.append(flax_example('leduc_poker', num_epochs, iterations))\n",
236283
" pytorch_rcfr.append(pytorch_example('leduc_poker', num_epochs, iterations))"
237284
]
238285
},
@@ -245,10 +292,10 @@
245292
"import matplotlib.pyplot as plt\n",
246293
"\n",
247294
"x = [i for i in range(10)]\n",
248-
"tf_exploitability = [sum(tfe) for tfe in zip(*tensorflow_rcfr)]\n",
295+
"flax_exploitability = [sum(tfe) for tfe in zip(*flax_rcfr)]\n",
249296
"pt_exploitability = [sum(pte) for pte in zip(*pytorch_rcfr)]\n",
250297
"\n",
251-
"plt.plot(x, tf_exploitability, label=\"tensorflow\")\n",
298+
"plt.plot(x, flax_exploitability, label=\"flax_nnx\")\n",
252299
"plt.plot(x, pt_exploitability, label=\"pytorch\")\n",
253300
"\n",
254301
"plt.legend()\n",
@@ -266,7 +313,7 @@
266313
],
267314
"metadata": {
268315
"kernelspec": {
269-
"display_name": "Python 3",
316+
"display_name": "open_spiel",
270317
"language": "python",
271318
"name": "python3"
272319
},
@@ -280,7 +327,7 @@
280327
"name": "python",
281328
"nbconvert_exporter": "python",
282329
"pygments_lexer": "ipython3",
283-
"version": "3.7.3"
330+
"version": "3.12.11"
284331
}
285332
},
286333
"nbformat": 4,

0 commit comments

Comments
 (0)