|
15 | 15 | from ding.worker import BaseLearner |
16 | 16 | from tensorboardX import SummaryWriter |
17 | 17 |
|
18 | | -from lzero.entry.utils import log_buffer_memory_usage |
| 18 | +from lzero.entry.utils import log_buffer_memory_usage, safe_eval, allocate_batch_size |
19 | 19 | from lzero.policy import visit_count_temperature |
20 | 20 | from lzero.mcts import UniZeroGameBuffer as GameBuffer |
21 | 21 | from lzero.worker import MuZeroEvaluator as Evaluator |
|
28 | 28 | ) |
29 | 29 |
|
30 | 30 |
|
31 | | -def safe_eval( |
32 | | - evaluator: Evaluator, |
33 | | - learner: BaseLearner, |
34 | | - collector: Collector, |
35 | | - rank: int, |
36 | | - world_size: int, |
37 | | - timeout: int = 12000 |
38 | | -) -> Tuple[Optional[bool], Optional[float]]: |
39 | | - """ |
40 | | - Overview: |
41 | | - Safely evaluates the policy using the evaluator with a specified timeout. This wrapper prevents |
42 | | - the entire training process from crashing due to evaluation-related issues like deadlocks. |
43 | | - Arguments: |
44 | | - - evaluator (:obj:`Evaluator`): The evaluator instance to run. |
45 | | - - learner (:obj:`BaseLearner`): The learner instance, used to access checkpoint saving and training iteration. |
46 | | - - collector (:obj:`Collector`): The collector instance, used to access the environment step count. |
47 | | - - rank (:obj:`int`): The rank of the current process in distributed training. |
48 | | - - world_size (:obj:`int`): The total number of processes. |
49 | | - - timeout (:obj:`int`): The maximum time in seconds to wait for the evaluation to complete. |
50 | | - Returns: |
51 | | - - (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the reward. |
52 | | - Returns (None, None) if evaluation times out or an exception occurs. |
53 | | - """ |
54 | | - try: |
55 | | - logging.info(f"Rank {rank}/{world_size}: Starting evaluation.") |
56 | | - # Ensure the stop_event is clear before starting a new evaluation. |
57 | | - evaluator.stop_event.clear() |
58 | | - with concurrent.futures.ThreadPoolExecutor() as executor: |
59 | | - future = executor.submit( |
60 | | - evaluator.eval, |
61 | | - learner.save_checkpoint, |
62 | | - learner.train_iter, |
63 | | - collector.envstep |
64 | | - ) |
65 | | - try: |
66 | | - stop, reward = future.result(timeout=timeout) |
67 | | - except concurrent.futures.TimeoutError: |
68 | | - # If evaluation exceeds the timeout, set the evaluator's stop event to terminate it gracefully. |
69 | | - evaluator.stop_event.set() |
70 | | - logging.warning(f"Rank {rank}/{world_size}: Evaluation timed out after {timeout} seconds.") |
71 | | - return None, None |
72 | | - |
73 | | - logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully.") |
74 | | - return stop, reward |
75 | | - except Exception as e: |
76 | | - logging.error(f"Rank {rank}/{world_size}: An error occurred during evaluation: {e}", exc_info=True) |
77 | | - return None, None |
78 | | - |
79 | | - |
80 | | -def allocate_batch_size( |
81 | | - cfgs: List[Any], |
82 | | - game_buffers: List[GameBuffer], |
83 | | - alpha: float = 1.0, |
84 | | - clip_scale: int = 1 |
85 | | -) -> List[int]: |
86 | | - """ |
87 | | - Overview: |
88 | | - Allocates batch sizes inversely proportional to the number of collected episodes for each task. |
89 | | - This dynamic adjustment helps balance training focus across multiple tasks, prioritizing those |
90 | | - with less data. The batch sizes are clipped to a dynamic range to maintain stability. |
91 | | - Arguments: |
92 | | - - cfgs (:obj:`List[Any]`): List of configuration objects for each task. |
93 | | - - game_buffers (:obj:`List[GameBuffer]`): List of replay buffer instances for each task. |
94 | | - - alpha (:obj:`float`): A hyperparameter controlling the degree of inverse proportionality. Defaults to 1.0. |
95 | | - - clip_scale (:obj:`int`): A scaling factor to define the clipping range for the batch size. Defaults to 1. |
96 | | - Returns: |
97 | | - - (:obj:`List[int]`): A list of allocated batch sizes for each task. |
98 | | - """ |
99 | | - # Extract the number of collected episodes from each task's buffer. |
100 | | - buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] |
101 | | - |
102 | | - world_size = get_world_size() |
103 | | - rank = get_rank() |
104 | | - |
105 | | - # Gather the episode counts from all ranks. |
106 | | - all_task_num_of_collected_episodes_obj = [None for _ in range(world_size)] |
107 | | - dist.all_gather_object(all_task_num_of_collected_episodes_obj, buffer_num_of_collected_episodes) |
108 | | - |
109 | | - # Concatenate the lists from all ranks into a single flat list. |
110 | | - all_task_num_of_collected_episodes = [item for sublist in all_task_num_of_collected_episodes_obj for item in sublist] |
111 | | - if rank == 0: |
112 | | - logging.info(f'All task collected episodes: {all_task_num_of_collected_episodes}') |
113 | | - |
114 | | - # Calculate the inverse weight for each task. Adding 1 to avoid division by zero. |
115 | | - inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) |
116 | | - inv_sum = np.sum(inv_episodes) |
117 | | - |
118 | | - # The total batch size is defined in the config of the first task. |
119 | | - total_batch_size = cfgs[0].policy.total_batch_size |
120 | | - |
121 | | - # Define a dynamic range for batch sizes to prevent extreme values. |
122 | | - avg_batch_size = total_batch_size / world_size |
123 | | - min_batch_size = avg_batch_size / clip_scale |
124 | | - max_batch_size = avg_batch_size * clip_scale |
125 | | - |
126 | | - # Calculate task weights based on inverse proportionality, smoothed by alpha. |
127 | | - task_weights = (inv_episodes / inv_sum) ** alpha |
128 | | - batch_sizes = total_batch_size * task_weights |
129 | | - |
130 | | - # Clip the batch sizes to the calculated dynamic range. |
131 | | - batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) |
132 | | - |
133 | | - # Ensure batch sizes are integers. |
134 | | - batch_sizes = [int(size) for size in batch_sizes] |
135 | | - |
136 | | - return batch_sizes |
137 | | - |
138 | | - |
139 | 31 | def train_unizero_multitask_segment_eval( |
140 | 32 | input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], |
141 | 33 | seed: int = 0, |
|
0 commit comments