Skip to content

Commit 1bf1b0c

Browse files
committed
fix(pu): fix test
1 parent a2a7205 commit 1bf1b0c

File tree

10 files changed

+42
-12
lines changed

10 files changed

+42
-12
lines changed

lzero/entry/train_unizero_multitask_segment_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def train_unizero_multitask_segment_ddp(
231231
# Process each task assigned to the current rank.
232232
for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank):
233233
# Set a unique random seed for each task.
234-
cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu'
234+
cfg.policy.device = 'cuda' if cfg.policy.device == 'cuda' and torch.cuda.is_available() else 'cpu'
235235
cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
236236
policy_config = cfg.policy
237237
policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode

lzero/mcts/tests/test_game_buffer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
use_priority=True,
1717
action_type='fixed_action_space',
1818
game_segment_length=20,
19+
model=dict(
20+
action_space_size=6,
21+
value_support_range=(-10, 10, 1),
22+
reward_support_range=(-10, 10, 1),
23+
),
1924
)
2025
)
2126

lzero/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ def __init__(
734734
self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type)
735735
else:
736736
self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
737-
self.norm = build_normalization(norm_type, dim=3)(num_channels, *observation_shape[1:])
737+
self.norm = build_normalization(norm_type, dim=2)(num_channels)
738738

739739
self.resblocks = nn.ModuleList([
740740
ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False)

lzero/policy/tests/config/atari_muzero_config_for_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
self_supervised_learning_loss=True, # default is False
5151
discrete_action_encoding_type='one_hot',
5252
norm_type='BN',
53+
value_support_range=(-300., 301., 1.),
54+
reward_support_range=(-300., 301., 1.),
5355
),
5456
cuda=True,
5557
env_type='not_board_games',

lzero/policy/tests/config/cartpole_muzero_config_for_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
model=dict(
3131
observation_shape=4,
3232
action_space_size=2,
33-
model_type='mlp',
33+
model_type='mlp',
3434
lstm_hidden_size=128,
3535
latent_state_dim=128,
3636
self_supervised_learning_loss=True, # NOTE: default is False.
3737
discrete_action_encoding_type='one_hot',
38-
norm_type='BN',
38+
norm_type='BN',
39+
value_support_range=(-300., 301., 1.),
40+
reward_support_range=(-300., 301., 1.),
3941
),
4042
cuda=True,
4143
env_type='not_board_games',

lzero/policy/unizero.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,29 @@
33
from collections import defaultdict
44
from typing import Any, Dict, List, Tuple, Union
55

6+
import numpy as np
7+
import torch
68
import torch.nn.functional as F
79
import wandb
810
from ding.model import model_wrap
911
from ding.utils import POLICY_REGISTRY
12+
from lzero.mcts import UniZeroMCTSCtree as MCTSCtree
1013
from lzero.model import ImageTransforms
11-
from lzero.policy import (DiscreteSupport, InverseScalarTransform, from,
12-
import, lzero.policy, mz_network_output_unpack,
13-
phi_transform, prepare_obs,
14+
from lzero.policy import (DiscreteSupport, InverseScalarTransform,
15+
mz_network_output_unpack, phi_transform, prepare_obs,
1416
prepare_obs_stack_for_unizero, scalar_transform,
1517
select_action, to_torch_float_tensor)
1618
from lzero.policy.head_clip_manager import (HeadClipConfig, HeadClipManager,
1719
create_head_clip_manager_from_dict)
20+
from lzero.policy.muzero import MuZeroPolicy
1821
from lzero.policy.utils import initialize_pad_batch
1922
from torch.nn.utils.convert_parameters import (parameters_to_vector,
2023
vector_to_parameters)
2124

2225
from .utils import configure_optimizers_nanogpt
2326

27+
28+
def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float):
2429
"""
2530
Efficiently scale all weights of a module using vectorized operations.
2631
"""
@@ -129,6 +134,8 @@ class UniZeroPolicy(MuZeroPolicy):
129134
# (int) The save interval of the model.
130135
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ),
131136
world_model_cfg=dict(
137+
# (str) The encoder type, e.g., 'resnet' or 'vit'.
138+
encoder_type='resnet',
132139
# (bool) If True, the action space of the environment is continuous, otherwise discrete.
133140
continuous_action_space=False,
134141
# (int) The number of tokens per block.
@@ -142,7 +149,7 @@ class UniZeroPolicy(MuZeroPolicy):
142149
# (bool) Whether to use GRU gating mechanism.
143150
gru_gating=False,
144151
# (str) The device to be used for computation, e.g., 'cpu' or 'cuda'.
145-
device='cuda',
152+
device='cpu',
146153
# (bool) Whether to analyze simulation normalization.
147154
analysis_sim_norm=False,
148155
# (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent.
@@ -235,6 +242,9 @@ class UniZeroPolicy(MuZeroPolicy):
235242
num_experts_per_tok=1,
236243
# (int) Total number of experts in the transformer MoE.
237244
num_experts_of_moe_in_transformer=8,
245+
# ****** Priority ******
246+
# (bool) Whether to use priority when sampling training data from the buffer.
247+
use_priority=False,
238248
),
239249
),
240250
# ****** common ******
@@ -298,6 +308,9 @@ class UniZeroPolicy(MuZeroPolicy):
298308
policy_ls_eps_end=0.01,
299309
# (int) Number of training steps to decay label smoothing epsilon from start to end
300310
policy_ls_eps_decay_steps=50000,
311+
312+
label_smoothing_eps=0.1, # TODO: For value
313+
301314
# (bool) Whether to use continuous (fixed) label smoothing throughout training
302315
use_continuous_label_smoothing=False,
303316
# (float) Fixed epsilon value for continuous label smoothing (only used when use_continuous_label_smoothing=True)

lzero/policy/unizero_multitask.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
from ding.model import model_wrap
99
from ding.utils import POLICY_REGISTRY
10-
from lzero.entry.utils import initialize_zeros_batch
1110
from lzero.mcts import UniZeroMCTSCtree as MCTSCtree
1211
from lzero.model import ImageTransforms
1312
from lzero.policy import (DiscreteSupport, InverseScalarTransform,
@@ -16,7 +15,7 @@
1615
select_action, to_torch_float_tensor)
1716
from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized
1817

19-
from .utils import configure_optimizers_nanogpt
18+
from .utils import configure_optimizers_nanogpt, initialize_zeros_batch
2019

2120
# Please replace the path with the actual location of your LibMTL library.
2221
sys.path.append('/path/to/your/LibMTL')
@@ -254,7 +253,7 @@ class UniZeroMTPolicy(UniZeroPolicy):
254253
analysis_dormant_ratio_weight_rank=False,
255254
# (float) The threshold for a dormant neuron.
256255
dormant_threshold=0.01,
257-
256+
share_head=False,
258257
),
259258
),
260259
# ****** common ******

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,12 @@ def create_config(
130130
max_tokens=2 * num_unroll_steps,
131131
context_length=2 * infer_context_length,
132132
encoder_type='vit',
133+
device='cuda',
134+
game_segment_length=20,
133135
),
134136
),
137+
device='cuda',
138+
game_segment_length=20,
135139
learning_rate=0.0001,
136140
weight_decay=1e-2,
137141
batch_size=batch_size,

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def main(env_id, seed):
7171
env_num=max(collector_env_num, evaluator_env_num),
7272
num_simulations=num_simulations,
7373
game_segment_length=game_segment_length,
74+
device='cuda',
75+
use_priority=True,
7476
),
7577
),
7678
# Learning settings

zoo/atari/envs/atari_lightzero_env.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@ def reset(self) -> dict:
139139
),
140140
})
141141

142+
# self._reward_space = gym.spaces.Box(
143+
# low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32
144+
# )
142145
self._reward_space = gym.spaces.Box(
143-
low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32
146+
low=-9999, high=9999, shape=(1,), dtype=np.float32
144147
)
145148

146149
self._init_flag = True

0 commit comments

Comments
 (0)