-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathbuild_model.py
More file actions
95 lines (80 loc) · 3.81 KB
/
Copy pathbuild_model.py
File metadata and controls
95 lines (80 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from tkinter.constants import RAISED
import timm
import torch
from model.vision_transformer import VisionTransformerPETL
from utils.log_utils import log_model_info
from timm.data import resolve_data_config
from utils.setup_logging import get_logger
import torch.distributed as dist
def is_main_process():
return (not dist.is_available()) or (not dist.is_initialized()) or dist.get_rank() == 0
logger = get_logger("Prompt_CAM")
TUNE_MODULES = ['vpt','head']
def get_model(params,visualize=False):
params.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {params.device}")
model = get_base_model(params,visualize=visualize)
##########
tune_parameters = []
if params.debug and is_main_process():
logger.info("Trainable params:")
for name, parameter in model.named_parameters():
if any(m in name for m in TUNE_MODULES):
parameter.requires_grad = True
tune_parameters.append(parameter)
if params.debug and is_main_process():
logger.info("\t{}, {}, {}".format(name, parameter.numel(), parameter.shape))
else:
parameter.requires_grad = False
model_grad_params_no_head = log_model_info(model, logger)
model = model.cuda(device=params.device)
return model, tune_parameters, model_grad_params_no_head
def get_base_model(params,visualize=False):
if params.pretrained_weights == "vit_base_patch16_224_in21k":
params.patch_size = 16
model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate,
pretrained=False, params=params)
if not visualize:
model.load_pretrained(
'pretrained_weights/ViT-B_16_in21k.npz')
model.reset_classifier(params.class_num)
elif params.pretrained_weights == "vit_base_mae":
model = timm.create_model("vit_base_patch16_224_in21k_petl", drop_path_rate=params.drop_path_rate,
pretrained=False,
params=params)
if not visualize:
model.load_pretrained(
'pretrained_weights/mae_pretrain_vit_base.pth')
model.reset_classifier(params.class_num)
elif params.pretrained_weights == "vit_base_patch14_dinov2":
params.patch_size = 14
model = timm.create_model("vit_base_patch14_dinov2_petl", drop_path_rate=params.drop_path_rate,
pretrained=False,
params=params)
if not visualize:
model.load_pretrained(
'pretrained_weights/dinov2_vitb14_pretrain.pth')
model.reset_classifier(params.class_num)
elif params.pretrained_weights == "vit_base_patch16_dino":
model = timm.create_model("vit_base_patch16_dino_petl", drop_path_rate=params.drop_path_rate,
pretrained=False,
params=params)
if not visualize:
model.load_pretrained(
'pretrained_weights/dino_vitbase16_pretrain.pth')
model.reset_classifier(params.class_num)
elif params.pretrained_weights == 'vit_base_patch16_clip_224':
params.patch_size = 16
model = timm.create_model("vit_base_patch16_clip_224_petl", drop_path_rate=params.drop_path_rate,
pretrained=False,
params=params)
if not visualize:
model.load_pretrained(
'pretrained_weights/ViT-B_16_clip.bin')
fc = init_imagenet_clip(params.device)
proj = get_clip_proj(params.device)
model.head = torch.nn.Sequential(*[proj, fc])
else:
raise NotImplementedError
# data_config = resolve_data_config(vars(params), model=model, verbose=False)
return model