Skip to content

Commit b8e362b

Browse files
committed
Adds multi-gpu training pipeline
Approved-by: Clemens Schwarke
1 parent 5be19f3 commit b8e362b

File tree

4 files changed

+235
-41
lines changed

4 files changed

+235
-41
lines changed

Diff for: rsl_rl/algorithms/distillation.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55

66
# torch
7+
import torch
78
import torch.nn as nn
89
import torch.optim as optim
910

@@ -26,23 +27,34 @@ def __init__(
2627
learning_rate=1e-3,
2728
loss_type="mse",
2829
device="cpu",
30+
# Distributed training parameters
31+
multi_gpu_cfg: dict | None = None,
2932
):
33+
# device-related parameters
3034
self.device = device
31-
self.learning_rate = learning_rate
35+
self.is_multi_gpu = multi_gpu_cfg is not None
36+
# Multi-GPU parameters
37+
if multi_gpu_cfg is not None:
38+
self.gpu_global_rank = multi_gpu_cfg["global_rank"]
39+
self.gpu_world_size = multi_gpu_cfg["world_size"]
40+
else:
41+
self.gpu_global_rank = 0
42+
self.gpu_world_size = 1
3243

3344
self.rnd = None # TODO: remove when runner has a proper base class
3445

3546
# distillation components
3647
self.policy = policy
3748
self.policy.to(self.device)
3849
self.storage = None # initialized later
39-
self.optimizer = optim.Adam(self.policy.student.parameters(), lr=self.learning_rate)
50+
self.optimizer = optim.Adam(self.policy.student.parameters(), lr=learning_rate)
4051
self.transition = RolloutStorage.Transition()
4152
self.last_hidden_states = None
4253

4354
# distillation parameters
4455
self.num_learning_epochs = num_learning_epochs
4556
self.gradient_length = gradient_length
57+
self.learning_rate = learning_rate
4658

4759
# initialize the loss function
4860
if loss_type == "mse":
@@ -113,6 +125,8 @@ def update(self):
113125
if cnt % self.gradient_length == 0:
114126
self.optimizer.zero_grad()
115127
loss.backward()
128+
if self.is_multi_gpu:
129+
self.reduce_parameters()
116130
self.optimizer.step()
117131
self.policy.detach_hidden_states()
118132
loss = 0
@@ -130,3 +144,37 @@ def update(self):
130144
loss_dict = {"behavior": mean_behavior_loss}
131145

132146
return loss_dict
147+
148+
"""
149+
Helper functions
150+
"""
151+
152+
def broadcast_parameters(self):
153+
"""Broadcast model parameters to all GPUs."""
154+
# obtain the model parameters on current GPU
155+
model_params = [self.policy.state_dict()]
156+
# broadcast the model parameters
157+
torch.distributed.broadcast_object_list(model_params, src=0)
158+
# load the model parameters on all GPUs from source GPU
159+
self.policy.load_state_dict(model_params[0])
160+
161+
def reduce_parameters(self):
162+
"""Collect gradients from all GPUs and average them.
163+
164+
This function is called after the backward pass to synchronize the gradients across all GPUs.
165+
"""
166+
# Create a tensor to store the gradients
167+
grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
168+
all_grads = torch.cat(grads)
169+
# Average the gradients across all GPUs
170+
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
171+
all_grads /= self.gpu_world_size
172+
# Update the gradients for all parameters with the reduced gradients
173+
offset = 0
174+
for param in self.policy.parameters():
175+
if param.grad is not None:
176+
numel = param.numel()
177+
# copy data back from shared buffer
178+
param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
179+
# update the offset for the next parameter
180+
offset += numel

Diff for: rsl_rl/algorithms/ppo.py

+100-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.nn as nn
1010
import torch.optim as optim
11-
import warnings
11+
from itertools import chain
1212

1313
from rsl_rl.modules import ActorCritic
1414
from rsl_rl.modules.rnd import RandomNetworkDistillation
@@ -43,13 +43,19 @@ def __init__(
4343
rnd_cfg: dict | None = None,
4444
# Symmetry parameters
4545
symmetry_cfg: dict | None = None,
46+
# Distributed training parameters
47+
multi_gpu_cfg: dict | None = None,
4648
):
49+
# device-related parameters
4750
self.device = device
48-
49-
self.desired_kl = desired_kl
50-
self.schedule = schedule
51-
self.learning_rate = learning_rate
52-
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
51+
self.is_multi_gpu = multi_gpu_cfg is not None
52+
# Multi-GPU parameters
53+
if multi_gpu_cfg is not None:
54+
self.gpu_global_rank = multi_gpu_cfg["global_rank"]
55+
self.gpu_world_size = multi_gpu_cfg["world_size"]
56+
else:
57+
self.gpu_global_rank = 0
58+
self.gpu_world_size = 1
5359

5460
# RND components
5561
if rnd_cfg is not None:
@@ -68,7 +74,7 @@ def __init__(
6874
use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"]
6975
# Print that we are not using symmetry
7076
if not use_symmetry:
71-
warnings.warn("Symmetry not used for learning. We will use it for logging instead.")
77+
print("Symmetry not used for learning. We will use it for logging instead.")
7278
# If function is a string then resolve it to a function
7379
if isinstance(symmetry_cfg["data_augmentation_func"], str):
7480
symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"])
@@ -102,6 +108,10 @@ def __init__(
102108
self.lam = lam
103109
self.max_grad_norm = max_grad_norm
104110
self.use_clipped_value_loss = use_clipped_value_loss
111+
self.desired_kl = desired_kl
112+
self.schedule = schedule
113+
self.learning_rate = learning_rate
114+
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
105115

106116
def init_storage(
107117
self, training_type, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, actions_shape
@@ -267,11 +277,28 @@ def update(self): # noqa: C901
267277
)
268278
kl_mean = torch.mean(kl)
269279

270-
if kl_mean > self.desired_kl * 2.0:
271-
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
272-
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
273-
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
274-
280+
# Reduce the KL divergence across all GPUs
281+
if self.is_multi_gpu:
282+
torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
283+
kl_mean /= self.gpu_world_size
284+
285+
# Update the learning rate
286+
# Perform this adaptation only on the main process
287+
# TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
288+
# then the learning rate should be the same across all GPUs.
289+
if self.gpu_global_rank == 0:
290+
if kl_mean > self.desired_kl * 2.0:
291+
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
292+
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
293+
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
294+
295+
# Update the learning rate for all GPUs
296+
if self.is_multi_gpu:
297+
lr_tensor = torch.tensor(self.learning_rate, device=self.device)
298+
torch.distributed.broadcast(lr_tensor, src=0)
299+
self.learning_rate = lr_tensor.item()
300+
301+
# Update the learning rate for all parameter groups
275302
for param_group in self.optimizer.param_groups:
276303
param_group["lr"] = self.learning_rate
277304

@@ -335,21 +362,30 @@ def update(self): # noqa: C901
335362
if self.rnd:
336363
# predict the embedding and the target
337364
predicted_embedding = self.rnd.predictor(rnd_state_batch)
338-
target_embedding = self.rnd.target(rnd_state_batch)
365+
target_embedding = self.rnd.target(rnd_state_batch).detach()
339366
# compute the loss as the mean squared error
340367
mseloss = torch.nn.MSELoss()
341-
rnd_loss = mseloss(predicted_embedding, target_embedding.detach())
368+
rnd_loss = mseloss(predicted_embedding, target_embedding)
342369

343-
# Gradient step
370+
# Compute the gradients
344371
# -- For PPO
345372
self.optimizer.zero_grad()
346373
loss.backward()
374+
# -- For RND
375+
if self.rnd:
376+
self.rnd_optimizer.zero_grad() # type: ignore
377+
rnd_loss.backward()
378+
379+
# Collect gradients from all GPUs
380+
if self.is_multi_gpu:
381+
self.reduce_parameters()
382+
383+
# Apply the gradients
384+
# -- For PPO
347385
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
348386
self.optimizer.step()
349387
# -- For RND
350388
if self.rnd_optimizer:
351-
self.rnd_optimizer.zero_grad()
352-
rnd_loss.backward()
353389
self.rnd_optimizer.step()
354390

355391
# Store the losses
@@ -389,3 +425,50 @@ def update(self): # noqa: C901
389425
loss_dict["symmetry"] = mean_symmetry_loss
390426

391427
return loss_dict
428+
429+
"""
430+
Helper functions
431+
"""
432+
433+
def broadcast_parameters(self):
434+
"""Broadcast model parameters to all GPUs."""
435+
# obtain the model parameters on current GPU
436+
model_params = [self.policy.state_dict()]
437+
if self.rnd:
438+
model_params.append(self.rnd.predictor.state_dict())
439+
# broadcast the model parameters
440+
torch.distributed.broadcast_object_list(model_params, src=0)
441+
# load the model parameters on all GPUs from source GPU
442+
self.policy.load_state_dict(model_params[0])
443+
if self.rnd:
444+
self.rnd.predictor.load_state_dict(model_params[1])
445+
446+
def reduce_parameters(self):
447+
"""Collect gradients from all GPUs and average them.
448+
449+
This function is called after the backward pass to synchronize the gradients across all GPUs.
450+
"""
451+
# Create a tensor to store the gradients
452+
grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None]
453+
if self.rnd:
454+
grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None]
455+
all_grads = torch.cat(grads)
456+
457+
# Average the gradients across all GPUs
458+
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
459+
all_grads /= self.gpu_world_size
460+
461+
# Get all parameters
462+
all_params = self.policy.parameters()
463+
if self.rnd:
464+
all_params = chain(all_params, self.rnd.parameters())
465+
466+
# Update the gradients for all parameters with the reduced gradients
467+
offset = 0
468+
for param in all_params:
469+
if param.grad is not None:
470+
numel = param.numel()
471+
# copy data back from shared buffer
472+
param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data))
473+
# update the offset for the next parameter
474+
offset += numel

Diff for: rsl_rl/modules/rnd.py

+3
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def __init__(
106106
self.predictor = self._build_mlp(num_states, predictor_hidden_dims, num_outputs, activation).to(self.device)
107107
self.target = self._build_mlp(num_states, target_hidden_dims, num_outputs, activation).to(self.device)
108108

109+
# make target network not trainable
110+
self.target.eval()
111+
109112
def get_intrinsic_reward(self, rnd_state) -> tuple[torch.Tensor, torch.Tensor]:
110113
# note: the counter is updated number of env steps per learning iteration
111114
self.update_counter += 1

0 commit comments

Comments
 (0)