Skip to content

Commit e8879e3

Browse files
committed
version 0.0.1
1 parent cdb3147 commit e8879e3

2 files changed

Lines changed: 198 additions & 0 deletions

File tree

pointnet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pointnet import STN, PointNetCls, PointNetSeg

pointnet/pointnet.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from einops import rearrange, repeat
5+
6+
7+
def exists(val):
8+
return val is not None
9+
10+
11+
def default(*vals):
12+
for val in vals:
13+
if exists(val):
14+
return val
15+
16+
17+
class STN(nn.Module):
18+
# perform spatial transformation in n-dimensional space
19+
20+
def __init__(self, in_dim=3, out_nd=None, head_norm=True):
21+
super().__init__()
22+
self.in_dim = in_dim
23+
self.out_nd = default(out_nd, in_dim)
24+
25+
self.net = nn.Sequential(
26+
nn.Conv1d(in_dim, 64, 1, bias=False),
27+
nn.BatchNorm1d(64),
28+
nn.GELU(),
29+
nn.Conv1d(64, 128, 1, bias=False),
30+
nn.BatchNorm1d(128),
31+
nn.GELU(),
32+
nn.Conv1d(128, 1024, 1, bias=False),
33+
)
34+
35+
norm = nn.BatchNorm1d if head_norm else nn.Identity
36+
self.norm = norm(1024)
37+
self.act = nn.GELU()
38+
39+
self.head = nn.Sequential(
40+
nn.Linear(1024, 512, bias=False),
41+
norm(512),
42+
nn.GELU(),
43+
nn.Linear(512, 256, bias=False),
44+
norm(256),
45+
nn.GELU(),
46+
nn.Linear(256, self.out_nd ** 2),
47+
)
48+
49+
nn.init.normal_(self.head[-1].weight, 0, 0.001)
50+
nn.init.eye_(self.head[-1].bias.view(in_dim, in_dim))
51+
52+
def forward(self, x):
53+
# x: (b, d, n)
54+
x = self.net(x)
55+
x = torch.max(x, dim=-1, keepdim=False)[0]
56+
x = self.act(self.norm(x))
57+
58+
x = self.head(x)
59+
x = rearrange(x, "b (x y) -> b x y", x=self.out_nd, y=self.out_nd)
60+
return x
61+
62+
63+
class PointNetCls(nn.Module):
64+
def __init__(
65+
self,
66+
*,
67+
in_dim,
68+
out_dim,
69+
stn_3d=STN(in_dim=3), # if None, no stn_3d
70+
with_head=True,
71+
head_norm=True,
72+
dropout=0.3,
73+
):
74+
super().__init__()
75+
self.with_head = with_head
76+
77+
# if using stn, put other features behind xyz
78+
self.stn_3d = stn_3d
79+
80+
self.conv1 = nn.Sequential(
81+
nn.Conv1d(in_dim, 64, 1, bias=False),
82+
nn.BatchNorm1d(64),
83+
nn.GELU(),
84+
nn.Conv1d(64, 64, 1, bias=False),
85+
nn.BatchNorm1d(64),
86+
nn.GELU(),
87+
)
88+
89+
self.stn_nd = STN(in_dim=64, head_norm=head_norm)
90+
self.conv2 = nn.Sequential(
91+
nn.Conv1d(64, 64, 1, bias=False),
92+
nn.BatchNorm1d(64),
93+
nn.GELU(),
94+
nn.Conv1d(64, 128, 1, bias=False),
95+
nn.BatchNorm1d(128),
96+
nn.GELU(),
97+
nn.Conv1d(128, 1024, 1, bias=False),
98+
)
99+
100+
norm = nn.BatchNorm1d if head_norm else nn.Identity
101+
self.norm = norm(1024)
102+
self.act = nn.GELU()
103+
104+
if self.with_head:
105+
self.head = nn.Sequential(
106+
nn.Linear(1024, 512, bias=False),
107+
norm(512),
108+
nn.GELU(),
109+
nn.Linear(512, 256, bias=False),
110+
norm(256),
111+
nn.GELU(),
112+
nn.Dropout(dropout),
113+
nn.Linear(256, out_dim),
114+
)
115+
116+
def forward(self, x):
117+
# x: (b, d, n)
118+
if exists(self.stn_3d):
119+
transform_3d = self.stn_3d(x)
120+
if x.size(1) == 3:
121+
x = torch.bmm(transform_3d, x)
122+
elif x.size(1) > 3:
123+
x = torch.cat([torch.bmm(transform_3d, x[:, :3]), x[:, 3:]], dim=1)
124+
else:
125+
raise ValueError(f"invalid input dimension: {x.size(1)}")
126+
127+
x = self.conv1(x)
128+
transform_nd = self.stn_nd(x)
129+
x = torch.bmm(transform_nd, x)
130+
x = self.conv2(x)
131+
132+
x = torch.max(x, dim=-1, keepdim=False)[0]
133+
x = self.act(self.norm(x))
134+
135+
if self.with_head:
136+
x = self.head(x)
137+
return x
138+
139+
140+
class PointNetSeg(nn.Module):
141+
142+
def __init__(
143+
self,
144+
*,
145+
in_dim,
146+
out_dim,
147+
stn_3d=STN(in_dim=3), # if None, no stn_3d
148+
global_head_norm=True, # if using normalization in the global head, disable it if batch size is 1
149+
):
150+
super().__init__()
151+
152+
self.backbone = PointNetCls(in_dim=in_dim,
153+
out_dim=out_dim,
154+
stn_3d=stn_3d,
155+
head_norm=global_head_norm,
156+
with_head=False)
157+
158+
self.head = nn.Sequential(
159+
nn.Conv1d(1024 + 64, 512, 1, bias=False),
160+
nn.BatchNorm1d(512),
161+
nn.GELU(),
162+
nn.Conv1d(512, 256, 1, bias=False),
163+
nn.BatchNorm1d(256),
164+
nn.GELU(),
165+
nn.Conv1d(256, 128, 1, bias=False),
166+
nn.BatchNorm1d(128),
167+
nn.GELU(),
168+
nn.Conv1d(128, out_dim, 1),
169+
)
170+
171+
def forward_backbone(self, x):
172+
# x: (b, d, n)
173+
if exists(self.backbone.stn_3d):
174+
transform_3d = self.backbone.stn_3d(x)
175+
if x.size(1) == 3:
176+
x = torch.bmm(transform_3d, x)
177+
elif x.size(1) > 3:
178+
x = torch.cat([torch.bmm(transform_3d, x[:, :3]), x[:, 3:]], dim=1)
179+
else:
180+
raise ValueError(f"invalid input dimension: {x.size(1)}")
181+
182+
x = self.backbone.conv1(x)
183+
transform_nd = self.backbone.stn_nd(x)
184+
x = torch.bmm(transform_nd, x)
185+
186+
global_feat = self.backbone.conv2(x)
187+
global_feat = torch.max(global_feat, dim=-1, keepdim=False)[0]
188+
global_feat = self.backbone.act(self.backbone.norm(global_feat))
189+
return x, global_feat
190+
191+
def forward(self, x):
192+
# x: (b, d, n)
193+
x, global_feat = self.forward_backbone(x)
194+
global_feat = repeat(global_feat, "b d -> b d n", n=x.size(-1))
195+
x = torch.cat([x, global_feat], dim=1)
196+
x = self.head(x)
197+
return x

0 commit comments

Comments
 (0)