Skip to content

Commit 841f6d8

Browse files
authored
Merge branch 'master' into dee_cfr_jax_refactor
2 parents 26161df + 76ca30e commit 841f6d8

File tree

13 files changed

+2548
-95
lines changed

13 files changed

+2548
-95
lines changed

open_spiel/colabs/CFR_and_REINFORCE.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
"\n",
135135
" # Compute regrets at this state.\n",
136136
" cfr_prob = np.prod(reach[:player]) * np.prod(reach[player+1:])\n",
137-
" value = np.einsum('ap,a-\u003ep', utility, curr_policy[index])\n",
137+
" value = np.einsum('ap,a->p', utility, curr_policy[index])\n",
138138
" for action in state.legal_actions():\n",
139139
" regrets[index][action] += cfr_prob * (utility[action][player] - value[player])\n",
140140
"\n",
@@ -173,7 +173,7 @@
173173
" policy.action_probability_array += curr_policy * lr\n",
174174
"\n",
175175
" # Evaluate the average policy\n",
176-
" if step \u0026 (step-1) == 0:\n",
176+
" if step & (step-1) == 0:\n",
177177
" nc = exploitability.nash_conv(game, policy)\n",
178178
" eval_steps.append(step)\n",
179179
" eval_nash_conv.append(nc)\n",

open_spiel/colabs/crowd_modelling_4rooms_MFGsurvey.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
"\n",
170170
"def decode_distribution(game: pyspiel.Game,\n",
171171
" dist: Dict[str, float],\n",
172-
" nans: bool = True) -\u003e np.ndarray:\n",
172+
" nans: bool = True) -> np.ndarray:\n",
173173
" \"\"\"Decodes the distribution of a 2D crowd modelling game from a dictionary.\"\"\"\n",
174174
" # Extract the size of the distribution from the game parameters. Time, i.e.\n",
175175
" # horizon is the leading dimension so that we can easily present the temporal\n",
@@ -179,7 +179,7 @@
179179
" decoded = np.zeros(dist_size)\n",
180180
"\n",
181181
" for key, value in dist.items():\n",
182-
" m = re.fullmatch(r'\\((?P\u003cx\u003e\\d+),\\s*(?P\u003cy\u003e\\d+),\\s*(?P\u003ct\u003e\\d+)\\)', key)\n",
182+
" m = re.fullmatch(r'\\((?P<x>\\d+),\\s*(?P<y>\\d+),\\s*(?P<t>\\d+)\\)', key)\n",
183183
" if m:\n",
184184
" g = m.group\n",
185185
" decoded[(int(g('t')), int(g('y')), int(g('x')))] = value\n",
@@ -188,14 +188,14 @@
188188
"\n",
189189
"\n",
190190
"def get_policy_distribution(game: pyspiel.Game,\n",
191-
" policy: policy_std.Policy) -\u003e np.ndarray:\n",
191+
" policy: policy_std.Policy) -> np.ndarray:\n",
192192
" \"\"\"Returns the distribution of the policy.\"\"\"\n",
193193
" dist_policy = distribution.DistributionPolicy(game, policy)\n",
194194
" return decode_distribution(game, dist_policy.distribution)\n",
195195
"\n",
196196
"\n",
197197
"def animate_distributions(dists: np.ndarray,\n",
198-
" fixed_cbar: bool = False) -\u003e animation.FuncAnimation:\n",
198+
" fixed_cbar: bool = False) -> animation.FuncAnimation:\n",
199199
" \"\"\"Animates the given distributions.\n",
200200
"\n",
201201
" Args:\n",

0 commit comments

Comments
 (0)