Skip to content

Commit fa86a07

Browse files
committed
cleaned codebase and moved to gbrl 1.1.0
1 parent 4dcc954 commit fa86a07

22 files changed

+788
-628
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ wandb/
1111
saved_models/
1212
videos/
1313
temp/
14-
results/
14+
results/
15+
*venv*/

algos/a2c.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from stable_baselines3.common.callbacks import BaseCallback
2121
from stable_baselines3.common.distributions import DiagGaussianDistribution
2222
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
23-
from stable_baselines3.common.type_aliases import (GymEnv, MaybeCallback,
24-
Schedule)
23+
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
2524
from stable_baselines3.common.utils import (explained_variance, get_linear_fn,
2625
obs_as_tensor, safe_mean, update_learning_rate, get_schedule_fn)
2726
from stable_baselines3.common.vec_env import VecEnv
@@ -103,15 +102,17 @@ def __init__(
103102
self.normalize_advantage = normalize_advantage
104103
self.max_policy_grad_norm = max_policy_grad_norm
105104
self.max_value_grad_norm = max_value_grad_norm
106-
num_rollouts = total_n_steps / (n_steps * env.num_envs)
105+
num_rollouts = total_n_steps / (n_steps * env.num_envs)
107106
total_num_updates = num_rollouts
108107
assert 'tree_optimizer' in policy_kwargs, "tree_optimizer must be a dictionary within policy_kwargs"
109-
assert 'gbrl_params' in policy_kwargs['tree_optimizer'], "gbrl_params must be a dictionary within policy_kwargs['tree_optimizer]"
108+
assert 'params' in policy_kwargs['tree_optimizer'], \
109+
"params must be a dictionary within policy_kwargs['tree_optimizer]"
110110
policy_kwargs['tree_optimizer']['policy_optimizer']['T'] = int(total_num_updates)
111111
policy_kwargs['tree_optimizer']['value_optimizer']['T'] = int(total_num_updates)
112112
policy_kwargs['tree_optimizer']['device'] = device
113113
self.fixed_std = fixed_std
114-
is_categorical = (hasattr(env, 'is_mixed') and env.is_mixed) or (hasattr(env, 'is_categorical') and env.is_categorical)
114+
is_categorical = (hasattr(env, 'is_mixed') and env.is_mixed) or (hasattr(env, 'is_categorical') and
115+
env.is_categorical)
115116
is_mixed = (hasattr(env, 'is_mixed') and env.is_mixed)
116117
if is_categorical:
117118
policy_kwargs['is_categorical'] = True
@@ -120,14 +121,14 @@ def __init__(
120121

121122
if isinstance(log_std_lr, str):
122123
if 'lin_' in log_std_lr:
123-
log_std_lr = get_linear_fn(float(log_std_lr.replace('lin_' ,'')), min_log_std_lr, 1)
124+
log_std_lr = get_linear_fn(float(log_std_lr.replace('lin_', '')), min_log_std_lr, 1)
124125
else:
125126
log_std_lr = float(log_std_lr)
126127
policy_kwargs['log_std_schedule'] = get_schedule_fn(log_std_lr)
127128
super().__init__(
128129
ActorCriticPolicy,
129130
env,
130-
learning_rate=learning_rate, #does nothing for categorical output spaces
131+
learning_rate=learning_rate, # does nothing for categorical output spaces
131132
n_steps=n_steps,
132133
gamma=gamma,
133134
gae_lambda=gae_lambda,
@@ -153,15 +154,15 @@ def __init__(
153154
# Update optimizer inside the policy if we want to use RMSProp
154155
# (original implementation) rather than Adam
155156

156-
self.rollout_buffer_class =RolloutBuffer
157+
self.rollout_buffer_class = RolloutBuffer
157158
self.rollout_buffer_kwargs = {}
158159
if is_categorical or is_mixed:
159-
self.rollout_buffer_class = CategoricalRolloutBuffer
160+
self.rollout_buffer_class = CategoricalRolloutBuffer
160161
self.rollout_buffer_kwargs['is_mixed'] = is_mixed
161-
162+
162163
if _init_setup_model:
163164
self.a2c_setup_model()
164-
165+
165166
def a2c_setup_model(self) -> None:
166167
self._setup_lr_schedule()
167168
self.set_random_seed(self.seed)
@@ -184,8 +185,6 @@ def a2c_setup_model(self) -> None:
184185
# pytype:enable=not-instantiable
185186
self.policy = self.policy.to(self.device)
186187

187-
188-
189188
def train(self) -> None:
190189
"""
191190
Update policy using the currently gathered
@@ -196,7 +195,8 @@ def train(self) -> None:
196195

197196
# Update optimizer learning rate
198197
if isinstance(self.policy.action_dist, DiagGaussianDistribution):
199-
update_learning_rate(self.policy.log_std_optimizer, self.policy.log_std_schedule(self._current_progress_remaining))
198+
update_learning_rate(self.policy.log_std_optimizer,
199+
self.policy.log_std_schedule(self._current_progress_remaining))
200200

201201
policy_losses, value_losses, entropy_losses = [], [], []
202202
log_std_s = []
@@ -245,7 +245,8 @@ def train(self) -> None:
245245
theta = params[0]
246246
if isinstance(self.policy.action_dist, DiagGaussianDistribution) and not self.fixed_std:
247247
if self.max_policy_grad_norm is not None and self.max_policy_grad_norm > 0.0:
248-
th.nn.utils.clip_grad_norm(self.policy.log_std, max_norm=self.max_policy_grad_norm, error_if_nonfinite=True)
248+
th.nn.utils.clip_grad_norm_(self.policy.log_std, max_norm=self.max_policy_grad_norm,
249+
error_if_nonfinite=True)
249250
self.policy.log_std_optimizer.step()
250251
log_std_grad = self.policy.log_std.grad.clone().detach().cpu().numpy()
251252
self.policy.log_std_optimizer.zero_grad()
@@ -264,14 +265,15 @@ def train(self) -> None:
264265
theta_grad_maxs.append(theta_grad.max().item())
265266
theta_grad_mins.append(theta_grad.min().item())
266267

267-
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
268+
explained_var = explained_variance(self.rollout_buffer.values.flatten(),
269+
self.rollout_buffer.returns.flatten())
268270

269271
self._n_updates += 1
270-
272+
271273
iteration = self.policy.model.get_iteration()
272274
num_trees = self.policy.model.get_num_trees()
273275
value_iteration = 0
274-
276+
275277
if isinstance(iteration, tuple):
276278
iteration, value_iteration = iteration
277279
value_num_trees = 0
@@ -379,7 +381,8 @@ def collect_rollouts(
379381
and infos[idx].get("terminal_observation") is not None
380382
and infos[idx].get("TimeLimit.truncated", False)
381383
):
382-
terminal_obs = infos[idx]["terminal_observation"] if self.is_categorical else self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
384+
terminal_obs = infos[idx]["terminal_observation"] if self.is_categorical else \
385+
self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
383386
with th.no_grad():
384387
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
385388
rewards[idx] += self.gamma * terminal_value
@@ -397,14 +400,14 @@ def collect_rollouts(
397400

398401
with th.no_grad():
399402
# Compute value for the last timestep
400-
values = self.policy.predict_values(new_obs, requires_grad=False) # type: ignore[arg-type] if self.is_categorical else self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type]
403+
values = self.policy.predict_values(new_obs, requires_grad=False)
401404

402405
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
403406

404407
callback.on_rollout_end()
405408

406409
return True
407-
410+
408411
def learn(
409412
self: SelfA2C,
410413
total_timesteps: int,
@@ -414,7 +417,7 @@ def learn(
414417
reset_num_timesteps: bool = True,
415418
progress_bar: bool = False,
416419
) -> SelfA2C:
417-
420+
418421
iteration = 0
419422
total_timesteps, callback = self._setup_learn(
420423
total_timesteps,
@@ -426,7 +429,8 @@ def learn(
426429
callback.on_training_start(locals(), globals())
427430
assert self.env is not None
428431
while self.num_timesteps < total_timesteps:
429-
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
432+
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer,
433+
n_rollout_steps=self.n_steps)
430434

431435
if continue_training is False:
432436
break
@@ -441,8 +445,10 @@ def learn(
441445
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
442446
self.logger.record("time/iterations", iteration, exclude="tensorboard")
443447
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
444-
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
445-
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
448+
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in
449+
self.ep_info_buffer]))
450+
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in
451+
self.ep_info_buffer]))
446452
self.logger.record("time/fps", fps)
447453
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
448454
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
@@ -453,10 +459,10 @@ def learn(
453459
callback.on_training_end()
454460

455461
return self
456-
462+
457463
def save(self,
458-
path: Union[str, pathlib.Path, io.BufferedIOBase],
459-
exclude: Optional[Iterable[str]] = None,
460-
include: Optional[Iterable[str]] = None,
461-
) -> None:
464+
path: Union[str, pathlib.Path, io.BufferedIOBase],
465+
exclude: Optional[Iterable[str]] = None,
466+
include: Optional[Iterable[str]] = None,
467+
) -> None:
462468
self.policy.model.save_model(path)

0 commit comments

Comments
 (0)