1313from torchvision import models
1414from tqdm import tqdm
1515
16+ # ==================================================================================
17+ # ================================ 核心修改部分 ====================================
18+ # ==================================================================================
19+ # 在导入 torch 和 torchvision 之后,但在实例化任何模型之前,设置 TORCH_HOME 环境变量。
20+ # 这会告诉 PyTorch 将所有通过 torch.hub 下载的模型(包括 torchvision.models 中的预训练模型)
21+ # 存放到您指定的目录下。
22+ # PyTorch 会自动在此目录下创建 'hub/checkpoints' 子文件夹。
23+ custom_torch_home = "/mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg"
24+ os .environ ['TORCH_HOME' ] = custom_torch_home
25+ # 确保目录存在,虽然 torch.hub 也会尝试创建,但提前创建更稳妥
26+ os .makedirs (os .path .join (custom_torch_home , 'hub' , 'checkpoints' ), exist_ok = True )
27+ # ==================================================================================
28+ # ==================================================================================
29+
1630
1731class LPIPS (nn .Module ):
1832 # Learned perceptual metric
@@ -22,19 +36,23 @@ def __init__(self, use_dropout: bool = True):
2236 self .chns = [64 , 128 , 256 , 512 , 512 ] # vg16 features
2337
2438 # Comment out the following line if you don't need perceptual loss
25- # self.net = vgg16(pretrained=True, requires_grad=False)
26- # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
27- # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
28- # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
29- # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
30- # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
31- # self.load_from_pretrained()
32- # for param in self.parameters():
33- # param.requires_grad = False
39+ # 现在,这一行将自动使用 TORCH_HOME 指定的路径
40+ self .net = vgg16 (pretrained = True , requires_grad = False )
41+ self .lin0 = NetLinLayer (self .chns [0 ], use_dropout = use_dropout )
42+ self .lin1 = NetLinLayer (self .chns [1 ], use_dropout = use_dropout )
43+ self .lin2 = NetLinLayer (self .chns [2 ], use_dropout = use_dropout )
44+ self .lin3 = NetLinLayer (self .chns [3 ], use_dropout = use_dropout )
45+ self .lin4 = NetLinLayer (self .chns [4 ], use_dropout = use_dropout )
46+ self .load_from_pretrained ()
47+ for param in self .parameters ():
48+ param .requires_grad = False
3449
3550 def load_from_pretrained (self ) -> None :
36- ckpt = get_ckpt_path (name = "vgg_lpips" , root = Path .home () / ".cache/iris/tokenizer_pretrained_vgg" ) # Download VGG if necessary
51+ # 这一部分您已经修改正确,它用于加载 LPIPS 的线性层权重 (vgg.pth)
52+ # 我们让它和 TORCH_HOME 使用相同的根目录,以保持一致性。
53+ ckpt = get_ckpt_path (name = "vgg_lpips" , root = custom_torch_home )
3754 self .load_state_dict (torch .load (ckpt , map_location = torch .device ("cpu" )), strict = False )
55+ print (f"Loaded LPIPS pretrained weights from: { ckpt } " )
3856
3957 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
4058 in0_input , in1_input = (self .scaling_layer (input ), self .scaling_layer (target ))
@@ -74,7 +92,10 @@ def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) ->
7492class vgg16 (torch .nn .Module ):
7593 def __init__ (self , requires_grad : bool = False , pretrained : bool = True ) -> None :
7694 super (vgg16 , self ).__init__ ()
95+ # 由于设置了 TORCH_HOME,这里的 pretrained=True 会在指定目录中查找或下载模型
96+ print ("Loading vgg16 backbone..." )
7797 vgg_pretrained_features = models .vgg16 (pretrained = pretrained ).features
98+ print ("vgg16 backbone loaded." )
7899 self .slice1 = torch .nn .Sequential ()
79100 self .slice2 = torch .nn .Sequential ()
80101 self .slice3 = torch .nn .Sequential ()
@@ -160,10 +181,26 @@ def md5_hash(path: str) -> str:
160181
161182def get_ckpt_path (name : str , root : str , check : bool = False ) -> str :
162183 assert name in URL_MAP
184+ # 这个函数现在只为 vgg.pth 服务,路径是正确的
163185 path = os .path .join (root , CKPT_MAP [name ])
164186 if not os .path .exists (path ) or (check and not md5_hash (path ) == MD5_MAP [name ]):
165187 print ("Downloading {} model from {} to {}" .format (name , URL_MAP [name ], path ))
166188 download (URL_MAP [name ], path )
167189 md5 = md5_hash (path )
168190 assert md5 == MD5_MAP [name ], md5
169191 return path
192+
193+ # =======================
194+ # ===== 运行示例 ======
195+ # =======================
196+ if __name__ == '__main__' :
197+ print (f"PyTorch Hub directory set to: { os .environ ['TORCH_HOME' ]} " )
198+
199+ # 第一次运行时,你会看到两个下载过程:
200+ # 1. 下载 vgg16-397923af.pth 到 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg/hub/checkpoints/
201+ # 2. 下载 vgg.pth 到 /mnt/shared-storage-user/puyuan/code_20250828/LightZero/tokenizer_pretrained_vgg/
202+ # 之后再次运行,将不会有任何下载提示,直接从指定目录加载。
203+
204+ print ("\n Initializing LPIPS model..." )
205+ model = LPIPS ()
206+ print ("\n LPIPS model initialized successfully." )
0 commit comments