-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhubconf.py
139 lines (114 loc) · 4.8 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Pathology FM model hub from kaiko.ai."""
from typing import List
import timm
import torch
from torch import nn
dependencies = ["torch", "timm"]
# List of package names required to load the model
RELEASE_TAG = "0.0.1"
"""The release tag to fetch the weights from."""
def vits16(dynamic_img_size: bool = True, out_indices: int | List[int] | None = None) -> nn.Module:
"""Initializes the vision transformer ViTS-16 pathology FM by kaiko.ai.
Args:
dynamic_img_size: Whether to allow the interpolation embedding
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
Returns:
The torch ViTS-16 based foundation model.
"""
return timm.create_model(
model_name="vit_small_patch16_224",
dynamic_img_size=dynamic_img_size,
pretrained_cfg={
"url": f"https://github.com/kaiko-ai/towards_large_pathology_fms/releases/download/{RELEASE_TAG}/vits16.pth",
"num_classes": 0
},
pretrained=True,
out_indices=out_indices,
features_only=out_indices is not None,
)
def vits8(dynamic_img_size: bool = True, out_indices: int | List[int] | None = None) -> nn.Module:
"""Initializes the vision transformer ViTS-8 pathology FM by kaiko.ai.
Args:
dynamic_img_size: Whether to allow the interpolation embedding
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
Returns:
The torch ViTS-8 based foundation model.
"""
return timm.create_model(
model_name="vit_small_patch8_224",
dynamic_img_size=dynamic_img_size,
pretrained_cfg={
"url": f"https://github.com/kaiko-ai/towards_large_pathology_fms/releases/download/{RELEASE_TAG}/vits8.pth",
"num_classes": 0
},
pretrained=True,
out_indices=out_indices,
features_only=out_indices is not None,
)
def vitb16(dynamic_img_size: bool = True, out_indices: int | List[int] | None = None) -> nn.Module:
"""Initializes the vision transformer ViTB-16 pathology FM by kaiko.ai.
Args:
dynamic_img_size: Whether to allow the interpolation embedding
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
Returns:
The torch ViTB-16 based foundation model.
"""
return timm.create_model(
model_name="vit_base_patch16_224",
dynamic_img_size=dynamic_img_size,
pretrained_cfg={
"url": f"https://github.com/kaiko-ai/towards_large_pathology_fms/releases/download/{RELEASE_TAG}/vitb16.pth",
"num_classes": 0
},
pretrained=True,
out_indices=out_indices,
features_only=out_indices is not None,
)
def vitb8(dynamic_img_size: bool = True, out_indices: int | List[int] | None = None) -> nn.Module:
"""Initializes the vision transformer ViTB-8 pathology FM by kaiko.ai.
Args:
dynamic_img_size: Whether to allow the interpolation embedding
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
Returns:
The torch ViTB-8 based foundation model.
"""
return timm.create_model(
model_name="vit_base_patch8_224",
dynamic_img_size=dynamic_img_size,
pretrained_cfg={
"url": f"https://github.com/kaiko-ai/towards_large_pathology_fms/releases/download/{RELEASE_TAG}/vitb8.pth",
"num_classes": 0
},
pretrained=True,
out_indices=out_indices,
features_only=out_indices is not None,
)
def vitl14(dynamic_img_size: bool = True, out_indices: int | List[int] | None = None) -> nn.Module:
"""Initializes the vision transformer ViTL-14 pathology FM by kaiko.ai.
Args:
dynamic_img_size: Whether to allow the interpolation embedding
to be interpolated at `forward()` time when image grid changes
from original.
out_indices: Weather and which multi-level patch embeddings to return.
Returns:
The torch ViTL-14 based foundation model.
"""
return timm.create_model(
model_name="vit_large_patch14_reg4_dinov2",
pretrained_cfg={
"url": f"https://github.com/kaiko-ai/towards_large_pathology_fms/releases/download/{RELEASE_TAG}/vitl14.pth",
"num_classes": 0,
},
pretrained=True,
out_indices=out_indices,
dynamic_img_size=dynamic_img_size,
features_only=out_indices is not None,
)