8
8
import torch
9
9
import torch .nn as nn
10
10
import torch .optim as optim
11
- import warnings
11
+ from itertools import chain
12
12
13
13
from rsl_rl .modules import ActorCritic
14
14
from rsl_rl .modules .rnd import RandomNetworkDistillation
@@ -43,13 +43,19 @@ def __init__(
43
43
rnd_cfg : dict | None = None ,
44
44
# Symmetry parameters
45
45
symmetry_cfg : dict | None = None ,
46
+ # Distributed training parameters
47
+ multi_gpu_cfg : dict | None = None ,
46
48
):
49
+ # device-related parameters
47
50
self .device = device
48
-
49
- self .desired_kl = desired_kl
50
- self .schedule = schedule
51
- self .learning_rate = learning_rate
52
- self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
51
+ self .is_multi_gpu = multi_gpu_cfg is not None
52
+ # Multi-GPU parameters
53
+ if multi_gpu_cfg is not None :
54
+ self .gpu_global_rank = multi_gpu_cfg ["global_rank" ]
55
+ self .gpu_world_size = multi_gpu_cfg ["world_size" ]
56
+ else :
57
+ self .gpu_global_rank = 0
58
+ self .gpu_world_size = 1
53
59
54
60
# RND components
55
61
if rnd_cfg is not None :
@@ -68,7 +74,7 @@ def __init__(
68
74
use_symmetry = symmetry_cfg ["use_data_augmentation" ] or symmetry_cfg ["use_mirror_loss" ]
69
75
# Print that we are not using symmetry
70
76
if not use_symmetry :
71
- warnings . warn ("Symmetry not used for learning. We will use it for logging instead." )
77
+ print ("Symmetry not used for learning. We will use it for logging instead." )
72
78
# If function is a string then resolve it to a function
73
79
if isinstance (symmetry_cfg ["data_augmentation_func" ], str ):
74
80
symmetry_cfg ["data_augmentation_func" ] = string_to_callable (symmetry_cfg ["data_augmentation_func" ])
@@ -102,6 +108,10 @@ def __init__(
102
108
self .lam = lam
103
109
self .max_grad_norm = max_grad_norm
104
110
self .use_clipped_value_loss = use_clipped_value_loss
111
+ self .desired_kl = desired_kl
112
+ self .schedule = schedule
113
+ self .learning_rate = learning_rate
114
+ self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
105
115
106
116
def init_storage (
107
117
self , training_type , num_envs , num_transitions_per_env , actor_obs_shape , critic_obs_shape , actions_shape
@@ -267,11 +277,28 @@ def update(self): # noqa: C901
267
277
)
268
278
kl_mean = torch .mean (kl )
269
279
270
- if kl_mean > self .desired_kl * 2.0 :
271
- self .learning_rate = max (1e-5 , self .learning_rate / 1.5 )
272
- elif kl_mean < self .desired_kl / 2.0 and kl_mean > 0.0 :
273
- self .learning_rate = min (1e-2 , self .learning_rate * 1.5 )
274
-
280
+ # Reduce the KL divergence across all GPUs
281
+ if self .is_multi_gpu :
282
+ torch .distributed .all_reduce (kl_mean , op = torch .distributed .ReduceOp .SUM )
283
+ kl_mean /= self .gpu_world_size
284
+
285
+ # Update the learning rate
286
+ # Perform this adaptation only on the main process
287
+ # TODO: Is this needed? If KL-divergence is the "same" across all GPUs,
288
+ # then the learning rate should be the same across all GPUs.
289
+ if self .gpu_global_rank == 0 :
290
+ if kl_mean > self .desired_kl * 2.0 :
291
+ self .learning_rate = max (1e-5 , self .learning_rate / 1.5 )
292
+ elif kl_mean < self .desired_kl / 2.0 and kl_mean > 0.0 :
293
+ self .learning_rate = min (1e-2 , self .learning_rate * 1.5 )
294
+
295
+ # Update the learning rate for all GPUs
296
+ if self .is_multi_gpu :
297
+ lr_tensor = torch .tensor (self .learning_rate , device = self .device )
298
+ torch .distributed .broadcast (lr_tensor , src = 0 )
299
+ self .learning_rate = lr_tensor .item ()
300
+
301
+ # Update the learning rate for all parameter groups
275
302
for param_group in self .optimizer .param_groups :
276
303
param_group ["lr" ] = self .learning_rate
277
304
@@ -335,21 +362,30 @@ def update(self): # noqa: C901
335
362
if self .rnd :
336
363
# predict the embedding and the target
337
364
predicted_embedding = self .rnd .predictor (rnd_state_batch )
338
- target_embedding = self .rnd .target (rnd_state_batch )
365
+ target_embedding = self .rnd .target (rnd_state_batch ). detach ()
339
366
# compute the loss as the mean squared error
340
367
mseloss = torch .nn .MSELoss ()
341
- rnd_loss = mseloss (predicted_embedding , target_embedding . detach () )
368
+ rnd_loss = mseloss (predicted_embedding , target_embedding )
342
369
343
- # Gradient step
370
+ # Compute the gradients
344
371
# -- For PPO
345
372
self .optimizer .zero_grad ()
346
373
loss .backward ()
374
+ # -- For RND
375
+ if self .rnd :
376
+ self .rnd_optimizer .zero_grad () # type: ignore
377
+ rnd_loss .backward ()
378
+
379
+ # Collect gradients from all GPUs
380
+ if self .is_multi_gpu :
381
+ self .reduce_parameters ()
382
+
383
+ # Apply the gradients
384
+ # -- For PPO
347
385
nn .utils .clip_grad_norm_ (self .policy .parameters (), self .max_grad_norm )
348
386
self .optimizer .step ()
349
387
# -- For RND
350
388
if self .rnd_optimizer :
351
- self .rnd_optimizer .zero_grad ()
352
- rnd_loss .backward ()
353
389
self .rnd_optimizer .step ()
354
390
355
391
# Store the losses
@@ -389,3 +425,50 @@ def update(self): # noqa: C901
389
425
loss_dict ["symmetry" ] = mean_symmetry_loss
390
426
391
427
return loss_dict
428
+
429
+ """
430
+ Helper functions
431
+ """
432
+
433
+ def broadcast_parameters (self ):
434
+ """Broadcast model parameters to all GPUs."""
435
+ # obtain the model parameters on current GPU
436
+ model_params = [self .policy .state_dict ()]
437
+ if self .rnd :
438
+ model_params .append (self .rnd .predictor .state_dict ())
439
+ # broadcast the model parameters
440
+ torch .distributed .broadcast_object_list (model_params , src = 0 )
441
+ # load the model parameters on all GPUs from source GPU
442
+ self .policy .load_state_dict (model_params [0 ])
443
+ if self .rnd :
444
+ self .rnd .predictor .load_state_dict (model_params [1 ])
445
+
446
+ def reduce_parameters (self ):
447
+ """Collect gradients from all GPUs and average them.
448
+
449
+ This function is called after the backward pass to synchronize the gradients across all GPUs.
450
+ """
451
+ # Create a tensor to store the gradients
452
+ grads = [param .grad .view (- 1 ) for param in self .policy .parameters () if param .grad is not None ]
453
+ if self .rnd :
454
+ grads += [param .grad .view (- 1 ) for param in self .rnd .parameters () if param .grad is not None ]
455
+ all_grads = torch .cat (grads )
456
+
457
+ # Average the gradients across all GPUs
458
+ torch .distributed .all_reduce (all_grads , op = torch .distributed .ReduceOp .SUM )
459
+ all_grads /= self .gpu_world_size
460
+
461
+ # Get all parameters
462
+ all_params = self .policy .parameters ()
463
+ if self .rnd :
464
+ all_params = chain (all_params , self .rnd .parameters ())
465
+
466
+ # Update the gradients for all parameters with the reduced gradients
467
+ offset = 0
468
+ for param in all_params :
469
+ if param .grad is not None :
470
+ numel = param .numel ()
471
+ # copy data back from shared buffer
472
+ param .grad .data .copy_ (all_grads [offset : offset + numel ].view_as (param .grad .data ))
473
+ # update the offset for the next parameter
474
+ offset += numel
0 commit comments