Skip to content

Commit 56ad87e

Browse files
authored
Feat: remove save and load methods (#225)
* Removing save and load methods from MapElitesRepertoire and children classes, as they were tiresome to maintain and implement for new types of repertoires. * Instead, users can use their own custom way to save and load repertoire, such as pickle and orbax. * fix brax v2 wrapper inheritance causing: _n_frames=1
1 parent f5bcec9 commit 56ad87e

17 files changed

Lines changed: 28 additions & 439 deletions

examples/aurora.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"- how to create an AURORA instance\n",
2222
"- which functions must be defined before training\n",
2323
"- how to launch a certain number of training steps\n",
24-
"- how to visualise the optimization process\n",
25-
"- how to save/load a repertoire"
24+
"- how to visualise the optimization process"
2625
]
2726
},
2827
{

examples/mapelites.ipynb

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"- how to create a Map-elites instance\n",
2222
"- which functions must be defined before training\n",
2323
"- how to launch a certain number of training steps\n",
24-
"- how to visualise the optimization process\n",
25-
"- how to save/load a repertoire"
24+
"- how to visualise the optimization process"
2625
]
2726
},
2827
{
@@ -368,77 +367,6 @@
368367
"fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)"
369368
]
370369
},
371-
{
372-
"cell_type": "markdown",
373-
"metadata": {},
374-
"source": [
375-
"# How to save/load a repertoire\n",
376-
"\n",
377-
"The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation."
378-
]
379-
},
380-
{
381-
"cell_type": "markdown",
382-
"metadata": {},
383-
"source": [
384-
"## Load the final repertoire"
385-
]
386-
},
387-
{
388-
"cell_type": "code",
389-
"execution_count": null,
390-
"metadata": {},
391-
"outputs": [],
392-
"source": [
393-
"repertoire_path = \"./last_repertoire/\"\n",
394-
"os.makedirs(repertoire_path, exist_ok=True)\n",
395-
"repertoire.save(path=repertoire_path)"
396-
]
397-
},
398-
{
399-
"cell_type": "markdown",
400-
"metadata": {},
401-
"source": [
402-
"## Build the reconstruction function"
403-
]
404-
},
405-
{
406-
"cell_type": "code",
407-
"execution_count": null,
408-
"metadata": {},
409-
"outputs": [],
410-
"source": [
411-
"# Init population of policies\n",
412-
"key, subkey = jax.random.split(key)\n",
413-
"fake_batch = jnp.zeros(shape=(env.observation_size,))\n",
414-
"fake_params = policy_network.init(subkey, fake_batch)\n",
415-
"\n",
416-
"_, reconstruction_fn = ravel_pytree(fake_params)"
417-
]
418-
},
419-
{
420-
"cell_type": "markdown",
421-
"metadata": {},
422-
"source": [
423-
"## Use the reconstruction function to load and re-create the repertoire"
424-
]
425-
},
426-
{
427-
"cell_type": "code",
428-
"execution_count": null,
429-
"metadata": {},
430-
"outputs": [],
431-
"source": [
432-
"repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)"
433-
]
434-
},
435-
{
436-
"cell_type": "markdown",
437-
"metadata": {},
438-
"source": [
439-
"## Get the best individual of the repertoire"
440-
]
441-
},
442370
{
443371
"cell_type": "code",
444372
"execution_count": null,
@@ -469,6 +397,7 @@
469397
"metadata": {},
470398
"outputs": [],
471399
"source": [
400+
"# select the parameters of the best individual\n",
472401
"my_params = jax.tree.map(\n",
473402
" lambda x: x[best_idx],\n",
474403
" repertoire.genotypes\n",

examples/mapelites_asktell.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"- how to create a Map-elites instance\n",
2222
"- which functions must be defined before training\n",
2323
"- how to launch a certain number of training steps\n",
24-
"- how to visualise the optimization process\n",
25-
"- how to save/load a repertoire"
24+
"- how to visualise the optimization process"
2625
]
2726
},
2827
{

examples/mapelites_brax_v2.ipynb

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -368,70 +368,6 @@
368368
"fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)"
369369
]
370370
},
371-
{
372-
"cell_type": "markdown",
373-
"metadata": {},
374-
"source": [
375-
"# How to save/load a repertoire\n",
376-
"\n",
377-
"The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation."
378-
]
379-
},
380-
{
381-
"cell_type": "markdown",
382-
"metadata": {},
383-
"source": [
384-
"## Load the final repertoire"
385-
]
386-
},
387-
{
388-
"cell_type": "code",
389-
"execution_count": null,
390-
"metadata": {},
391-
"outputs": [],
392-
"source": [
393-
"repertoire_path = \"./last_repertoire/\"\n",
394-
"os.makedirs(repertoire_path, exist_ok=True)\n",
395-
"repertoire.save(path=repertoire_path)"
396-
]
397-
},
398-
{
399-
"cell_type": "markdown",
400-
"metadata": {},
401-
"source": [
402-
"## Build the reconstruction function"
403-
]
404-
},
405-
{
406-
"cell_type": "code",
407-
"execution_count": null,
408-
"metadata": {},
409-
"outputs": [],
410-
"source": [
411-
"# Init population of policies\n",
412-
"key, subkey = jax.random.split(key)\n",
413-
"fake_batch = jnp.zeros(shape=(env.observation_size,))\n",
414-
"fake_params = policy_network.init(subkey, fake_batch)\n",
415-
"\n",
416-
"_, reconstruction_fn = ravel_pytree(fake_params)"
417-
]
418-
},
419-
{
420-
"cell_type": "markdown",
421-
"metadata": {},
422-
"source": [
423-
"## Use the reconstruction function to load and re-create the repertoire"
424-
]
425-
},
426-
{
427-
"cell_type": "code",
428-
"execution_count": null,
429-
"metadata": {},
430-
"outputs": [],
431-
"source": [
432-
"repertoire = MapElitesRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)"
433-
]
434-
},
435371
{
436372
"cell_type": "markdown",
437373
"metadata": {},
@@ -519,7 +455,8 @@
519455
"outputs": [],
520456
"source": [
521457
"def save_rollout_html(env, pipeline_env_list, file_name: str):\n",
522-
" rollout_html = html.render(env.sys.replace(dt=env.dt), pipeline_env_list)\n",
458+
" sys = env.sys.tree_replace({'opt.timestep': env.dt})\n",
459+
" rollout_html = html.render(sys, pipeline_env_list)\n",
523460
" with open(file_name, 'w') as f:\n",
524461
" f.write(rollout_html)\n",
525462
"\n",

examples/mels.ipynb

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
"- how to create an ME-LS instance\n",
2121
"- which functions must be defined before training\n",
2222
"- how to launch a certain number of training steps\n",
23-
"- how to visualise the optimization process\n",
24-
"- how to save/load a repertoire"
23+
"- how to visualise the optimization process"
2524
]
2625
},
2726
{
@@ -383,70 +382,6 @@
383382
"fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)"
384383
]
385384
},
386-
{
387-
"cell_type": "markdown",
388-
"metadata": {},
389-
"source": [
390-
"# How to save/load a repertoire\n",
391-
"\n",
392-
"The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation."
393-
]
394-
},
395-
{
396-
"cell_type": "markdown",
397-
"metadata": {},
398-
"source": [
399-
"## Load the final repertoire"
400-
]
401-
},
402-
{
403-
"cell_type": "code",
404-
"execution_count": null,
405-
"metadata": {},
406-
"outputs": [],
407-
"source": [
408-
"repertoire_path = \"./last_repertoire/\"\n",
409-
"os.makedirs(repertoire_path, exist_ok=True)\n",
410-
"repertoire.save(path=repertoire_path)"
411-
]
412-
},
413-
{
414-
"cell_type": "markdown",
415-
"metadata": {},
416-
"source": [
417-
"## Build the reconstruction function"
418-
]
419-
},
420-
{
421-
"cell_type": "code",
422-
"execution_count": null,
423-
"metadata": {},
424-
"outputs": [],
425-
"source": [
426-
"# Init population of policies\n",
427-
"key, subkey = jax.random.split(key)\n",
428-
"fake_batch = jnp.zeros(shape=(env.observation_size,))\n",
429-
"fake_params = policy_network.init(subkey, fake_batch)\n",
430-
"\n",
431-
"_, reconstruction_fn = ravel_pytree(fake_params)"
432-
]
433-
},
434-
{
435-
"cell_type": "markdown",
436-
"metadata": {},
437-
"source": [
438-
"## Use the reconstruction function to load and re-create the repertoire"
439-
]
440-
},
441-
{
442-
"cell_type": "code",
443-
"execution_count": null,
444-
"metadata": {},
445-
"outputs": [],
446-
"source": [
447-
"repertoire = MELSRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)"
448-
]
449-
},
450385
{
451386
"cell_type": "markdown",
452387
"metadata": {},
@@ -488,6 +423,7 @@
488423
"metadata": {},
489424
"outputs": [],
490425
"source": [
426+
"# select the parameters of the best individual\n",
491427
"my_params = jax.tree.map(\n",
492428
" lambda x: x[best_idx],\n",
493429
" repertoire.genotypes\n",

examples/pga_aurora.ipynb

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
"- how to create an AURORA instance and mix it with the right emitter to define PGA-AURORA\n",
2222
"- which functions must be defined before training\n",
2323
"- how to launch a certain number of training steps\n",
24-
"- how to visualise the optimization process\n",
25-
"- how to save/load a repertoire"
24+
"- how to visualise the optimization process"
2625
]
2726
},
2827
{
@@ -426,8 +425,6 @@
426425
" model, subkey, (1, *observations_dims)\n",
427426
")\n",
428427
"\n",
429-
"print(jax.tree_map(lambda x: x.shape, model_params))\n",
430-
"\n",
431428
"# Define the encoder function\n",
432429
"encoder_fn = jax.jit(\n",
433430
" functools.partial(\n",
@@ -459,8 +456,6 @@
459456
" model, subkey, (1, *observations_dims)\n",
460457
")\n",
461458
"\n",
462-
"print(jax.tree_map(lambda x: x.shape, model_params))\n",
463-
"\n",
464459
"# define arbitrary observation's mean/std\n",
465460
"mean_observations = jnp.zeros(observations_dims[-1])\n",
466461
"std_observations = jnp.ones(observations_dims[-1])\n",

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,3 @@ strict_equality = True
1818
explicit_package_bases = True
1919
follow_imports = skip
2020
ignore_missing_imports = True
21-
22-
[mypy-tensorflow_probability.*]
23-
ignore_missing_imports = True

qdax/core/containers/ga_repertoire.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
from __future__ import annotations
44

5-
from typing import Callable, Optional, Tuple
5+
from typing import Optional, Tuple
66

77
import flax
88
import jax
99
import jax.numpy as jnp
10-
from jax.flatten_util import ravel_pytree
1110

1211
from qdax.core.containers.repertoire import Repertoire
1312
from qdax.core.emitters.repertoire_selectors.selector import GARepertoireT, Selector
@@ -45,46 +44,6 @@ def size(self) -> int:
4544
first_leaf = jax.tree.leaves(self.genotypes)[0]
4645
return int(first_leaf.shape[0])
4746

48-
def save(self, path: str = "./") -> None:
49-
"""Saves the repertoire.
50-
51-
Args:
52-
path: place to store the files. Defaults to "./".
53-
"""
54-
55-
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
56-
flatten_genotype, _ = ravel_pytree(genotype)
57-
return flatten_genotype
58-
59-
# flatten all the genotypes
60-
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
61-
62-
jnp.save(path + "genotypes.npy", flat_genotypes)
63-
jnp.save(path + "scores.npy", self.fitnesses)
64-
65-
@classmethod
66-
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
67-
"""Loads a GA Repertoire.
68-
69-
Args:
70-
reconstruction_fn: Function to reconstruct a PyTree
71-
from a flat array.
72-
path: Path where the data is saved. Defaults to "./".
73-
74-
Returns:
75-
A GA Repertoire.
76-
"""
77-
78-
flat_genotypes = jnp.load(path + "genotypes.npy")
79-
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
80-
81-
fitnesses = jnp.load(path + "fitnesses.npy")
82-
83-
return cls(
84-
genotypes=genotypes,
85-
fitnesses=fitnesses,
86-
)
87-
8847
def select(
8948
self,
9049
key: RNGKey,

0 commit comments

Comments
 (0)