From 0dc53d344298ca75c1f7f078e3588db8edabed68 Mon Sep 17 00:00:00 2001 From: MANU S PILLAI Date: Tue, 15 Feb 2022 12:49:10 +0530 Subject: [PATCH] Added support for resnet architectures Updated PNetLin to have resnet arch support --- lpips/networks_basic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lpips/networks_basic.py b/lpips/networks_basic.py index 1d23f059..337a73ce 100755 --- a/lpips/networks_basic.py +++ b/lpips/networks_basic.py @@ -10,6 +10,7 @@ from pdb import set_trace as st from skimage import color from IPython import embed +from functools import partial from . import pretrained_networks as pn import lpips as util @@ -25,7 +26,7 @@ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W # Learned perceptual metric class PNetLin(nn.Module): - def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): + def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True, num=None): super(PNetLin, self).__init__() self.pnet_type = pnet_type @@ -45,6 +46,11 @@ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropou elif(self.pnet_type=='squeeze'): net_type = pn.squeezenet self.chns = [64,128,256,384,384,512,512] + elif(self.pnet_type== 'resnet'): + if num is None: + raise ValueError(f'\'num\' must be specified for pnet_type == {self.pnet_type}') + net_type = partial(pn.resnet, num=num) + self.chns = [64, 256, 512, 1024, 2048] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)