-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_modelnet40_partial.py
More file actions
164 lines (143 loc) · 5.81 KB
/
train_modelnet40_partial.py
File metadata and controls
164 lines (143 loc) · 5.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import h5py
import jax
import jax.numpy as jnp
import optax
from dataset_with_overlap_masks import PointCloudDataset, RandomCrop
from dataset_with_overlap_masks import RandomJitter, RandomShuffle
from dataset_with_overlap_masks import RandomRotateSource, RandomTranslateSource
from flax import nnx
from losses_and_metrics import get_overlap_accuracy, get_overlap_loss
from losses_and_metrics import get_rotation_errors, get_rotation_loss
from losses_and_metrics import get_translation_errors, get_translation_loss
from mmd_weighted import expm_skew
from set_transformer import SupervisedModel, count_params, save_model
from torch.utils.data import DataLoader, default_collate
from torchvision.transforms import v2
def get_modelnet40_partial_dataloaders(
batch_size,
train_data_path="Datasets/Processed/modelnet40_partial_train.hdf5",
val_data_path="Datasets/Processed/modelnet40_partial_val.hdf5",
):
"""
Return the training and validation PyTorch DataLoaders.
Args:
batch_size: Batch size.
train_data_path: Path to HDF5 file containing processed training data.
val_data_path: Path to HDF5 file containing processed validation data.
Returns:
Training and validation PyTorch DataLoaders.
"""
collate_fn = lambda batch: jax.tree.map(jnp.asarray, default_collate(batch))
train_tr = v2.Compose(
[
RandomCrop(),
RandomJitter(),
RandomRotateSource(),
RandomTranslateSource(),
RandomShuffle(),
]
)
train_ds = PointCloudDataset(train_data_path, transform=train_tr)
train_dl = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
drop_last=True,
)
val_tr = None
val_ds = PointCloudDataset(val_data_path, transform=val_tr)
val_dl = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
drop_last=False,
)
return train_dl, val_dl
def loss_fn(model, batch):
Xs, Ys, true_Rs, true_ts, true_Xos, true_Yos = batch
batch_expm_skew = jax.vmap(expm_skew)
pred_rs, pred_ts, pred_ls, pred_Xo_logits, pred_Yo_logits = model(Xs, Ys)
pred_Rs = batch_expm_skew(pred_rs)
loss = 0.0
loss = loss + get_rotation_loss(pred_Rs, true_Rs)
loss = loss + get_translation_loss(pred_ts, true_ts)
loss = loss + get_overlap_loss(pred_Xo_logits, true_Xos)
loss = loss + get_overlap_loss(pred_Yo_logits, true_Yos)
predictions = (pred_Rs, pred_ts, pred_Xo_logits, pred_Yo_logits)
return loss, predictions
@nnx.jit
def train_step(model, optimizer, batch):
value_and_grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, predictions), grads = value_and_grad_fn(model, batch)
optimizer.update(model, grads)
return loss
@nnx.jit
def eval_step(model, batch):
Xs, Ys, true_Rs, true_ts, true_Xos, true_Yos = batch
loss, predictions = loss_fn(model, batch)
pred_Rs, pred_ts, pred_Xo_logits, pred_Yo_logits = predictions
net_errors_R = get_rotation_errors(pred_Rs, true_Rs)
net_errors_t = get_translation_errors(pred_ts, true_ts)
overlap_accuracy = get_overlap_accuracy(
pred_Xo_logits, pred_Yo_logits, true_Xos, true_Yos
)
return loss, overlap_accuracy, net_errors_R, net_errors_t
if __name__ == "__main__":
batch_size = 16
num_epochs = 20000
train_dl, val_dl = get_modelnet40_partial_dataloaders(batch_size)
M = len(train_dl.dataset)
N = len(val_dl.dataset)
num_steps = num_epochs * len(train_dl)
learning_rate_fn = optax.warmup_cosine_decay_schedule(1e-6, 1e-4, 1000, num_steps)
model = SupervisedModel(
dim_i=3,
dim_h=512,
num_h=4,
drop=0.0,
rngs=nnx.Rngs(0),
)
print(f"Number of model parameters: {count_params(model):,}")
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(learning_rate_fn),
)
optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param)
for epoch in range(num_epochs):
model.train()
train_loss = jnp.asarray(0.0)
for batch in train_dl:
B = jnp.size(batch[0], axis=0)
loss = train_step(model, optimizer, batch)
train_loss += loss * B / M
model.eval()
val_loss = jnp.asarray(0.0)
val_overlap_acc = jnp.asarray(0.0)
all_net_errors_R = jnp.zeros((N,))
all_net_errors_t = jnp.zeros((N,))
idx = 0
for batch in val_dl:
B = jnp.size(batch[0], axis=0)
evals = eval_step(model, batch)
loss, overlap_accuracy, net_errors_R, net_errors_t = evals
val_loss += loss * B / N
val_overlap_acc += overlap_accuracy * B / N
all_net_errors_R = all_net_errors_R.at[idx:idx+B].set(net_errors_R)
all_net_errors_t = all_net_errors_t.at[idx:idx+B].set(net_errors_t)
idx += B
print(f"Epoch {epoch+1:,}/{num_epochs:,} Averages:")
print(f" Train Loss: {train_loss:.5f}")
print(f" Val Loss: {val_loss:.5f}")
print(f" Val Net Overlap Accuracy: {val_overlap_acc:.5f}")
print(f" Val Net Rotation Error: {jnp.mean(all_net_errors_R):.5f}")
print(f" Val Net Translation Error: {jnp.mean(all_net_errors_t):.5f}")
es = f"epoch_{epoch}" # Epoch string.
with h5py.File("results_partial.hdf5", "x" if epoch == 0 else "r+") as f:
f.create_dataset(f"{es}/train_loss", data=train_loss)
f.create_dataset(f"{es}/val_loss", data=val_loss)
f.create_dataset(f"{es}/val_overlap_acc", data=val_overlap_acc)
f.create_dataset(f"{es}/all_net_errors_R", data=all_net_errors_R)
f.create_dataset(f"{es}/all_net_errors_t", data=all_net_errors_t)
save_model(model, "trained_model_params_partial.msgpack")