Skip to content

Commit a2a7205

Browse files
committed
fix(pu): fix import bug in test
1 parent 9e3cd2a commit a2a7205

File tree

2 files changed

+2
-110
lines changed

2 files changed

+2
-110
lines changed

lzero/entry/train_unizero_multitask_segment_eval.py

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ding.worker import BaseLearner
1616
from tensorboardX import SummaryWriter
1717

18-
from lzero.entry.utils import log_buffer_memory_usage
18+
from lzero.entry.utils import log_buffer_memory_usage, safe_eval, allocate_batch_size
1919
from lzero.policy import visit_count_temperature
2020
from lzero.mcts import UniZeroGameBuffer as GameBuffer
2121
from lzero.worker import MuZeroEvaluator as Evaluator
@@ -28,114 +28,6 @@
2828
)
2929

3030

31-
def safe_eval(
32-
evaluator: Evaluator,
33-
learner: BaseLearner,
34-
collector: Collector,
35-
rank: int,
36-
world_size: int,
37-
timeout: int = 12000
38-
) -> Tuple[Optional[bool], Optional[float]]:
39-
"""
40-
Overview:
41-
Safely evaluates the policy using the evaluator with a specified timeout. This wrapper prevents
42-
the entire training process from crashing due to evaluation-related issues like deadlocks.
43-
Arguments:
44-
- evaluator (:obj:`Evaluator`): The evaluator instance to run.
45-
- learner (:obj:`BaseLearner`): The learner instance, used to access checkpoint saving and training iteration.
46-
- collector (:obj:`Collector`): The collector instance, used to access the environment step count.
47-
- rank (:obj:`int`): The rank of the current process in distributed training.
48-
- world_size (:obj:`int`): The total number of processes.
49-
- timeout (:obj:`int`): The maximum time in seconds to wait for the evaluation to complete.
50-
Returns:
51-
- (:obj:`Tuple[Optional[bool], Optional[float]]`): A tuple containing the stop flag and the reward.
52-
Returns (None, None) if evaluation times out or an exception occurs.
53-
"""
54-
try:
55-
logging.info(f"Rank {rank}/{world_size}: Starting evaluation.")
56-
# Ensure the stop_event is clear before starting a new evaluation.
57-
evaluator.stop_event.clear()
58-
with concurrent.futures.ThreadPoolExecutor() as executor:
59-
future = executor.submit(
60-
evaluator.eval,
61-
learner.save_checkpoint,
62-
learner.train_iter,
63-
collector.envstep
64-
)
65-
try:
66-
stop, reward = future.result(timeout=timeout)
67-
except concurrent.futures.TimeoutError:
68-
# If evaluation exceeds the timeout, set the evaluator's stop event to terminate it gracefully.
69-
evaluator.stop_event.set()
70-
logging.warning(f"Rank {rank}/{world_size}: Evaluation timed out after {timeout} seconds.")
71-
return None, None
72-
73-
logging.info(f"Rank {rank}/{world_size}: Evaluation finished successfully.")
74-
return stop, reward
75-
except Exception as e:
76-
logging.error(f"Rank {rank}/{world_size}: An error occurred during evaluation: {e}", exc_info=True)
77-
return None, None
78-
79-
80-
def allocate_batch_size(
81-
cfgs: List[Any],
82-
game_buffers: List[GameBuffer],
83-
alpha: float = 1.0,
84-
clip_scale: int = 1
85-
) -> List[int]:
86-
"""
87-
Overview:
88-
Allocates batch sizes inversely proportional to the number of collected episodes for each task.
89-
This dynamic adjustment helps balance training focus across multiple tasks, prioritizing those
90-
with less data. The batch sizes are clipped to a dynamic range to maintain stability.
91-
Arguments:
92-
- cfgs (:obj:`List[Any]`): List of configuration objects for each task.
93-
- game_buffers (:obj:`List[GameBuffer]`): List of replay buffer instances for each task.
94-
- alpha (:obj:`float`): A hyperparameter controlling the degree of inverse proportionality. Defaults to 1.0.
95-
- clip_scale (:obj:`int`): A scaling factor to define the clipping range for the batch size. Defaults to 1.
96-
Returns:
97-
- (:obj:`List[int]`): A list of allocated batch sizes for each task.
98-
"""
99-
# Extract the number of collected episodes from each task's buffer.
100-
buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers]
101-
102-
world_size = get_world_size()
103-
rank = get_rank()
104-
105-
# Gather the episode counts from all ranks.
106-
all_task_num_of_collected_episodes_obj = [None for _ in range(world_size)]
107-
dist.all_gather_object(all_task_num_of_collected_episodes_obj, buffer_num_of_collected_episodes)
108-
109-
# Concatenate the lists from all ranks into a single flat list.
110-
all_task_num_of_collected_episodes = [item for sublist in all_task_num_of_collected_episodes_obj for item in sublist]
111-
if rank == 0:
112-
logging.info(f'All task collected episodes: {all_task_num_of_collected_episodes}')
113-
114-
# Calculate the inverse weight for each task. Adding 1 to avoid division by zero.
115-
inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes])
116-
inv_sum = np.sum(inv_episodes)
117-
118-
# The total batch size is defined in the config of the first task.
119-
total_batch_size = cfgs[0].policy.total_batch_size
120-
121-
# Define a dynamic range for batch sizes to prevent extreme values.
122-
avg_batch_size = total_batch_size / world_size
123-
min_batch_size = avg_batch_size / clip_scale
124-
max_batch_size = avg_batch_size * clip_scale
125-
126-
# Calculate task weights based on inverse proportionality, smoothed by alpha.
127-
task_weights = (inv_episodes / inv_sum) ** alpha
128-
batch_sizes = total_batch_size * task_weights
129-
130-
# Clip the batch sizes to the calculated dynamic range.
131-
batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size)
132-
133-
# Ensure batch sizes are integers.
134-
batch_sizes = [int(size) for size in batch_sizes]
135-
136-
return batch_sizes
137-
138-
13931
def train_unizero_multitask_segment_eval(
14032
input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]],
14133
seed: int = 0,

lzero/policy/unizero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class UniZeroPolicy(MuZeroPolicy):
142142
# (bool) Whether to use GRU gating mechanism.
143143
gru_gating=False,
144144
# (str) The device to be used for computation, e.g., 'cpu' or 'cuda'.
145-
device='cpu',
145+
device='cuda',
146146
# (bool) Whether to analyze simulation normalization.
147147
analysis_sim_norm=False,
148148
# (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent.

0 commit comments

Comments
 (0)