|
169 | 169 | "\n", |
170 | 170 | "def decode_distribution(game: pyspiel.Game,\n", |
171 | 171 | " dist: Dict[str, float],\n", |
172 | | - " nans: bool = True) -\u003e np.ndarray:\n", |
| 172 | + " nans: bool = True) -> np.ndarray:\n", |
173 | 173 | " \"\"\"Decodes the distribution of a 2D crowd modelling game from a dictionary.\"\"\"\n", |
174 | 174 | " # Extract the size of the distribution from the game parameters. Time, i.e.\n", |
175 | 175 | " # horizon is the leading dimension so that we can easily present the temporal\n", |
|
179 | 179 | " decoded = np.zeros(dist_size)\n", |
180 | 180 | "\n", |
181 | 181 | " for key, value in dist.items():\n", |
182 | | - " m = re.fullmatch(r'\\((?P\u003cx\u003e\\d+),\\s*(?P\u003cy\u003e\\d+),\\s*(?P\u003ct\u003e\\d+)\\)', key)\n", |
| 182 | + " m = re.fullmatch(r'\\((?P<x>\\d+),\\s*(?P<y>\\d+),\\s*(?P<t>\\d+)\\)', key)\n", |
183 | 183 | " if m:\n", |
184 | 184 | " g = m.group\n", |
185 | 185 | " decoded[(int(g('t')), int(g('y')), int(g('x')))] = value\n", |
|
188 | 188 | "\n", |
189 | 189 | "\n", |
190 | 190 | "def get_policy_distribution(game: pyspiel.Game,\n", |
191 | | - " policy: policy_std.Policy) -\u003e np.ndarray:\n", |
| 191 | + " policy: policy_std.Policy) -> np.ndarray:\n", |
192 | 192 | " \"\"\"Returns the distribution of the policy.\"\"\"\n", |
193 | 193 | " dist_policy = distribution.DistributionPolicy(game, policy)\n", |
194 | 194 | " return decode_distribution(game, dist_policy.distribution)\n", |
195 | 195 | "\n", |
196 | 196 | "\n", |
197 | 197 | "def animate_distributions(dists: np.ndarray,\n", |
198 | | - " fixed_cbar: bool = False) -\u003e animation.FuncAnimation:\n", |
| 198 | + " fixed_cbar: bool = False) -> animation.FuncAnimation:\n", |
199 | 199 | " \"\"\"Animates the given distributions.\n", |
200 | 200 | "\n", |
201 | 201 | " Args:\n", |
|
0 commit comments