33from collections import defaultdict
44from typing import Any , Dict , List , Tuple , Union
55
6+ import numpy as np
7+ import torch
68import torch .nn .functional as F
79import wandb
810from ding .model import model_wrap
911from ding .utils import POLICY_REGISTRY
12+ from lzero .mcts import UniZeroMCTSCtree as MCTSCtree
1013from lzero .model import ImageTransforms
11- from lzero .policy import (DiscreteSupport , InverseScalarTransform , from ,
12- import , lzero .policy , mz_network_output_unpack ,
13- phi_transform , prepare_obs ,
14+ from lzero .policy import (DiscreteSupport , InverseScalarTransform ,
15+ mz_network_output_unpack , phi_transform , prepare_obs ,
1416 prepare_obs_stack_for_unizero , scalar_transform ,
1517 select_action , to_torch_float_tensor )
1618from lzero .policy .head_clip_manager import (HeadClipConfig , HeadClipManager ,
1719 create_head_clip_manager_from_dict )
20+ from lzero .policy .muzero import MuZeroPolicy
1821from lzero .policy .utils import initialize_pad_batch
1922from torch .nn .utils .convert_parameters import (parameters_to_vector ,
2023 vector_to_parameters )
2124
2225from .utils import configure_optimizers_nanogpt
2326
27+
28+ def scale_module_weights_vectorized (module : torch .nn .Module , scale_factor : float ):
2429 """
2530 Efficiently scale all weights of a module using vectorized operations.
2631 """
@@ -129,6 +134,8 @@ class UniZeroPolicy(MuZeroPolicy):
129134 # (int) The save interval of the model.
130135 learn = dict (learner = dict (hook = dict (save_ckpt_after_iter = 10000 , ), ), ),
131136 world_model_cfg = dict (
137+ # (str) The encoder type, e.g., 'resnet' or 'vit'.
138+ encoder_type = 'resnet' ,
132139 # (bool) If True, the action space of the environment is continuous, otherwise discrete.
133140 continuous_action_space = False ,
134141 # (int) The number of tokens per block.
@@ -142,7 +149,7 @@ class UniZeroPolicy(MuZeroPolicy):
142149 # (bool) Whether to use GRU gating mechanism.
143150 gru_gating = False ,
144151 # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'.
145- device = 'cuda ' ,
152+ device = 'cpu ' ,
146153 # (bool) Whether to analyze simulation normalization.
147154 analysis_sim_norm = False ,
148155 # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent.
@@ -235,6 +242,9 @@ class UniZeroPolicy(MuZeroPolicy):
235242 num_experts_per_tok = 1 ,
236243 # (int) Total number of experts in the transformer MoE.
237244 num_experts_of_moe_in_transformer = 8 ,
245+ # ****** Priority ******
246+ # (bool) Whether to use priority when sampling training data from the buffer.
247+ use_priority = False ,
238248 ),
239249 ),
240250 # ****** common ******
@@ -298,6 +308,9 @@ class UniZeroPolicy(MuZeroPolicy):
298308 policy_ls_eps_end = 0.01 ,
299309 # (int) Number of training steps to decay label smoothing epsilon from start to end
300310 policy_ls_eps_decay_steps = 50000 ,
311+
312+ label_smoothing_eps = 0.1 , # TODO: For value
313+
301314 # (bool) Whether to use continuous (fixed) label smoothing throughout training
302315 use_continuous_label_smoothing = False ,
303316 # (float) Fixed epsilon value for continuous label smoothing (only used when use_continuous_label_smoothing=True)
0 commit comments