-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvision_model.py
171 lines (142 loc) · 5.88 KB
/
vision_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
"""
Two vision encoder here, VisionEncoder: simple and ResNetFe: resnet style
Note: batch norm will be replaced by group norm before using with the noise predictor
as group norm works better with unet for action noise prediction.
"""
#This is a simple visionencoder (without resnet)
class VisionEncoder(nn.Module):
def __init__(self ):
super(VisionEncoder, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # (x+2*3-7)/2+1=(x-1)/2+1=48 : 64*48*48
self.gn1 = nn.GroupNorm(32, 64) # num_groups, num_channels
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) #(x+2*1-3)/2+1=(x-1)/2+1=24 : 64*24*24
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, 512)
def forward(self, x):
x = self.conv1(x)
x = self.gn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = F.relu(out)
return out
class ResNetFe(nn.Module):
def __init__(self, block, layers):
super(ResNetFe, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # (x+2*3-7)/2+1=(x-1)/2+1=48 : 64*48*48
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) #(x+2*1-3)/2+1=(x-1)/2+1=24 : 64*24*24
# Two ResNet layers
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128, 512) # Output 512-dimensional features
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for _ in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
if len(x.shape) == 4:
x = x.unsqueeze(1)
batch_size, seq_len, channels, height, width = x.shape
x = x.view(batch_size * seq_len, channels, height, width)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
### ----------------- utility functions from the original diffusion policy codebase ---------------------- ###
def replace_submodules(
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
assert len(bn_list) == 0
return root_module
def replace_bn_with_gn(
root_module: nn.Module,
features_per_group: int=16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//features_per_group,
num_channels=x.num_features)
)
return root_module