Skip to content

Commit 306e762

Browse files
committed
Fix torchvision pretrained deprecation warning
1 parent f61dde4 commit 306e762

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

Diff for: simclr/modules/resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
def get_resnet(name, pretrained=False):
55
resnets = {
6-
"resnet18": torchvision.models.resnet18(pretrained=pretrained),
7-
"resnet50": torchvision.models.resnet50(pretrained=pretrained),
6+
"resnet18": torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT),
7+
"resnet50": torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT),
88
}
99
if name not in resnets.keys():
1010
raise KeyError(f"{name} is not a valid ResNet version")

0 commit comments

Comments
 (0)