Skip to content

Commit 3788eb7

Browse files
committed
polish(pu): polish minotor-log and adapt to ale/xxx-v5 style game
1 parent b4c3ba8 commit 3788eb7

File tree

10 files changed

+265
-213
lines changed

10 files changed

+265
-213
lines changed

lzero/model/unizero_world_models/world_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,7 +1911,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19111911
# F.cosine_similarity 计算的是相似度,范围是 [-1, 1]。我们希望最大化它,
19121912
# 所以最小化 1 - similarity。
19131913
# reduction='none' 使得我们可以像原来一样处理mask
1914-
print("predict_latent_loss_type == 'cos_sim'")
1914+
# print("predict_latent_loss_type == 'cos_sim'")
19151915
cosine_sim_loss = 1 - F.cosine_similarity(logits_observations, labels_observations, dim=-1)
19161916
loss_obs = cosine_sim_loss
19171917

@@ -2034,6 +2034,16 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
20342034
value_priority=value_priority,
20352035
intermediate_tensor_x=intermediate_tensor_x,
20362036
obs_embeddings=detached_obs_embeddings, # <-- 新增
2037+
2038+
# logits_value_mean=outputs.logits_value.mean(),
2039+
# logits_value_max=outputs.logits_value.max(),
2040+
# logits_value_min=outputs.logits_value.min(),
2041+
# logits_policy_mean=outputs.logits_policy.mean(),
2042+
# logits_policy_max=outputs.logits_policy.max(),
2043+
# logits_policy_min=outputs.logits_policy.min(),
2044+
logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算
2045+
logits_reward=outputs.logits_rewards.detach(),
2046+
logits_policy=outputs.logits_policy.detach(),
20372047
)
20382048
else:
20392049
return LossWithIntermediateLosses(
@@ -2064,6 +2074,16 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
20642074
value_priority=value_priority,
20652075
intermediate_tensor_x=intermediate_tensor_x,
20662076
obs_embeddings=detached_obs_embeddings, # <-- 新增
2077+
2078+
# logits_value_mean=outputs.logits_value.mean(),
2079+
# logits_value_max=outputs.logits_value.max(),
2080+
# logits_value_min=outputs.logits_value.min(),
2081+
# logits_policy_mean=outputs.logits_policy.mean(),
2082+
# logits_policy_max=outputs.logits_policy.max(),
2083+
# logits_policy_min=outputs.logits_policy.min(),
2084+
logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算
2085+
logits_reward=outputs.logits_rewards.detach(),
2086+
logits_policy=outputs.logits_policy.detach(),
20672087
)
20682088

20692089

lzero/model/unizero_world_models/world_model_multitask.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19341934

19351935
value_priority=value_priority,
19361936
obs_embeddings=detached_obs_embeddings, # <-- 新增
1937+
# logits_value_mean=outputs.logits_value.mean(),
1938+
# logits_value_max=outputs.logits_value.max(),
1939+
# logits_value_min=outputs.logits_value.min(),
1940+
# logits_policy_mean=outputs.logits_policy.mean(),
1941+
# logits_policy_max=outputs.logits_policy.max(),
1942+
# logits_policy_min=outputs.logits_policy.min(),
1943+
logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算
1944+
logits_reward=outputs.logits_rewards.detach(),
1945+
logits_policy=outputs.logits_policy.detach(),
19371946

19381947
)
19391948
else:
@@ -1964,6 +1973,15 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
19641973

19651974
value_priority=value_priority,
19661975
obs_embeddings=detached_obs_embeddings, # <-- 新增
1976+
# logits_value_mean=outputs.logits_value.mean(),
1977+
# logits_value_max=outputs.logits_value.max(),
1978+
# logits_value_min=outputs.logits_value.min(),
1979+
# logits_policy_mean=outputs.logits_policy.mean(),
1980+
# logits_policy_max=outputs.logits_policy.max(),
1981+
# logits_policy_min=outputs.logits_policy.min(),
1982+
logits_value=outputs.logits_value.detach(), # 使用detach(),因为它仅用于分析和裁剪,不参与梯度计算
1983+
logits_reward=outputs.logits_rewards.detach(),
1984+
logits_policy=outputs.logits_policy.detach(),
19671985

19681986

19691987
)

lzero/policy/unizero.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -833,15 +833,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
833833
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']
834834

835835
latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms']
836-
logits_value_mean=self.intermediate_losses['logits_value_mean']
837-
logits_value_max=self.intermediate_losses['logits_value_max']
838-
logits_value_min=self.intermediate_losses['logits_value_min']
839-
logits_policy_mean=self.intermediate_losses['logits_policy_mean']
840-
logits_policy_max=self.intermediate_losses['logits_policy_max']
841-
logits_policy_min=self.intermediate_losses['logits_policy_min']
842-
temperature_value=self.intermediate_losses['temperature_value']
843-
temperature_reward=self.intermediate_losses['temperature_reward']
844-
temperature_policy=self.intermediate_losses['temperature_policy']
836+
837+
# logits_value_mean=self.intermediate_losses['logits_value_mean']
838+
# logits_value_max=self.intermediate_losses['logits_value_max']
839+
# logits_value_min=self.intermediate_losses['logits_value_min']
840+
# logits_policy_mean=self.intermediate_losses['logits_policy_mean']
841+
# logits_policy_max=self.intermediate_losses['logits_policy_max']
842+
# logits_policy_min=self.intermediate_losses['logits_policy_min']
843+
# temperature_value=self.intermediate_losses['temperature_value']
844+
# temperature_reward=self.intermediate_losses['temperature_reward']
845+
# temperature_policy=self.intermediate_losses['temperature_policy']
845846

846847
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
847848
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
@@ -875,7 +876,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
875876
# --- [优化建议] 增加 log_alpha 裁剪作为安全措施 ---
876877
with torch.no_grad():
877878
# 将 alpha 限制在例如 [1e-4, 10.0] 的范围内
878-
self.log_alpha.clamp_(np.log(1e-4), np.log(10.0))
879+
# self.log_alpha.clamp_(np.log(1e-4), np.log(10.0))
880+
self.log_alpha.clamp_(np.log(5e-3), np.log(10.0))
881+
879882

880883
# --- 使用当前更新后的 alpha (截断梯度流) ---
881884
current_alpha = self.log_alpha.exp().detach()
@@ -1047,12 +1050,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
10471050
'analysis/l2_norm_after': self.l2_norm_after,
10481051
'analysis/grad_norm_before': self.grad_norm_before,
10491052
'analysis/grad_norm_after': self.grad_norm_after,
1050-
"logits_value_mean":logits_value_mean,
1051-
"logits_value_max":logits_value_max,
1052-
"logits_value_min":logits_value_min,
1053-
"logits_policy_mean":logits_policy_mean,
1054-
"logits_policy_max":logits_policy_max,
1055-
"logits_policy_min":logits_policy_min,
1053+
1054+
# "logits_value_mean":logits_value_mean,
1055+
# "logits_value_max":logits_value_max,
1056+
# "logits_value_min":logits_value_min,
1057+
# "logits_policy_mean":logits_policy_mean,
1058+
# "logits_policy_max":logits_policy_max,
1059+
# "logits_policy_min":logits_policy_min,
10561060

10571061
"temperature_value":temperature_value,
10581062
"temperature_reward":temperature_reward,
@@ -1621,12 +1625,12 @@ def _monitor_vars_learn(self) -> List[str]:
16211625
'total_grad_norm_before_clip_wm',
16221626

16231627
# ==================== Logits Statistics ====================
1624-
'logits_value_mean',
1625-
'logits_value_max',
1626-
'logits_value_min',
1627-
'logits_policy_mean',
1628-
'logits_policy_max',
1629-
'logits_policy_min',
1628+
# 'logits_value_mean',
1629+
# 'logits_value_max',
1630+
# 'logits_value_min',
1631+
# 'logits_policy_mean',
1632+
# 'logits_policy_max',
1633+
# 'logits_policy_min',
16301634

16311635
# ==================== Temperature Parameters ====================
16321636
'temperature_value',

lzero/policy/unizero_multitask.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ def _monitor_model_norms(self) -> Dict[str, float]:
622622
module_groups = {
623623
'encoder': world_model.tokenizer.encoder,
624624
'transformer': world_model.transformer,
625-
'head_value': world_model.head_values, # Note: multi-task uses head_values (plural)
626-
'head_reward': world_model.head_rewards,
627-
'head_policy': world_model.head_policies, # Note: multi-task uses head_policies (plural)
625+
'head_value': world_model.head_value_multi_task, # Note: multi-task uses head_value (plural)
626+
'head_reward': world_model.head_rewards_multi_task,
627+
'head_policy': world_model.head_policy_multi_task, # Note: multi-task uses head_policies (plural)
628628
}
629629

630630
for group_name, group_module in module_groups.items():
@@ -669,9 +669,9 @@ def _monitor_gradient_norms(self) -> Dict[str, float]:
669669
module_groups = {
670670
'encoder': world_model.tokenizer.encoder,
671671
'transformer': world_model.transformer,
672-
'head_value': world_model.head_values,
673-
'head_reward': world_model.head_rewards,
674-
'head_policy': world_model.head_policies,
672+
'head_value': world_model.head_value_multi_task,
673+
'head_reward': world_model.head_rewards_multi_task,
674+
'head_policy': world_model.head_policy_multi_task,
675675
}
676676

677677
for group_name, group_module in module_groups.items():
@@ -1169,7 +1169,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_ite
11691169
# --- [优化建议] 增加 log_alpha 裁剪作为安全措施 ---
11701170
with torch.no_grad():
11711171
# 将 alpha 限制在例如 [1e-4, 10.0] 的范围内
1172-
self.log_alpha.clamp_(np.log(1e-4), np.log(10.0))
1172+
# self.log_alpha.clamp_(np.log(1e-4), np.log(10.0))
1173+
self.log_alpha.clamp_(np.log(5e-3), np.log(10.0))
1174+
11731175

11741176
# --- 使用当前更新后的 alpha (截断梯度流) ---
11751177
current_alpha = self.log_alpha.exp().detach()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
DI-engine>=0.5.3
2-
gymnasium[atari]==0.28.0
2+
# gymnasium[atari]==0.28.0
33
numpy>=1.24.1,<2
44
pympler
55
minigrid
Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
from easydict import EasyDict
22

33
atari_env_action_space_map = EasyDict({
4-
'AlienNoFrameskip-v4': 18,
5-
'AmidarNoFrameskip-v4': 10,
6-
'AssaultNoFrameskip-v4': 7,
7-
'AsterixNoFrameskip-v4': 9,
8-
'BankHeistNoFrameskip-v4': 18,
9-
'BattleZoneNoFrameskip-v4': 18,
10-
'ChopperCommandNoFrameskip-v4': 18,
11-
'CrazyClimberNoFrameskip-v4': 9,
12-
'DemonAttackNoFrameskip-v4': 6,
13-
'FreewayNoFrameskip-v4': 3,
14-
'FrostbiteNoFrameskip-v4': 18,
15-
'GopherNoFrameskip-v4': 8,
16-
'HeroNoFrameskip-v4': 18,
17-
'JamesbondNoFrameskip-v4': 18,
18-
'KangarooNoFrameskip-v4': 18,
19-
'KrullNoFrameskip-v4': 18,
20-
'KungFuMasterNoFrameskip-v4': 14,
21-
'PrivateEyeNoFrameskip-v4': 18,
22-
'RoadRunnerNoFrameskip-v4': 18,
23-
'UpNDownNoFrameskip-v4': 6,
24-
'PongNoFrameskip-v4': 6,
25-
'MsPacmanNoFrameskip-v4': 9,
26-
'QbertNoFrameskip-v4': 6,
27-
'SeaquestNoFrameskip-v4': 18,
28-
'BoxingNoFrameskip-v4': 18,
29-
'BreakoutNoFrameskip-v4': 4,
30-
'SpaceInvadersNoFrameskip-v4': 6,
31-
'BeamRiderNoFrameskip-v4': 9,
32-
'GravitarNoFrameskip-v4': 18,
4+
'ALE/Alien-v5': 18,
5+
'ALE/Amidar-v5': 10,
6+
'ALE/Assault-v5': 7,
7+
'ALE/Asterix-v5': 9,
8+
'ALE/BankHeist-v5': 18,
9+
'ALE/BattleZone-v5': 18,
10+
'ALE/ChopperCommand-v5': 18,
11+
'ALE/CrazyClimber-v5': 9,
12+
'ALE/DemonAttack-v5': 6,
13+
'ALE/Freeway-v5': 3,
14+
'ALE/Frostbite-v5': 18,
15+
'ALE/Gopher-v5': 8,
16+
'ALE/Hero-v5': 18,
17+
'ALE/Jamesbond-v5': 18,
18+
'ALE/Kangaroo-v5': 18,
19+
'ALE/Krull-v5': 18,
20+
'ALE/KungFuMaster-v5': 14,
21+
'ALE/PrivateEye-v5': 18,
22+
'ALE/RoadRunner-v5': 18,
23+
'ALE/UpNDown-v5': 6,
24+
'ALE/Pong-v5': 6,
25+
'ALE/MsPacman-v5': 9,
26+
'ALE/Qbert-v5': 6,
27+
'ALE/Seaquest-v5': 18,
28+
'ALE/Boxing-v5': 18,
29+
'ALE/Breakout-v5': 4,
30+
'ALE/SpaceInvaders-v5': 6,
31+
'ALE/BeamRider-v5': 9,
32+
'ALE/Gravitar-v5': 18,
3333
})

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def create_config(
9393
"""
9494
return EasyDict(dict(
9595
env=dict(
96+
frame_skip=1, # TODO
9697
stop_value=int(1e6),
9798
env_id=env_id,
9899
observation_shape=(3, 64, 64),
@@ -162,8 +163,8 @@ def create_config(
162163
# use_priority=False, # TODO=====
163164
priority_prob_alpha=1,
164165
priority_prob_beta=1,
165-
# encoder_type='vit',
166-
encoder_type='resnet',
166+
encoder_type='vit',
167+
# encoder_type='resnet',
167168
use_normal_head=True,
168169
use_softmoe_head=False,
169170
use_moe_head=False,
@@ -195,7 +196,8 @@ def create_config(
195196
# use_adaptive_entropy_weight=False,
196197

197198
# (float) 自适应alpha优化器的学习率
198-
adaptive_entropy_alpha_lr=1e-4,
199+
# adaptive_entropy_alpha_lr=1e-4,
200+
adaptive_entropy_alpha_lr=1e-3,
199201
target_entropy_start_ratio =0.98,
200202
# target_entropy_end_ratio =0.9, # TODO=====
201203
# target_entropy_end_ratio =0.7,
@@ -289,15 +291,18 @@ def generate_configs(
289291
# --- Experiment Name Template ---
290292
# Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name.
291293
# benchmark_tag = "data_unizero_mt_refactor1010_debug" # e.g., unizero_atari_mt_20250612
292-
benchmark_tag = "data_unizero_mt_refactor1012" # e.g., unizero_atari_mt_20250612
294+
benchmark_tag = "data_unizero_mt_refactor1024" # e.g., unizero_atari_mt_20250612
293295

294296
# model_tag = f"vit-small_moe8_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}_not-share-head"
295297
# model_tag = f"resnet_noprior_noalpha_nomoe_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
296298

297299
# model_tag = f"vit_prior_alpha-100k-098-07_encoder-100k-30-10_moe8_head-inner-ln_adamw-wd1e-2_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
298300

299301
# model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-100k-098-07_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
300-
model_tag = f"resnet_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
302+
model_tag = f"vit_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
303+
304+
# model_tag = f"resnet_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
305+
301306
# model_tag = f"resnet_encoder-100k-30-10-true_label-smooth_prior_alpha-150k-098-05_moe8_head-inner-ln_adamw-wd1e-2-all_tbs512_tran-nlayer{num_layers}_brf{buffer_reanalyze_freq}"
302307

303308
exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/'
@@ -309,7 +314,10 @@ def generate_configs(
309314
buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers
310315
)
311316
config.policy.task_id = task_id
312-
config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}"
317+
# --- MODIFIED LINE ---
318+
# Correctly extract the game name from 'ALE/GameName-v5' format.
319+
game_name = env_id.split('/')[1].split('-')[0]
320+
config.exp_name = exp_name_prefix + f"{game_name}_seed{seed}"
313321
configs.append([task_id, [config, create_env_manager()]])
314322
return configs
315323

@@ -348,6 +356,8 @@ def create_env_manager() -> EasyDict:
348356
export CUDA_VISIBLE_DEVICES=4,5,6,7
349357
350358
cd /path/to/your/project/
359+
/mnt/shared-storage-user/puyuan/lz/bin/python -m torch.distributed.launch --nproc_per_node=4 --master_port=29502 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/shared-storage-user/puyuan/code_20250828/LightZero/log/20251024_vit_nlayer4_alpha-100k-098-05.log
360+
351361
python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/nfs/zhangjinouwen/puyuan/LightZero/log/20251012_resnet_nlayer4_alpha-100k-098-05.log
352362
/path/to/this/script.py 2>&1 | tee /path/to/your/log/file.log
353363
"""
@@ -370,22 +380,23 @@ def create_env_manager() -> EasyDict:
370380
max_env_step = int(5e6) # TODO
371381
reanalyze_ratio = 0.0
372382

383+
# --- MODIFIED SECTION: Standardized env_id_list formats ---
373384
if num_games == 3:
374-
env_id_list = ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4']
385+
env_id_list = ['ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5']
375386
elif num_games == 8:
376387
env_id_list = [
377-
'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4',
378-
'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4',
388+
'ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5', 'ALE/Boxing-v5',
389+
'ALE/Alien-v5', 'ALE/ChopperCommand-v5', 'ALE/Hero-v5', 'ALE/RoadRunner-v5',
379390
]
380391
elif num_games == 26:
381392
env_id_list = [
382-
'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4',
383-
'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4',
384-
'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4',
385-
'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4',
386-
'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4',
387-
'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4',
388-
'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4',
393+
'ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5', 'ALE/Boxing-v5',
394+
'ALE/Alien-v5', 'ALE/ChopperCommand-v5', 'ALE/Hero-v5', 'ALE/RoadRunner-v5',
395+
'ALE/Amidar-v5', 'ALE/Assault-v5', 'ALE/Asterix-v5', 'ALE/BankHeist-v5',
396+
'ALE/BattleZone-v5', 'ALE/CrazyClimber-v5', 'ALE/DemonAttack-v5', 'ALE/Freeway-v5',
397+
'ALE/Frostbite-v5', 'ALE/Gopher-v5', 'ALE/Jamesbond-v5', 'ALE/Kangaroo-v5',
398+
'ALE/Krull-v5', 'ALE/KungFuMaster-v5', 'ALE/PrivateEye-v5', 'ALE/UpNDown-v5',
399+
'ALE/Qbert-v5', 'ALE/Breakout-v5',
389400
]
390401
else:
391402
raise ValueError(f"Unsupported number of environments: {num_games}")

0 commit comments

Comments
 (0)