Skip to content

Commit 6d7761a

Browse files
committed
feature(pu): add decode_loss for unizero atari
1 parent 3788eb7 commit 6d7761a

File tree

5 files changed

+104
-31
lines changed

5 files changed

+104
-31
lines changed

lzero/model/unizero_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def __init__(
126126
self.decoder_network_tokenizer = None
127127
else:
128128
raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}")
129+
129130
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer,
130-
with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option'])
131+
with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option'])
131132
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
132133
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
133134
print('==' * 20)
@@ -168,8 +169,19 @@ def __init__(
168169
if world_model_cfg.analysis_sim_norm:
169170
self.encoder_hook = FeatureAndGradientHook()
170171
self.encoder_hook.setup_hooks(self.representation_network)
171-
172-
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type)
172+
173+
if world_model_cfg.latent_recon_loss_weight==0:
174+
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type)
175+
else:
176+
# TODO =============
177+
self.decoder_network = LatentDecoder(
178+
embedding_dim=world_model_cfg.embed_dim,
179+
output_shape=[3, 64, 64],
180+
num_channels = 64,
181+
activation=self.activation,
182+
)
183+
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, with_lpips=True, obs_type=world_model_cfg.obs_type)
184+
173185
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
174186
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
175187
print('==' * 20)

lzero/model/unizero_world_models/lpips.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
from torchvision import models
1414
from tqdm import tqdm
1515

16+
# ==================================================================================
17+
# ================================ 核心修改部分 ====================================
18+
# ==================================================================================
19+
# 在导入 torch 和 torchvision 之后,但在实例化任何模型之前,设置 TORCH_HOME 环境变量。
20+
# 这会告诉 PyTorch 将所有通过 torch.hub 下载的模型(包括 torchvision.models 中的预训练模型)
21+
# 存放到您指定的目录下。
22+
# PyTorch 会自动在此目录下创建 'hub/checkpoints' 子文件夹。
23+
custom_torch_home = "/mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg"
24+
os.environ['TORCH_HOME'] = custom_torch_home
25+
# 确保目录存在,虽然 torch.hub 也会尝试创建,但提前创建更稳妥
26+
os.makedirs(os.path.join(custom_torch_home, 'hub', 'checkpoints'), exist_ok=True)
27+
# ==================================================================================
28+
# ==================================================================================
29+
1630

1731
class LPIPS(nn.Module):
1832
# Learned perceptual metric
@@ -22,19 +36,23 @@ def __init__(self, use_dropout: bool = True):
2236
self.chns = [64, 128, 256, 512, 512] # vg16 features
2337

2438
# Comment out the following line if you don't need perceptual loss
25-
# self.net = vgg16(pretrained=True, requires_grad=False)
26-
# self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
27-
# self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
28-
# self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
29-
# self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
30-
# self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
31-
# self.load_from_pretrained()
32-
# for param in self.parameters():
33-
# param.requires_grad = False
39+
# 现在,这一行将自动使用 TORCH_HOME 指定的路径
40+
self.net = vgg16(pretrained=True, requires_grad=False)
41+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
42+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
43+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
44+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
45+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
46+
self.load_from_pretrained()
47+
for param in self.parameters():
48+
param.requires_grad = False
3449

3550
def load_from_pretrained(self) -> None:
36-
ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary
51+
# 这一部分您已经修改正确,它用于加载 LPIPS 的线性层权重 (vgg.pth)
52+
# 我们让它和 TORCH_HOME 使用相同的根目录,以保持一致性。
53+
ckpt = get_ckpt_path(name="vgg_lpips", root=custom_torch_home)
3754
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
55+
print(f"Loaded LPIPS pretrained weights from: {ckpt}")
3856

3957
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
4058
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
@@ -74,7 +92,10 @@ def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) ->
7492
class vgg16(torch.nn.Module):
7593
def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
7694
super(vgg16, self).__init__()
95+
# 由于设置了 TORCH_HOME,这里的 pretrained=True 会在指定目录中查找或下载模型
96+
print("Loading vgg16 backbone...")
7797
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
98+
print("vgg16 backbone loaded.")
7899
self.slice1 = torch.nn.Sequential()
79100
self.slice2 = torch.nn.Sequential()
80101
self.slice3 = torch.nn.Sequential()
@@ -160,10 +181,26 @@ def md5_hash(path: str) -> str:
160181

161182
def get_ckpt_path(name: str, root: str, check: bool = False) -> str:
162183
assert name in URL_MAP
184+
# 这个函数现在只为 vgg.pth 服务,路径是正确的
163185
path = os.path.join(root, CKPT_MAP[name])
164186
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
165187
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
166188
download(URL_MAP[name], path)
167189
md5 = md5_hash(path)
168190
assert md5 == MD5_MAP[name], md5
169191
return path
192+
193+
# =======================
194+
# ===== 运行示例 ======
195+
# =======================
196+
if __name__ == '__main__':
197+
print(f"PyTorch Hub directory set to: {os.environ['TORCH_HOME']}")
198+
199+
# 第一次运行时,你会看到两个下载过程:
200+
# 1. 下载 vgg16-397923af.pth 到 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg/hub/checkpoints/
201+
# 2. 下载 vgg.pth 到 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg/
202+
# 之后再次运行,将不会有任何下载提示,直接从指定目录加载。
203+
204+
print("\nInitializing LPIPS model...")
205+
model = LPIPS()
206+
print("\nLPIPS model initialized successfully.")

lzero/model/unizero_world_models/world_model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,7 +1759,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
17591759

17601760
if self.obs_type == 'image':
17611761
# Reconstruct observations from latent state representations
1762-
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
1762+
reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
17631763

17641764
# ========== for visualization ==========
17651765
# Uncomment the lines below for visual analysis
@@ -1772,11 +1772,12 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
17721772
# ========== for visualization ==========
17731773

17741774
# ========== Calculate reconstruction loss and perceptual loss ============
1775-
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
1776-
# perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
1775+
latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
1776+
perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1
17771777

1778-
latent_recon_loss = self.latent_recon_loss
1779-
perceptual_loss = self.perceptual_loss
1778+
# TODO:
1779+
# latent_recon_loss = self.latent_recon_loss
1780+
# perceptual_loss = self.perceptual_loss
17801781

17811782
elif self.obs_type == 'vector':
17821783
perceptual_loss = torch.tensor(0., device=batch['observations'].device,

lzero/policy/unizero.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
840840
# logits_policy_mean=self.intermediate_losses['logits_policy_mean']
841841
# logits_policy_max=self.intermediate_losses['logits_policy_max']
842842
# logits_policy_min=self.intermediate_losses['logits_policy_min']
843-
# temperature_value=self.intermediate_losses['temperature_value']
844-
# temperature_reward=self.intermediate_losses['temperature_reward']
845-
# temperature_policy=self.intermediate_losses['temperature_policy']
843+
844+
temperature_value=self.intermediate_losses['temperature_value']
845+
temperature_reward=self.intermediate_losses['temperature_reward']
846+
temperature_policy=self.intermediate_losses['temperature_policy']
846847

847848
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
848849
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
@@ -898,13 +899,30 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
898899
self.reward_loss_weight = 1.
899900
self.policy_loss_weight = 1.
900901
self.ends_loss_weight = 0.
901-
total_loss = (
902-
self.reward_loss_weight * reward_loss +
903-
self.value_loss_weight * value_loss +
904-
self.policy_loss_weight * weighted_policy_loss +
905-
self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重
906-
# ... 如果还有其他损失项,也加进来 ...
907-
)
902+
903+
self.latent_recon_loss_weight = self._cfg.model.world_model_cfg.latent_recon_loss_weight # 默认使用固定值
904+
self.perceptual_loss_weight = self._cfg.model.world_model_cfg.perceptual_loss_weight # 默认使用固定值
905+
906+
if self.latent_recon_loss_weight>0:
907+
total_loss = (
908+
self.reward_loss_weight * reward_loss +
909+
self.value_loss_weight * value_loss +
910+
self.policy_loss_weight * weighted_policy_loss +
911+
self.obs_loss_weight * obs_loss + # 假设 ssl_loss_weight 是 obs_loss 的权重
912+
self.latent_recon_loss_weight * latent_recon_loss+
913+
self.perceptual_loss_weight*perceptual_loss
914+
# ... 如果还有其他损失项,也加进来 ...
915+
)
916+
else:
917+
918+
total_loss = (
919+
self.reward_loss_weight * reward_loss +
920+
self.value_loss_weight * value_loss +
921+
self.policy_loss_weight * weighted_policy_loss +
922+
self.obs_loss_weight * obs_loss # 假设 ssl_loss_weight 是 obs_loss 的权重
923+
924+
# ... 如果还有其他损失项,也加进来 ...
925+
)
908926
weighted_total_loss = (weights * total_loss).mean()
909927
# ===================== END: 目标熵正则化更新逻辑 =====================
910928

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def main(env_id, seed):
8989
num_res_blocks=2,
9090
num_channels=128,
9191
world_model_cfg=dict(
92+
latent_recon_loss_weight=1,
93+
perceptual_loss_weight=1,
94+
9295
# use_new_cache_manager=True,
9396
use_new_cache_manager=False,
9497

@@ -240,7 +243,9 @@ def main(env_id, seed):
240243

241244
# ============ use muzero_segment_collector instead of muzero_collector =============
242245
from lzero.entry import train_unizero_segment
243-
main_config.exp_name = f'data_unizero_st_refactor1024/{env_id[3:-3]}/{env_id[3:-3]}_uz_cossimloss_nokvcachemanager_ch128-res2_aug_targetentropy-alpha-100k-098-07-lr1e-3-encoder-clip30-10-100k_adamw-wd1e-2-encoder5-trans1-head0_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
246+
main_config.exp_name = f'data_unizero_st_refactor1024/{env_id[3:-3]}/{env_id[3:-3]}_uz_recon-perc-w1_cossimloss_nokvcachemanager_ch128-res2_aug_targetentropy-alpha-100k-098-07-lr1e-3-encoder-clip30-10-100k_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
247+
248+
# main_config.exp_name = f'data_unizero_st_refactor1024/{env_id[3:-3]}/{env_id[3:-3]}_uz_recon-perc-w1_cossimloss_nokvcachemanager_ch128-res2_aug_targetentropy-alpha-100k-098-07-lr1e-3-encoder-clip30-10-100k_adamw-wd1e-2-encoder5-trans1-head0_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
244249

245250
# main_config.exp_name = f'data_unizero_st_refactor1024/{env_id[3:-3]}/{env_id[3:-3]}_uz_cossimloss_nokvcachemanager_ch64-res1_targetentropy-alpha-100k-098-07-encoder-clip30-10-100k_adamw-wd1e-2-encoder5-trans1-head0_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
246251

@@ -265,7 +270,7 @@ def main(env_id, seed):
265270

266271
# 测试的atari8中的4个base环境
267272
# args.env = 'PongNoFrameskip-v4' # 反应型环境 密集奖励
268-
args.env = 'MsPacmanNoFrameskip-v4' # 记忆规划型环境 稀疏奖励
273+
# args.env = 'MsPacmanNoFrameskip-v4' # 记忆规划型环境 稀疏奖励
269274

270275
# args.env = 'ALE/Pong-v5' # 记忆规划型环境 稀疏奖励
271276

@@ -293,7 +298,7 @@ def main(env_id, seed):
293298
"""
294299
tmux new -s uz-st-refactor-boxing
295300
296-
export CUDA_VISIBLE_DEVICES=1
301+
export CUDA_VISIBLE_DEVICES=0
297302
cd /mnt/shared-storage-user/puyuan/code_20250828/LightZero/
298303
/mnt/shared-storage-user/puyuan/lz/bin/python /mnt/shared-storage-user/puyuan/code_20250828/LightZero/zoo/atari/config/atari_unizero_segment_config.py
299304

0 commit comments

Comments
 (0)