|
21 | 21 | "- how to create a Map-elites instance\n", |
22 | 22 | "- which functions must be defined before training\n", |
23 | 23 | "- how to launch a certain number of training steps\n", |
24 | | - "- how to visualise the optimization process\n", |
25 | | - "- how to save/load a repertoire" |
| 24 | + "- how to visualise the optimization process" |
26 | 25 | ] |
27 | 26 | }, |
28 | 27 | { |
|
368 | 367 | "fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)" |
369 | 368 | ] |
370 | 369 | }, |
371 | | - { |
372 | | - "cell_type": "markdown", |
373 | | - "metadata": {}, |
374 | | - "source": [ |
375 | | - "# How to save/load a repertoire\n", |
376 | | - "\n", |
377 | | - "The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation." |
378 | | - ] |
379 | | - }, |
380 | | - { |
381 | | - "cell_type": "markdown", |
382 | | - "metadata": {}, |
383 | | - "source": [ |
384 | | - "## Load the final repertoire" |
385 | | - ] |
386 | | - }, |
387 | | - { |
388 | | - "cell_type": "code", |
389 | | - "execution_count": null, |
390 | | - "metadata": {}, |
391 | | - "outputs": [], |
392 | | - "source": [ |
393 | | - "repertoire_path = \"./last_repertoire/\"\n", |
394 | | - "os.makedirs(repertoire_path, exist_ok=True)\n", |
395 | | - "repertoire.save(path=repertoire_path)" |
396 | | - ] |
397 | | - }, |
398 | | - { |
399 | | - "cell_type": "markdown", |
400 | | - "metadata": {}, |
401 | | - "source": [ |
402 | | - "## Build the reconstruction function" |
403 | | - ] |
404 | | - }, |
405 | | - { |
406 | | - "cell_type": "code", |
407 | | - "execution_count": null, |
408 | | - "metadata": {}, |
409 | | - "outputs": [], |
410 | | - "source": [ |
411 | | - "# Init population of policies\n", |
412 | | - "key, subkey = jax.random.split(key)\n", |
413 | | - "fake_batch = jnp.zeros(shape=(env.observation_size,))\n", |
414 | | - "fake_params = policy_network.init(subkey, fake_batch)\n", |
415 | | - "\n", |
416 | | - "_, reconstruction_fn = ravel_pytree(fake_params)" |
417 | | - ] |
418 | | - }, |
419 | | - { |
420 | | - "cell_type": "markdown", |
421 | | - "metadata": {}, |
422 | | - "source": [ |
423 | | - "## Use the reconstruction function to load and re-create the repertoire" |
424 | | - ] |
425 | | - }, |
426 | | - { |
427 | | - "cell_type": "code", |
428 | | - "execution_count": null, |
429 | | - "metadata": {}, |
430 | | - "outputs": [], |
431 | | - "source": [ |
432 | | - "repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)" |
433 | | - ] |
434 | | - }, |
435 | | - { |
436 | | - "cell_type": "markdown", |
437 | | - "metadata": {}, |
438 | | - "source": [ |
439 | | - "## Get the best individual of the repertoire" |
440 | | - ] |
441 | | - }, |
442 | 370 | { |
443 | 371 | "cell_type": "code", |
444 | 372 | "execution_count": null, |
|
469 | 397 | "metadata": {}, |
470 | 398 | "outputs": [], |
471 | 399 | "source": [ |
| 400 | + "# select the parameters of the best individual\n", |
472 | 401 | "my_params = jax.tree.map(\n", |
473 | 402 | " lambda x: x[best_idx],\n", |
474 | 403 | " repertoire.genotypes\n", |
|
0 commit comments