Skip to content

Commit ab746d1

Browse files
committed
fix(pu): fix some merge typo
1 parent e7a8796 commit ab746d1

File tree

12 files changed

+48
-43
lines changed

12 files changed

+48
-43
lines changed

lzero/entry/train_muzero_multitask_segment_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import torch.distributed as dist
1010
from ding.config import compile_config
11-
from ding.envs import IEnvManager, create_env_manager, get_vec_env_setting
11+
from ding.envs import create_env_manager, get_vec_env_setting
1212
from ding.policy import Policy, create_policy
1313
from ding.rl_utils import get_epsilon_greedy_fn
1414
from ding.utils import EasyTimer, set_pkg_seed, get_rank, get_world_size

lzero/entry/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _is_lora_param(name: str) -> bool:
121121
return bool(_LORA_PAT.search(name))
122122

123123

124-
def freeze_non_lora(
124+
def freeze_non_lora_parameters(
125125
module: nn.Module,
126126
freeze: bool = True,
127127
*,

lzero/model/common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,9 +623,6 @@ def __init__(
623623
self.norm_before_last_linear = nn.LayerNorm([num_channels, spatial_size, spatial_size], eps=1e-5)
624624
self.last_linear = nn.Linear(linear_in_dim, embedding_dim, bias=False)
625625

626-
elif self.observation_shape[1] in [84, 96]:
627-
self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
628-
629626
self.final_norm_option_in_encoder = final_norm_option_in_encoder
630627
if self.final_norm_option_in_encoder == 'LayerNorm':
631628
self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5)

lzero/model/unizero_model_multitask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim:
106106
self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25)
107107
self.tokenizer = Tokenizer(
108108
encoder=self.representation_network,
109-
decoder_network=self.decoder_network,
109+
decoder=self.decoder_network,
110110
with_lpips=False,
111111
obs_type=world_model_cfg.obs_type
112112
)
@@ -162,7 +162,7 @@ def _init_image_components(self, world_model_cfg: EasyDict, observation_shape: S
162162
self.decoder_network = None
163163
self.tokenizer = Tokenizer(
164164
encoder=self.representation_network,
165-
decoder_network=self.decoder_network,
165+
decoder=self.decoder_network,
166166
with_lpips=False,
167167
obs_type=world_model_cfg.obs_type
168168
)
@@ -192,7 +192,7 @@ def _init_image_memory_components(self, world_model_cfg: EasyDict) -> None:
192192
)
193193
self.tokenizer = Tokenizer(
194194
encoder=self.representation_network,
195-
decoder_network=self.decoder_network,
195+
decoder=self.decoder_network,
196196
with_lpips=True,
197197
obs_type=world_model_cfg.obs_type
198198
)

lzero/model/unizero_world_models/tokenizer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,12 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id: int = 0) -> torch.T
115115
# This handles both single-task (a single nn.Module) and multi-task (an nn.ModuleList) scenarios.
116116
if isinstance(self.encoder, nn.ModuleList):
117117
if not 0 <= task_id < len(self.encoder):
118-
raise ValueError(
119-
f"Provided task_id {task_id} is invalid for the encoder list of size {len(self.encoder)}."
120-
)
121-
encoder_module = self.encoder[task_id]
118+
# raise ValueError(
119+
# f"Provided task_id {task_id} is invalid for the encoder list of size {len(self.encoder)}."
120+
# )
121+
encoder_module = self.encoder
122+
else:
123+
encoder_module = self.encoder[task_id]
122124
else:
123125
encoder_module = self.encoder
124126

lzero/model/unizero_world_models/world_model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform
1010

1111
from lzero.model.common import SimNorm
12-
from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank
12+
from lzero.model.utils import calculate_dormant_ratio, compute_average_weight_magnitude, compute_effective_rank
1313
from .kv_caching import KeysValues
1414
from .slicer import Head, PolicyHeadCont
1515
from .tokenizer import Tokenizer
@@ -45,6 +45,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
4545

4646
self.transformer = Transformer(self.config)
4747
self.task_num = 1
48+
self.env_num = self.config.env_num
4849
if self.config.device == 'cpu':
4950
self.device = torch.device('cpu')
5051
else:
@@ -70,7 +71,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
7071
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
7172

7273
self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4
73-
74+
if self.task_embed_option == "concat_task_embed":
75+
self.obs_per_embdding_dim = self.config.embed_dim - self.task_embed_dim
76+
else:
77+
self.obs_per_embdding_dim = self.config.embed_dim
7478
self.continuous_action_space = self.config.continuous_action_space
7579

7680
# Initialize action embedding table
@@ -1352,7 +1356,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13521356
# E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64)
13531357
inputs = batch['observations'].contiguous().view(-1, *shape[-3:])
13541358

1355-
dormant_ratio_encoder_dict = cal_dormant_ratio(
1359+
dormant_ratio_encoder_dict = calculate_dormant_ratio(
13561360
self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold
13571361
)
13581362
dormant_ratio_encoder = dormant_ratio_encoder_dict['global']
@@ -1370,11 +1374,11 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13701374
# The 'representation_layer_name' argument specifies the target layer within the model's named modules.
13711375

13721376
# Effective rank for the final linear layer of the encoder.
1373-
e_rank_last_linear = cal_effective_rank(
1377+
e_rank_last_linear = compute_effective_rank(
13741378
self.tokenizer.encoder, inputs, representation_layer_name="last_linear"
13751379
)
13761380
# Effective rank for the SimNorm layer of the encoder.
1377-
e_rank_sim_norm = cal_effective_rank(
1381+
e_rank_sim_norm = compute_effective_rank(
13781382
self.tokenizer.encoder, inputs, representation_layer_name="sim_norm"
13791383
)
13801384

@@ -1485,7 +1489,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
14851489
# ========= logging for analysis =========
14861490
if self.analysis_dormant_ratio_weight_rank:
14871491
# Calculate dormant ratio of the world model
1488-
dormant_ratio_world_model = cal_dormant_ratio(self, {
1492+
dormant_ratio_world_model = calculate_dormant_ratio(self, {
14891493
'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())},
14901494
dormant_threshold=self.dormant_threshold)
14911495
dormant_ratio_transformer = dormant_ratio_world_model['transformer']

lzero/model/unizero_world_models/world_model_multitask.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from lzero.model.common import SimNorm
2020
from lzero.model.unizero_world_models.world_model import WorldModel
2121
from lzero.model.utils import (
22-
cal_dormant_ratio,
23-
cal_effective_rank,
22+
calculate_dormant_ratio,
23+
compute_effective_rank,
2424
compute_average_weight_magnitude,
2525
)
2626

@@ -224,7 +224,7 @@ def __init__(self, config: TransformerConfig, tokenizer: Tokenizer) -> None:
224224

225225
# Apply weight initialization. The order of initialization is important.
226226
self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type))
227-
self._initialize_last_layer()
227+
self._initialize_last_layer_mt()
228228

229229
# --- Cache and State Initialization ---
230230
self._initialize_cache_structures()
@@ -415,7 +415,7 @@ def create_head_modules_softmoe(self) -> None:
415415
self.head_policy = self._create_head_softmoe(self.value_policy_tokens_pattern, self.action_space_size, soft_moe=self.get_soft_moe("policy_soft_moe"))
416416
self.head_value = self._create_head_softmoe(self.value_policy_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("value_soft_moe"))
417417

418-
def _initialize_last_layer(self) -> None:
418+
def _initialize_last_layer_mt(self) -> None:
419419
"""Initializes the last linear layer of prediction heads to zero for training stability."""
420420
last_linear_layer_init_zero = True
421421
print(f'world_model_mt.py:self.task_num:{self.task_num}')
@@ -1555,7 +1555,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
15551555
encoder_index = task_id
15561556
else:
15571557
encoder_index = 0
1558-
dormant_ratio_encoder_dict = cal_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(),
1558+
dormant_ratio_encoder_dict = calculate_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(),
15591559
dormant_threshold=self.dormant_threshold)
15601560

15611561
dormant_ratio_encoder = dormant_ratio_encoder_dict['global']
@@ -1564,9 +1564,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
15641564
avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer)
15651565
avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict)
15661566

1567-
e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear")
1567+
e_rank_last_linear = compute_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear")
15681568
try:
1569-
e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm")
1569+
e_rank_sim_norm = compute_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm")
15701570
except Exception as e:
15711571
e_rank_sim_norm = torch.tensor(0.)
15721572

@@ -1658,7 +1658,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
16581658
# if self.analysis_dormant_ratio_weight_rank:
16591659
if self.do_analysis:
16601660
# Calculate dormant ratio of the world model
1661-
dormant_ratio_world_model = cal_dormant_ratio(self, {
1661+
dormant_ratio_world_model = calculate_dormant_ratio(self, {
16621662
'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())},
16631663
dormant_threshold=self.dormant_threshold)
16641664
dormant_ratio_transformer = dormant_ratio_world_model['transformer']

lzero/policy/muzero.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
1616
from lzero.mcts import MuZeroMCTSPtree as MCTSPtree
1717
from lzero.model import ImageTransforms
18-
from lzero.model.utils import cal_dormant_ratio
18+
from lzero.model.utils import calculate_dormant_ratio
1919
from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \
2020
DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \
2121
prepare_obs, configure_optimizers
@@ -113,7 +113,7 @@ class MuZeroPolicy(Policy):
113113
# This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically.
114114
eval_offline=False,
115115
# (bool) Whether to calculate the dormant ratio.
116-
cal_dormant_ratio=False,
116+
calculate_dormant_ratio=False,
117117
# (bool) Whether to analyze simulation normalization.
118118
analysis_sim_norm=False,
119119
# (bool) Whether to analyze dormant ratio.
@@ -423,8 +423,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
423423

424424
# ========= logging for analysis =========
425425
# calculate dormant ratio of encoder
426-
if self._cfg.cal_dormant_ratio:
427-
self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(),
426+
if self._cfg.calculate_dormant_ratio:
427+
self.dormant_ratio_encoder = calculate_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(),
428428
percentage=self._cfg.dormant_threshold)
429429
# calculate L2 norm of latent state
430430
latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean()
@@ -470,7 +470,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
470470
latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output)
471471

472472
# ========= logging for analysis ===============
473-
if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio:
473+
if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.calculate_dormant_ratio:
474474
# calculate dormant ratio of encoder
475475
action_tmp = action_batch[:, step_k]
476476
if len(action_tmp.shape) == 1:
@@ -486,7 +486,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
486486
latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3]
487487
)
488488
state_action_encoding = torch.cat((latent_state, action_encoding), dim=1)
489-
self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network,
489+
self.dormant_ratio_dynamics = calculate_dormant_ratio(self._learn_model.dynamics_network,
490490
state_action_encoding.detach(),
491491
percentage=self._cfg.dormant_threshold)
492492
# ========= logging for analysis ===============

lzero/policy/scaling_transform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Union
22
import torch
3-
3+
import numpy as np
44

55
class DiscreteSupport(object):
66

@@ -11,7 +11,6 @@ def __init__(self, start: float, stop: float, step: float = 1., device: Union[st
1111
assert self.size > 0, "DiscreteSupport size must be greater than 0"
1212
self.step = step
1313

14-
1514
def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor:
1615
"""
1716
Overview:

lzero/policy/unizero_multitask.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,11 +522,12 @@ def _init_learn(self) -> None:
522522
self._cfg.augmentation,
523523
image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2])
524524
)
525-
self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1)
526-
self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1)
527-
self.inverse_scalar_transform_handle = InverseScalarTransform(
528-
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
529-
)
525+
526+
self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device)
527+
self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device)
528+
self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution)
529+
self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution)
530+
530531
self.intermediate_losses = defaultdict(float)
531532
self.l2_norm_before = 0.
532533
self.l2_norm_after = 0.

0 commit comments

Comments
 (0)