Skip to content

Commit b627b70

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feat/simplify-save-load
2 parents 27120e9 + b5472d0 commit b627b70

64 files changed

Lines changed: 1065 additions & 1099 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

examples/aurora.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@
122122
"prior_descriptor_dim = 2 #@param {type:\"integer\"}\n",
123123
"\n",
124124
"log_freq = 5 #@param {type:\"integer\"}\n",
125+
"\n",
126+
"# Custom observations key that will be used to store the observations in the\n",
127+
"# extra_scores of the repertoire\n",
128+
"aurora_observations_key = \"observations\"\n",
129+
"\n",
125130
"#@markdown ---"
126131
]
127132
},
@@ -258,6 +263,7 @@
258263
"aurora_scoring_fn = get_aurora_scoring_fn(\n",
259264
" scoring_fn=scoring_fn,\n",
260265
" observation_extractor_fn=observation_extractor_fn,\n",
266+
" observations_key=aurora_observations_key,\n",
261267
")\n",
262268
"\n",
263269
"# Get minimum reward value to make sure qd_score are positive\n",
@@ -389,6 +395,7 @@
389395
" metrics_function=metrics_fn,\n",
390396
" encoder_function=encoder_fn,\n",
391397
" training_function=train_fn,\n",
398+
" observations_key=aurora_observations_key,\n",
392399
")\n",
393400
"\n",
394401
"# define arbitrary observation's mean/std\n",
@@ -444,6 +451,7 @@
444451
"n_target = 1024\n",
445452
"\n",
446453
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
454+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
447455
"\n",
448456
"iteration = 0\n",
449457
"while iteration < max_iterations:\n",
@@ -472,7 +480,7 @@
472480
" )\n",
473481
"\n",
474482
" elif iteration % 2 == 0:\n",
475-
" repertoire, previous_error = aurora.container_size_control(\n",
483+
" repertoire, previous_error = container_size_control_fn(\n",
476484
" repertoire,\n",
477485
" target_size=n_target,\n",
478486
" previous_error=previous_error,\n",

examples/cmaes.ipynb

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,21 @@
208208
"covs = [(state.sigma**2) * state.cov_matrix]\n",
209209
"\n",
210210
"iteration_count = 0\n",
211+
"sample_fn = jax.jit(cmaes.sample)\n",
212+
"update_fn = jax.jit(cmaes.update)\n",
213+
"stop_condition_fn = jax.jit(cmaes.stop_condition)\n",
211214
"for _ in range(num_iterations):\n",
212215
" iteration_count += 1\n",
213216
"\n",
214217
" # sample\n",
215218
" key, subkey = jax.random.split(key)\n",
216-
" samples = cmaes.sample(state, subkey)\n",
219+
" samples = sample_fn(state, subkey)\n",
217220
"\n",
218221
" # update\n",
219-
" state = cmaes.update(state, samples)\n",
222+
" state = update_fn(state, samples)\n",
220223
"\n",
221224
" # check stop condition\n",
222-
" stop_condition = cmaes.stop_condition(state)\n",
225+
" stop_condition = stop_condition_fn(state)\n",
223226
"\n",
224227
" if stop_condition:\n",
225228
" break\n",

examples/mapelites_asktell.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,20 +326,23 @@
326326
"except ImportError:\n",
327327
" bar = range(num_iterations)\n",
328328
"\n",
329+
"ask_fn = jax.jit(map_elites.ask)\n",
330+
"tell_fn = jax.jit(map_elites.tell)\n",
331+
"\n",
329332
"# Main loop\n",
330333
"for i in bar:\n",
331334
" start_time = time.time()\n",
332335
" key, subkey = jax.random.split(key)\n",
333336
" # Generate solutions\n",
334-
" genotypes, extra_info = map_elites.ask(repertoire, emitter_state, subkey)\n",
337+
" genotypes, extra_info = ask_fn(repertoire, emitter_state, subkey)\n",
335338
"\n",
336339
" # Evaluate solutions: get fitness, descriptor and extra scores.\n",
337340
" # This is where custom evaluations on CPU or GPU can be added.\n",
338341
" key, subkey = jax.random.split(key)\n",
339342
" fitnesses, descriptors, extra_scores = scoring_fn(genotypes, subkey)\n",
340343
"\n",
341344
" # Update MAP-Elites\n",
342-
" repertoire, emitter_state, current_metrics = map_elites.tell(\n",
345+
" repertoire, emitter_state, current_metrics = tell_fn(\n",
343346
" genotypes=genotypes,\n",
344347
" fitnesses=fitnesses,\n",
345348
" descriptors=descriptors,\n",

examples/mees.ipynb

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@
248248
"\n",
249249
"# Prepare the scoring functions for the offspring generated following\n",
250250
"# the approximated gradient (each of them is evaluated 30 times)\n",
251-
"sampling_fn = functools.partial(\n",
251+
"sampling_fn = jax.jit(functools.partial(\n",
252252
" sampling,\n",
253253
" scoring_fn=scoring_fn,\n",
254254
" num_samples=30,\n",
255-
")\n",
255+
"))\n",
256256
"\n",
257257
"# Get minimum reward value to make sure qd_score are positive\n",
258258
"reward_offset = environments.reward_offset[env_name]\n",
@@ -448,11 +448,8 @@
448448
"provenance": []
449449
},
450450
"gpuClass": "standard",
451-
"interpreter": {
452-
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
453-
},
454451
"kernelspec": {
455-
"display_name": "Python 3 (ipykernel)",
452+
"display_name": ".venv",
456453
"language": "python",
457454
"name": "python3"
458455
},
@@ -466,7 +463,7 @@
466463
"name": "python",
467464
"nbconvert_exporter": "python",
468465
"pygments_lexer": "ipython3",
469-
"version": "3.10.12"
466+
"version": "3.11.10"
470467
}
471468
},
472469
"nbformat": 4,

examples/nsga2_spea2.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@
9898
"proportion_mutation = 0.80 #@param {type:\"number\"}\n",
9999
"minval = -5.12 #@param {type:\"number\"}\n",
100100
"maxval = 5.12 #@param {type:\"number\"}\n",
101-
"batch_size = 100 #@param {type:\"integer\"}\n",
102101
"genotype_dim = 6 #@param {type:\"integer\"}\n",
103102
"lag = 2.2 #@param {type:\"number\"}\n",
104103
"base_lag = 0 #@param {type:\"number\"}\n",
@@ -184,7 +183,7 @@
184183
"key = jax.random.key(0)\n",
185184
"key, subkey = jax.random.split(key)\n",
186185
"genotypes = jax.random.uniform(\n",
187-
" subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
186+
" subkey, (population_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
188187
")\n",
189188
"\n",
190189
"# Mutation & Crossover\n",
@@ -202,11 +201,12 @@
202201
")\n",
203202
"\n",
204203
"# Define the emitter\n",
204+
"# NSGA-II and SPEA2 use batch size = population size\n",
205205
"mixing_emitter = MixingEmitter(\n",
206206
" mutation_fn=mutation_function, \n",
207207
" variation_fn=crossover_function, \n",
208208
" variation_percentage=1-proportion_mutation, \n",
209-
" batch_size=batch_size\n",
209+
" batch_size=population_size, \n",
210210
")"
211211
]
212212
},

examples/pga_aurora.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@
140140
"policy_delay = 2 #@param {type:\"number\"}\n",
141141
"\n",
142142
"log_freq = 5 #@param {type:\"integer\"}\n",
143+
"\n",
144+
"# Custom observations key that will be used to store the observations in the\n",
145+
"# extra_scores of the repertoire\n",
146+
"aurora_observations_key = \"observations\"\n",
147+
"\n",
143148
"#@markdown ---"
144149
]
145150
},
@@ -276,6 +281,7 @@
276281
"aurora_scoring_fn = get_aurora_scoring_fn(\n",
277282
" scoring_fn=scoring_fn,\n",
278283
" observation_extractor_fn=observation_extractor_fn,\n",
284+
" observations_key=aurora_observations_key,\n",
279285
")\n",
280286
"\n",
281287
"# Get minimum reward value to make sure qd_score are positive\n",
@@ -441,6 +447,7 @@
441447
" metrics_function=metrics_fn,\n",
442448
" encoder_function=encoder_fn,\n",
443449
" training_function=train_fn,\n",
450+
" observations_key=aurora_observations_key,\n",
444451
")\n",
445452
"\n",
446453
"# init the model params\n",
@@ -502,6 +509,7 @@
502509
"n_target = 1024\n",
503510
"\n",
504511
"previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target\n",
512+
"container_size_control_fn = jax.jit(aurora.container_size_control)\n",
505513
"\n",
506514
"iteration = 0\n",
507515
"while iteration < max_iterations:\n",
@@ -530,7 +538,7 @@
530538
" )\n",
531539
"\n",
532540
" elif iteration % 2 == 0:\n",
533-
" repertoire, previous_error = aurora.container_size_control(\n",
541+
" repertoire, previous_error = container_size_control_fn(\n",
534542
" repertoire,\n",
535543
" target_size=n_target,\n",
536544
" previous_error=previous_error,\n",

qdax/baselines/cmaes.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
a CMA optimization script. Link to the paper: https://arxiv.org/abs/1604.00772
44
"""
55

6-
from functools import partial
76
from typing import Callable, Optional, Tuple
87

98
import flax
@@ -165,7 +164,6 @@ def init(self) -> CMAESState:
165164
invsqrt_cov=invsqrt_cov,
166165
)
167166

168-
@partial(jax.jit, static_argnames=("self",))
169167
def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
170168
"""
171169
Sample a population.
@@ -186,7 +184,6 @@ def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
186184
)
187185
return samples
188186

189-
@partial(jax.jit, static_argnames=("self",))
190187
def update_state(
191188
self,
192189
cmaes_state: CMAESState,
@@ -198,7 +195,6 @@ def update_state(
198195
weights=self._weights,
199196
)
200197

201-
@partial(jax.jit, static_argnames=("self",))
202198
def update_state_with_mask(
203199
self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
204200
) -> CMAESState:
@@ -217,7 +213,6 @@ def update_state_with_mask(
217213
weights=weights,
218214
)
219215

220-
@partial(jax.jit, static_argnames=("self",))
221216
def _update_state(
222217
self,
223218
cmaes_state: CMAESState,
@@ -332,7 +327,6 @@ def update_eigen(
332327

333328
return cmaes_state
334329

335-
@partial(jax.jit, static_argnames=("self",))
336330
def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
337331
"""Updates the distribution.
338332
@@ -352,7 +346,6 @@ def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
352346

353347
return new_state # type: ignore
354348

355-
@partial(jax.jit, static_argnames=("self",))
356349
def stop_condition(self, cmaes_state: CMAESState) -> bool:
357350
"""Determines if the current optimization path must be stopped.
358351

qdax/baselines/dads.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
from dataclasses import dataclass
7-
from functools import partial
87
from typing import Callable, Tuple
98

109
import jax
@@ -191,7 +190,6 @@ def init( # type: ignore
191190
steps=jnp.array(0),
192191
)
193192

194-
@partial(jax.jit, static_argnames=("self",))
195193
def _compute_diversity_reward(
196194
self, transition: QDTransition, training_state: DadsTrainingState
197195
) -> Reward:
@@ -244,8 +242,7 @@ def _compute_diversity_reward(
244242

245243
return reward
246244

247-
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
248-
def play_step_fn(
245+
def play_step_fn( # type: ignore
249246
self,
250247
env_state: EnvState,
251248
training_state: DadsTrainingState,
@@ -339,14 +336,13 @@ def play_step_fn(
339336

340337
return next_env_state, training_state, transition
341338

342-
@partial(jax.jit, static_argnames=("self", "play_step_fn", "env_batch_size"))
343-
def eval_policy_fn(
339+
def eval_policy_fn( # type: ignore
344340
self,
345341
training_state: DadsTrainingState,
346342
eval_env_first_state: EnvState,
347343
play_step_fn: Callable[
348-
[EnvState, Params, RNGKey],
349-
Tuple[EnvState, Params, RNGKey, QDTransition],
344+
[EnvState, Params],
345+
Tuple[EnvState, Params, QDTransition],
350346
],
351347
env_batch_size: int,
352348
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
@@ -400,7 +396,6 @@ def eval_policy_fn(
400396

401397
return true_return, true_returns, diversity_returns, transitions.state_desc
402398

403-
@partial(jax.jit, static_argnames=("self",))
404399
def _compute_reward(
405400
self, transition: QDTransition, training_state: DadsTrainingState
406401
) -> Reward:
@@ -417,7 +412,6 @@ def _compute_reward(
417412
transition=transition, training_state=training_state
418413
)
419414

420-
@partial(jax.jit, static_argnames=("self",))
421415
def _update_dynamics(
422416
self, operand: Tuple[DadsTrainingState, QDTransition]
423417
) -> Tuple[Params, float, optax.OptState]:
@@ -448,7 +442,6 @@ def _update_dynamics(
448442
dynamics_optimizer_state,
449443
)
450444

451-
@partial(jax.jit, static_argnames=("self",))
452445
def _not_update_dynamics(
453446
self, operand: Tuple[DadsTrainingState, QDTransition]
454447
) -> Tuple[Params, float, optax.OptState]:
@@ -464,7 +457,6 @@ def _not_update_dynamics(
464457
training_state.dynamics_optimizer_state,
465458
)
466459

467-
@partial(jax.jit, static_argnames=("self",))
468460
def _update_networks(
469461
self,
470462
training_state: DadsTrainingState,
@@ -566,7 +558,6 @@ def _update_networks(
566558

567559
return new_training_state, metrics
568560

569-
@partial(jax.jit, static_argnames=("self",))
570561
def update(
571562
self,
572563
training_state: DadsTrainingState,

qdax/baselines/dads_smerl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
from dataclasses import dataclass
8-
from functools import partial
98
from typing import Optional, Tuple
109

1110
import jax
@@ -40,8 +39,7 @@ def __init__(self, config: DadsSmerlConfig, action_size: int, descriptor_size: i
4039
super(DADSSMERL, self).__init__(config, action_size, descriptor_size)
4140
self._config: DadsSmerlConfig = config
4241

43-
@partial(jax.jit, static_argnames=("self",))
44-
def _compute_reward(
42+
def _compute_reward( # type: ignore
4543
self,
4644
transition: QDTransition,
4745
training_state: DadsTrainingState,
@@ -74,7 +72,6 @@ def _compute_reward(
7472

7573
return rewards
7674

77-
@partial(jax.jit, static_argnames=("self",))
7875
def update(
7976
self,
8077
training_state: DadsTrainingState,

0 commit comments

Comments
 (0)