-
Notifications
You must be signed in to change notification settings - Fork 188
Expand file tree
/
Copy pathtrain_unizero_multitask_ddp.py
More file actions
449 lines (387 loc) · 21.9 KB
/
train_unizero_multitask_ddp.py
File metadata and controls
449 lines (387 loc) · 21.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
import logging
import os
from collections import defaultdict
from functools import partial
from typing import Tuple, Optional, List, Dict
import concurrent.futures
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from tensorboardX import SummaryWriter
from ding.config import compile_config
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer
from ding.worker import BaseLearner
from lzero.entry.utils import (
EVALUATION_TIMEOUT,
TemperatureScheduler,
allocate_batch_size,
compute_task_weights,
compute_unizero_mt_normalized_stats,
log_buffer_memory_usage,
safe_eval,
symlog,
inv_symlog,
)
from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler, symlog, inv_symlog
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
timer = EasyTimer()
def train_unizero_multitask_ddp(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
Entry point for UniZero multi-task training (DDP version).
Args:
- input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks.
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): Path to the pretrained model checkpoint file.
- max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training.
- max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): The converged policy.
"""
# Initialize the temperature scheduler for task weighting.
initial_temperature = 10.0
final_temperature = 1.0
threshold_steps = int(1e4)
temperature_scheduler = TemperatureScheduler(
initial_temp=initial_temperature,
final_temp=final_temperature,
threshold_steps=threshold_steps,
mode='linear'
)
rank = get_rank()
world_size = get_world_size()
# Task partitioning
total_tasks = len(input_cfg_list)
tasks_per_rank = total_tasks // world_size
remainder = total_tasks % world_size
if rank < remainder:
start_idx = rank * (tasks_per_rank + 1)
end_idx = start_idx + tasks_per_rank + 1
num_tasks_for_this_rank = tasks_per_rank + 1
else:
start_idx = rank * tasks_per_rank + remainder
end_idx = start_idx + tasks_per_rank
num_tasks_for_this_rank = tasks_per_rank
tasks_for_this_rank = input_cfg_list[start_idx:end_idx]
# Ensure at least one task is assigned
if len(tasks_for_this_rank) == 0:
logging.warning(f"Rank {rank}: no tasks assigned, continuing execution.")
# Initialize empty lists to avoid errors in later code
cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], []
else:
print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}")
cfgs = []
game_buffers = []
collector_envs = []
evaluator_envs = []
collectors = []
evaluators = []
if tasks_for_this_rank:
# Use the first task’s config to create a shared policy
task_id, [cfg, create_cfg] = tasks_for_this_rank[0]
for config in tasks_for_this_rank:
config[1][0].policy.task_num = num_tasks_for_this_rank
assert create_cfg.policy.type in ['unizero_multitask',
'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'"
if create_cfg.policy.type == 'unizero_multitask':
from lzero.mcts import UniZeroGameBuffer as GameBuffer
if create_cfg.policy.type == 'sampled_unizero_multitask':
from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'Configured device: {cfg.policy.device}')
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create shared policy
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
print(f"rank {rank} created the policy!")
if model_path is not None:
logging.info(f'Loading pretrained model: {model_path}')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Finished loading pretrained model: {model_path}')
log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}')
tb_logger = SummaryWriter(log_dir)
# Create shared learner
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
policy_config = cfg.policy
# Handle each task assigned to this rank
for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank):
cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu'
cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
policy_config = cfg.policy
policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode
policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode
# Create environments
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed + task_id)
evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False)
set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda)
# Create game buffer, collector, and evaluator
replay_buffer = GameBuffer(policy_config)
collector = Collector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
task_id=task_id
)
evaluator = Evaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
task_id=task_id
)
cfgs.append(cfg)
replay_buffer.batch_size = cfg.policy.batch_size[task_id]
game_buffers.append(replay_buffer)
collector_envs.append(collector_env)
evaluator_envs.append(evaluator_env)
collectors.append(collector)
evaluators.append(evaluator)
learner.call_hook('before_run')
value_priority_tasks = {}
buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size
update_per_collect = cfg.policy.update_per_collect
task_exploitation_weight = None
# Create task reward dictionary
task_rewards = {} # {task_id: reward}
while True:
# Dynamically adjust batch_size
if cfg.policy.allocated_batch_sizes:
clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4)
allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale)
if rank == 0:
print("Allocated batch_sizes: ", allocated_batch_sizes)
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
zip(cfgs, collectors, evaluators, game_buffers)):
task_id = cfg.policy.task_id
if isinstance(allocated_batch_sizes, dict):
cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size)
elif isinstance(allocated_batch_sizes, list):
# Use the index in the list or task_id as fallback
cfg.policy.batch_size = allocated_batch_sizes[idx] if idx < len(allocated_batch_sizes) else cfg.policy.batch_size
else:
logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}")
# Also update the policy config (use the full list for compatibility)
policy._cfg.batch_size = allocated_batch_sizes
# Perform data collection and evaluation for each task on this rank
for idx, (cfg, collector, evaluator, replay_buffer) in enumerate(
zip(cfgs, collectors, evaluators, game_buffers)):
policy_config = cfg.policy
# Log buffer memory usage
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id)
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0
}
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
print('=' * 20)
print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...')
evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)
# Perform safe evaluation
stop, reward = safe_eval(evaluator, learner, collector, rank, world_size)
if stop is None or reward is None:
print(f"Rank {rank} encountered an issue during evaluation, continuing training...")
task_rewards[cfg.policy.task_id] = float('inf') # Assign max difficulty if evaluation fails
else:
try:
eval_mean_reward = reward.get('eval_episode_return_mean', float('inf'))
print(f"Evaluation reward for task {cfg.policy.task_id}: {eval_mean_reward}")
task_rewards[cfg.policy.task_id] = eval_mean_reward
except Exception as e:
print(f"Error extracting evaluation reward: {e}")
task_rewards[cfg.policy.task_id] = float('inf') # Assign max reward if error occurs
print('=' * 20)
print(f'Starting data collection for Rank {rank}, task_id: {cfg.policy.task_id}...')
print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ')
# Reset policy state before each collection (important for multi-task setups)
collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments')
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()
if cfg.policy.buffer_reanalyze_freq >= 1:
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \
replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(
reanalyze_batch_size / cfg.policy.reanalyze_partition):
with timer:
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time cost: {timer.value}')
logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}')
try:
logging.info(f'Rank {rank}: Waiting at post-collection barrier...')
dist.barrier()
logging.info(f'Rank {rank}: All ranks completed data collection, proceeding...')
except Exception as e:
logging.error(f'Rank {rank}: Post-collection barrier failed, error: {e}')
raise e
# Check if there is enough data for training
local_not_enough_data = any(
replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size
for replay_buffer in game_buffers
)
logging.info(f"Rank {rank} local_not_enough_data:{local_not_enough_data}")
flag_tensor = torch.tensor(1.0 if local_not_enough_data else 0.0, device=cfg.policy.device)
dist.all_reduce(flag_tensor, op=dist.ReduceOp.MAX)
not_enough_data = (flag_tensor.item() > 0.5)
if rank == 0:
logging.info(f"Global not_enough_data status: {not_enough_data}")
if not not_enough_data:
for i in range(update_per_collect):
train_data_multi_task = []
envstep_multi_task = 0
for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)):
envstep_multi_task += collector.envstep
if isinstance(cfg.policy.batch_size, (list, tuple)):
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
elif isinstance(cfg.policy.batch_size, dict):
batch_size = cfg.policy.batch_size[cfg.policy.task_id]
else:
batch_size = cfg.policy.batch_size
if replay_buffer.get_num_of_transitions() > batch_size:
if cfg.policy.buffer_reanalyze_freq >= 1:
if i % reanalyze_interval == 0 and \
replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(
reanalyze_batch_size / cfg.policy.reanalyze_partition):
with timer:
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time cost: {timer.value}')
train_data = replay_buffer.sample(batch_size, policy)
train_data.append(cfg.policy.task_id)
train_data_multi_task.append(train_data)
else:
logging.warning(
f'Not enough data in replay buffer to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}'
)
break
if train_data_multi_task:
learn_kwargs = {'task_weights': None, "train_iter": learner.train_iter}
log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs)
# Compute task_exploitation_weight if needed
if i == 0:
try:
dist.barrier()
if cfg.policy.use_task_exploitation_weight:
all_obs_loss = [None for _ in range(world_size)]
merged_obs_loss_task = {}
for cfg, replay_buffer in zip(cfgs, game_buffers):
task_id = cfg.policy.task_id
if f'noreduce_obs_loss_task{task_id}' in log_vars[0]:
merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}']
dist.all_gather_object(all_obs_loss, merged_obs_loss_task)
global_obs_loss_task = {}
for obs_loss_task in all_obs_loss:
if obs_loss_task:
global_obs_loss_task.update(obs_loss_task)
if global_obs_loss_task:
task_exploitation_weight = compute_task_weights(
global_obs_loss_task,
option="rank",
temperature=1,
)
dist.broadcast_object_list([task_exploitation_weight], src=0)
print(f"Rank {rank}, task_exploitation_weight (by task_id): {task_exploitation_weight}")
else:
logging.warning(f"Rank {rank}: Failed to compute global obs_loss task weights, obs_loss data is empty.")
task_exploitation_weight = None
else:
task_exploitation_weight = None
learn_kwargs['task_weight'] = task_exploitation_weight
except Exception as e:
logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}')
raise e
if cfg.policy.use_priority:
for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)):
task_id = cfg.policy.task_id
replay_buffer.update_priority(
train_data_multi_task[idx],
log_vars[0][f'value_priority_task{task_id}']
)
current_priorities = log_vars[0][f'value_priority_task{task_id}']
mean_priority = np.mean(current_priorities)
std_priority = np.std(current_priorities)
alpha = 0.1 # smoothing factor
if f'running_mean_priority_task{task_id}' not in value_priority_tasks:
value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority
else:
value_priority_tasks[f'running_mean_priority_task{task_id}'] = (
alpha * mean_priority +
(1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}']
)
running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}']
normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6)
if cfg.policy.print_task_priority_logs:
print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, "
f"Running mean priority: {running_mean_priority:.8f}, "
f"Std: {std_priority:.8f}")
train_epoch += 1
policy.recompute_pos_emb_diff_and_clear_cache()
# Synchronize all ranks after training
try:
dist.barrier()
logging.info(f'Rank {rank}: passed training synchronization barrier')
except Exception as e:
logging.error(f'Rank {rank}: synchronization barrier failed, error: {e}')
break
# Check termination conditions
try:
local_envsteps = [collector.envstep for collector in collectors]
total_envsteps = [None for _ in range(world_size)]
dist.all_gather_object(total_envsteps, local_envsteps)
all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps])
max_envstep_reached = torch.all(all_envsteps >= max_env_step)
global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device)
all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)]
dist.all_gather(all_train_iters, global_train_iter)
max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter)
if max_envstep_reached.item() or max_train_iter_reached.item():
logging.info(f'Rank {rank}: termination condition reached')
dist.barrier()
break
except Exception as e:
logging.error(f'Rank {rank}: termination check failed, error: {e}')
break
learner.call_hook('after_run')
return policy