Skip to content

Commit 8a153f4

Browse files
committed
.
1 parent 3d5c37a commit 8a153f4

File tree

5 files changed

+119
-6
lines changed

5 files changed

+119
-6
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
name: mesh500_2048_vqvae_mhvq
2+
exp_root_dir: outputs
3+
exp_dir: ""
4+
trial_dir: ""
5+
n_gpus: 1
6+
seed: 42
7+
8+
dataset_name: mesh500_2048
9+
dataset_source: hf
10+
dataset_img_key: null
11+
dataset_kwargs:
12+
root: ./.cache/${dataset_name}
13+
14+
model: VQVAE
15+
model_config: VQVAEConfig
16+
model_kwargs:
17+
dim: 128
18+
in_channel: 3
19+
out_channel: 3
20+
layers: 2
21+
layer_mults: null
22+
num_res_blocks: 1
23+
group: 8
24+
conv_type: "conv1d"
25+
enc_act_func: "LeakyReLU"
26+
dec_act_func: "GLU"
27+
enc_act_kwargs: {"negative_slope": 0.1}
28+
dec_act_kwargs: {"dim": 1}
29+
first_conv_kernel_size: 5
30+
quantizer: "VectorQuantize"
31+
codebook_size: 512
32+
quantizer_kwargs: {
33+
"codebook_dim": 64,
34+
"heads": 8,
35+
"decay" : 0.99,
36+
"commitment_weight": 0.25,
37+
"kmeans_init": True,
38+
"use_cosine_sim": True
39+
}
40+
l2_recon_loss: True
41+
42+
trainer: PCVQVAETrainer
43+
trainer_config: PCVQVAETrainerConfig
44+
trainer_kwargs:
45+
num_train_steps: 300000
46+
batch_size: 128
47+
num_workers: 16
48+
pin_memory: True
49+
grad_accum_every: 1
50+
learning_rate: 0.001
51+
weight_decay: 0.
52+
max_grad_norm: 0.5
53+
val_every: 1000
54+
val_num_batches: 20
55+
val_num_images: 32
56+
scheduler: CosineAnnealingLR
57+
scheduler_kwargs:
58+
T_max: "${sub: ${trainer_kwargs.num_train_steps}, ${trainer_kwargs.warmup_steps}}"
59+
eta_min: 0.0005
60+
ema_kwargs: null
61+
accelerator_kwargs: {}
62+
optimizer_name: Adam
63+
optimizer_kwargs: {}
64+
loss_lambda:
65+
recon_loss: 1.
66+
quantizer_loss: 1.
67+
checkpoint_every: null
68+
warmup_steps: 0
69+
use_wandb_tracking: False
70+
resume: False
71+
from_checkpoint: null
72+
from_checkpoint_type: null
73+
74+
wandb:
75+
project_name: "vitvqganvae"
76+
run_name: null
77+
kwargs:
78+
entity: "heartbeats"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM kohido/base_dl_cuda129:v0.0.6
2+
3+
ENV DEBIAN_FRONTEND=noninteractive
4+
5+
WORKDIR /vitvqganvae
6+
7+
RUN ls
8+
9+
COPY ./config /vitvqganvae/config
10+
COPY ./vitvqganvae /vitvqganvae/vitvqganvae
11+
COPY ./main.py /vitvqganvae/main.py
12+
13+
CMD wandb login ${WANDB_API_KEY} && accelerate launch \
14+
--mixed_precision=no \
15+
--num_processes=1 \
16+
--num_machines=1 \
17+
--dynamo_backend=no \
18+
main.py \
19+
--config config/pointcloud/mesh500/mesh500_2048_vqvae_mhvq.yaml \
20+
--train \
21+
trainer_kwargs.use_wandb_tracking=True

main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@ def main(args, extras):
6767
model_module = model_cls(**model_config)
6868
try:
6969
sample: Tensor = train_ds[0].unsqueeze(0)
70-
summary(copy.deepcopy(model_module), input_size=sample.shape)
70+
summary(
71+
copy.deepcopy(model_module),
72+
input_data=sample,
73+
col_names=["input_size", "output_size", "num_params", "params_percent", "trainable"],
74+
# depth=2
75+
)
7176
except Exception as e:
7277
print(f"Cannot run model summary: {e}")
7378

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vitvqganvae"
3-
version = "0.4.7"
3+
version = "0.4.8"
44
authors = [
55
{ name="KhoiDOO", email="khoido8899@gmail.com" },
66
]

vitvqganvae/data/hf/mesh500.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,26 @@ class Mesh500(Dataset):
1313
def __init__(self, root: str, num_points: int = 1024):
1414
super().__init__()
1515

16-
if num_points not in [1024, 4096]:
17-
raise ValueError("num_points should be one of 1024 or 4096 for Mesh500 dataset")
16+
if num_points not in [1024, 2048, 4096]:
17+
raise ValueError("num_points should be one of 1024, 2048, or 4096 for Mesh500 dataset")
1818

1919
self._root = root
2020
self._num_points = num_points
21-
self._dataset = load_dataset(f"kohido/mesh500_{num_points}pts", cache_dir=self._root)['train']['points']
21+
if self._num_points == 2048:
22+
self._dataset = load_dataset(f"kohido/mesh500_4096pts", cache_dir=self._root)['train']['points']
23+
elif self._num_points == 4096:
24+
self._dataset = load_dataset(f"kohido/mesh500_4096pts", cache_dir=self._root)['train']['points']
2225

2326
def __len__(self) -> int:
2427
return len(self._dataset)
2528

2629
def __getitem__(self, index: int) -> Tensor:
2730
points = self._dataset[index]
28-
points: np.ndarray = np.array(points) # (1024, 3)
31+
points: np.ndarray = np.array(points) # (self._num_points, 3)
2932
points = points[np.lexsort((points[:, 2], points[:, 1], points[:, 0]))]
33+
if self._num_points == 2048:
34+
# only take even indices
35+
points = points[::2]
3036
# scale to a [-0.5, 0.5] cube
3137
points = points - np.mean(points, axis=0, keepdims=True)
3238
max_abs = np.max(np.abs(points))
@@ -54,6 +60,9 @@ def get_mesh500(root: str | None = None, num_points: int = 1024, split: float =
5460
def get_mesh500_1024(root: str | None = None, split: float = 0.8) -> tuple[Mesh500, Mesh500]:
5561
return get_mesh500(root=root, num_points=1024, split=split)
5662

63+
def get_mesh500_2048(root: str | None = None, split: float = 0.8) -> tuple[Mesh500, Mesh500]:
64+
return get_mesh500(root=root, num_points=2048, split=split)
65+
5766
def get_mesh500_4096(root: str | None = None, split: float = 0.8) -> tuple[Mesh500, Mesh500]:
5867
return get_mesh500(root=root, num_points=4096, split=split)
5968

0 commit comments

Comments
 (0)