|
32 | 32 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise |
33 | 33 | from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first |
34 | 34 | from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike # noqa: F401 |
35 | | -from stable_baselines3.common.utils import constant_fn |
| 35 | +from stable_baselines3.common.utils import ConstantSchedule |
36 | 36 | from stable_baselines3.common.vec_env import ( |
37 | 37 | DummyVecEnv, |
38 | 38 | SubprocVecEnv, |
|
50 | 50 | import rl_zoo3.import_envs # noqa: F401 |
51 | 51 | from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback |
52 | 52 | from rl_zoo3.hyperparams_opt import HYPERPARAMS_CONVERTER, HYPERPARAMS_SAMPLER |
53 | | -from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule |
| 53 | +from rl_zoo3.utils import ( |
| 54 | + ALGOS, |
| 55 | + SimpleLinearSchedule, |
| 56 | + get_callback_list, |
| 57 | + get_class_by_name, |
| 58 | + get_latest_run_id, |
| 59 | + get_wrapper_class, |
| 60 | +) |
54 | 61 |
|
55 | 62 |
|
56 | 63 | class ExperimentManager: |
@@ -381,12 +388,12 @@ def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]: |
381 | 388 | if isinstance(hyperparams[key], str): |
382 | 389 | schedule, initial_value = hyperparams[key].split("_") |
383 | 390 | initial_value = float(initial_value) |
384 | | - hyperparams[key] = linear_schedule(initial_value) |
| 391 | + hyperparams[key] = SimpleLinearSchedule(initial_value) |
385 | 392 | elif isinstance(hyperparams[key], (float, int)): |
386 | 393 | # Negative value: ignore (ex: for clipping) |
387 | 394 | if hyperparams[key] < 0: |
388 | 395 | continue |
389 | | - hyperparams[key] = constant_fn(float(hyperparams[key])) |
| 396 | + hyperparams[key] = ConstantSchedule(float(hyperparams[key])) |
390 | 397 | else: |
391 | 398 | raise ValueError(f"Invalid value for {key}: {hyperparams[key]}") |
392 | 399 | return hyperparams |
|
0 commit comments