Skip to content

Commit 0ae1927

Browse files
committed
Refactored code for AddBiomechanics and previous datasets
1 parent 53abc6b commit 0ae1927

8 files changed

+83
-164
lines changed

Dockerfile

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Use an official Ubuntu runtime as a parent image
2-
FROM ubuntu:22.04
2+
FROM ubuntu:25.04
33

44
# Set the working directory
55
WORKDIR /P-BIGE
66

77
# Install necessary dependencies
8-
RUN apt-get update && apt-get install -y wget git htop xvfb
8+
RUN apt-get update && apt-get install -y wget git htop xvfb build-essential
99

1010
# Install Miniconda
1111
RUN MINICONDA_INSTALLER_SCRIPT=Miniconda3-py38_23.1.0-1-Linux-x86_64.sh && \
@@ -27,10 +27,6 @@ RUN conda env create -f environment.yml
2727
# Activate the conda environment for subsequent RUN commands
2828
SHELL ["conda", "run", "-n", "P-BIGE", "/bin/bash", "-c"]
2929

30-
# Download the model and extractor
31-
RUN bash dataset/prepare/download_model.sh && \
32-
bash dataset/prepare/download_extractor.sh
33-
3430
# Install additional Python packages (if needed)
3531
RUN pip install --user ipykernel polyscope easydict trimesh
3632
RUN pip install --user --force-reinstall numpy==1.22.0

dataset/dataset_MOT_segmented.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def __getitem__(self, item):
205205
return subsequences, subsequence_lengths, names
206206

207207
class AddBiomechanicsDataset(data.Dataset):
208-
def __init__(self, window_size=64, unit_length=4, mode='train', data_dir='/home/mnt/data/addb_dataset_publication'):
208+
def __init__(self, window_size=64, unit_length=4, mode='train', data_dir='/home/kingn450/Datasets/addb_dataset_publication'):
209209
self.window_size = window_size
210210
self.unit_length = unit_length
211211
self.data_dir = data_dir

environment.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ dependencies:
1010
- bzip2=1.0.8=h7b6447c_0
1111
- ca-certificates=2021.7.5=h06a4308_1
1212
- certifi=2021.5.30=py38h06a4308_0
13-
- cudatoolkit=10.1.243=h6bb024c_0
1413
- ffmpeg=4.3=hf484d3e_0
1514
- freetype=2.10.4=h5ab3b9f_0
1615
- gmp=6.2.1=h2531618_2
@@ -56,8 +55,8 @@ dependencies:
5655
- six=1.16.0=pyhd3eb1b0_0
5756
- sqlite=3.36.0=hc218d9a_0
5857
- tk=8.6.10=hbc83047_0
59-
- torchaudio=0.8.1=py38
60-
- torchvision=0.9.1=py38_cu101
58+
- torchaudio=2.3.0
59+
- torchvision
6160
- typing_extensions=3.10.0.0=pyh06a4308_0
6261
- wheel=0.37.0=pyhd3eb1b0_0
6362
- xz=5.2.5=h7b6447c_0
Binary file not shown.
Binary file not shown.
Binary file not shown.

train_vq.py

Lines changed: 72 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,32 @@
1-
import os
1+
import deepspeed
22
import json
3-
# from osim_sequence import OSIMSequence,load_osim
3+
import models.vqvae as vqvae
4+
import nimblephysics as nimble
5+
import options.option_vq as option_vq
6+
import os
47
import torch
8+
import torch.distributed as dist
59
import torch.optim as optim
6-
from torch.utils.tensorboard import SummaryWriter
7-
from tqdm import tqdm
8-
9-
import models.vqvae as vqvae
10+
import utils.eval_trans as eval_trans
1011
import utils.losses as losses
11-
import options.option_vq as option_vq
1212
import utils.utils_model as utils_model
13+
import warnings
1314
from dataset import dataset_MOT_MCS, dataset_TM_eval, dataset_MOT_segmented
14-
import utils.eval_trans as eval_trans
15-
from options.get_eval_option import get_opt
1615
from models.evaluator_wrapper import EvaluatorModelWrapper
17-
import warnings
18-
warnings.filterwarnings('ignore')
16+
from options.get_eval_option import get_opt
17+
from torch.utils.tensorboard import SummaryWriter
18+
from tqdm import tqdm
1919
from utils.word_vectorizer import WordVectorizer
20-
import nimblephysics as nimble
21-
import deepspeed
22-
20+
warnings.filterwarnings('ignore')
2321

2422
def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr):
25-
2623
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
2724
for param_group in optimizer.param_groups:
2825
param_group["lr"] = current_lr
29-
3026
return optimizer, current_lr
3127

3228
def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
33-
# y_translation = 0.0
3429
min_height, idx = motion[..., 1].min(dim=-1)
35-
36-
# y_translation = -min_height.median() # Change reference to median (Other set of experiments determine this. See paper)
37-
38-
# print(min_height,idx,motion[..., 1].shape)
3930
min_height = min_height + y_translation
4031
pn = -torch.minimum(min_height, torch.zeros_like(min_height)) # penetration
4132
pn[pn < feet_threshold] = 0.0
@@ -64,16 +55,25 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
6455

6556
return loss_pn.sum()/bs, loss_fl.sum()/bs, loss_sk.sum()/bs
6657

67-
##### ---- Exp dirs ---- #####
58+
# --- Robust device and distributed/deepspeed setup ---
6859
args = option_vq.get_args_parser()
6960
torch.manual_seed(args.seed)
61+
62+
args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
63+
world_size = int(os.environ.get("WORLD_SIZE", 1))
64+
7065
if torch.cuda.is_available():
7166
torch.cuda.set_device(args.local_rank)
67+
if dist.is_available() and not dist.is_initialized() and world_size > 1:
68+
dist.init_process_group(backend="nccl")
69+
device = torch.device(f"cuda:{args.local_rank}")
70+
else:
71+
device = torch.device("cpu")
7272

7373
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
7474
os.makedirs(args.out_dir, exist_ok = True)
7575

76-
##### ---- Logger ---- #####
76+
# Logger
7777
logger = utils_model.get_logger(args.out_dir)
7878
writer = SummaryWriter(args.out_dir)
7979
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
@@ -83,54 +83,37 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
8383
if args.dataname == 'kit' :
8484
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'
8585
args.nb_joints = 21
86-
86+
if args.dataname == 'addb' :
87+
args.nb_joints = 23
8788
else :
8889
dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
8990
args.nb_joints = 22
9091

91-
args.nb_joints = 23 # fixed issues
92-
9392
logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')
9493

95-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96-
wrapper_opt = get_opt(dataset_opt_path, device)
97-
#eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
98-
99-
100-
##### ---- Dataloader ---- #####
101-
# train_loader = dataset_MOT_MCS.DATALoader(args.dataname,
102-
# args.batch_size,
103-
# window_size=args.window_size,
104-
# unit_length=2**args.down_t)
105-
10694
train_loader = dataset_MOT_segmented.addb_data_loader(
10795
window_size=args.window_size,
10896
unit_length=2**args.down_t,
10997
batch_size=args.batch_size,
11098
mode=args.dataname
11199
)
112100

113-
# train_loader_iter = dataset_MOT_MCS.cycle(train_loader)
114101
train_loader_iter = dataset_MOT_segmented.cycle(train_loader)
115102

116-
# val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
117-
# 32,
118-
# w_vectorizer,
119-
# unit_length=2**args.down_t)
120-
121-
##### ---- Network ---- #####
122-
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
123-
args.nb_code,
124-
args.code_dim,
125-
args.output_emb_width,
126-
args.down_t,
127-
args.stride_t,
128-
args.width,
129-
args.depth,
130-
args.dilation_growth_rate,
131-
args.vq_act,
132-
args.vq_norm)
133-
103+
# Setup VQ-VAE model
104+
net = vqvae.HumanVQVAE(
105+
args,
106+
args.nb_code,
107+
args.code_dim,
108+
args.output_emb_width,
109+
args.down_t,
110+
args.stride_t,
111+
args.width,
112+
args.depth,
113+
args.dilation_growth_rate,
114+
args.vq_act,
115+
args.vq_norm
116+
)
134117

135118
if args.resume_pth :
136119
logger.info('loading checkpoint from {}'.format(args.resume_pth))
@@ -139,67 +122,47 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
139122
net.train()
140123
net.to(device)
141124

142-
##### ---- Optimizer & Scheduler ---- #####
125+
# Optimizer and scheduler setup
143126
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
144127
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
145128

146-
deepspeed_config = {
147-
"train_micro_batch_size_per_gpu": args.batch_size,
148-
"optimizer": {
149-
"type": "AdamW",
150-
"params": {
151-
"lr": args.lr,
152-
"betas": [
153-
0.9,
154-
0.99
155-
],
156-
"weight_decay":args.weight_decay
157-
}
158-
},
159-
"gradient_accumulation_steps": 1,
160-
# "fp16": {
161-
# "enabled": True
162-
# },
163-
"zero_optimization": {
164-
"stage": 0
129+
# Run deepspeed if avaliable (2+ GPUs)
130+
if torch.cuda.is_available() and world_size > 1:
131+
deepspeed_config = {
132+
"train_micro_batch_size_per_gpu": args.batch_size,
133+
"optimizer": {
134+
"type": "AdamW",
135+
"params": {
136+
"lr": args.lr,
137+
"betas": [0.9, 0.99],
138+
"weight_decay": args.weight_decay
139+
}
140+
},
141+
"gradient_accumulation_steps": 1,
142+
"zero_optimization": {"stage": 0}
165143
}
166-
}
167-
net, optimizer, _, _ = deepspeed.initialize(model=net, optimizer=optimizer, args=args, config_params=deepspeed_config)
144+
net, optimizer, _, _ = deepspeed.initialize(
145+
model=net,
146+
optimizer=optimizer,
147+
args=args,
148+
config_params=deepspeed_config
149+
)
150+
else:
151+
logger.info("Running without DeepSpeed (single GPU or CPU).")
168152

169153
Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints)
170154

171-
##### ------ warm-up ------- #####
155+
# Warm up
172156
avg_recons, avg_perplexity, avg_commit, avg_temporal = 0., 0., 0., 0.
173-
174157
for nb_iter in range(1, args.warm_up_iter):
175-
176158
optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
177-
178159
gt_motion,_, names = next(train_loader_iter)
179-
gt_motion = gt_motion.to(device).float() # (bs, 64, dim)
160+
gt_motion = gt_motion.to(device).float()
180161

181162
pred_motion, loss_commit, perplexity = net(gt_motion)
182163
loss_motion = Loss(pred_motion, gt_motion)
183-
184-
loss_temp = torch.mean((pred_motion[:,1:,:]-pred_motion[:,:-1,:])**2)
185-
186-
# loss_vel = Loss.forward_vel(pred_motion, gt_motion)
187-
# loss_pn, loss_fl, loss_sk = get_foot_losses(pred_motion)
188-
# print(loss_pn, loss_fl, loss_sk)
189-
190-
# # hip flexion 7, 15
191-
# hip_flexion_l = -pred_motion[:,:,7].max(dim=1).values.mean()
192-
# hip_flexion_r = -pred_motion[:,:,15].max(dim=1).values.mean()
193-
# # knee angle 10, 18
194-
# knee_angle_l = -pred_motion[:,:,10].max(dim=1).values.mean()
195-
# knee_angle_r = -pred_motion[:,:,18].max(dim=1).values.mean()
196-
# # ankle angle 12, 20
197-
# ankle_angle_l = -pred_motion[:,:,12].max(dim=1).values.mean()
198-
# ankle_angle_r = -pred_motion[:,:,20].max(dim=1).values.mean()
199-
200-
# print(hip_flexion_l, hip_flexion_r, knee_angle_l, knee_angle_r, ankle_angle_l, ankle_angle_r)
201-
202-
loss = loss_motion + args.commit * loss_commit + 0.5 * loss_temp #+ args.loss_vel * loss_vel
164+
loss_temp = torch.mean((pred_motion[:,1:,:] - pred_motion[:,:-1,:])**2)
165+
loss = loss_motion + args.commit * loss_commit + 0.5 * loss_temp
203166

204167
optimizer.zero_grad()
205168
loss.backward()
@@ -220,24 +183,17 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
220183

221184
avg_recons, avg_perplexity, avg_commit, avg_temporal = 0., 0., 0., 0.
222185

223-
##### ---- Training ---- #####
186+
# Training Loop
224187
avg_recons, avg_perplexity, avg_commit, avg_temporal = 0., 0., 0., 0.
225188
torch.save({'net' : net.state_dict()}, os.path.join(args.out_dir, 'warmup.pth'))
226-
# best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)
227-
228189
for nb_iter in range(1, args.total_iter + 1):
229-
230190
gt_motion,_,_ = next(train_loader_iter)
231-
gt_motion = gt_motion.to(device).float() # bs, nb_joints, joints_dim, seq_len
191+
gt_motion = gt_motion.to(device).float()
232192

233193
pred_motion, loss_commit, perplexity = net(gt_motion)
234194
loss_motion = Loss(pred_motion, gt_motion)
235195
loss_temp = torch.mean((pred_motion[:,1:,:]-pred_motion[:,:-1,:])**2)
236-
# loss_vel = Loss.forward_vel(pred_motion, gt_motion)
237-
# loss_pn, loss_fl, loss_sk = get_foot_losses(pred_motion)
238-
# print(loss_pn, loss_fl, loss_sk)
239-
240-
loss = loss_motion + args.commit * loss_commit + 0.5 * loss_temp #+ args.loss_vel * loss_vel # Need to remove/change loss_vel since its not SMPL
196+
loss = loss_motion + args.commit * loss_commit + 0.5 * loss_temp
241197

242198
optimizer.zero_grad()
243199
loss.backward()
@@ -261,16 +217,8 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
261217

262218
logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f} \t Temporal. {avg_temporal:.5f}")
263219

264-
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.,
265-
266-
if nb_iter % (10*args.eval_iter) == 0:
220+
avg_recons, avg_perplexity, avg_commit, avg_temporal = 0., 0., 0., 0.
221+
222+
if nb_iter % (10 * args.eval_iter) == 0:
267223
torch.save({'net' : net.state_dict()}, os.path.join(args.out_dir, str(nb_iter) + '.pth'))
268224

269-
# if nb_iter % args.eval_iter==0 :
270-
# # The line `best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching,
271-
# # writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer,
272-
# # nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching,
273-
# # eval_wrapper=eval_wrapper)` is calling a function named `evaluation_vqvae` from the
274-
# # `eval_trans` module.
275-
# best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
276-

0 commit comments

Comments
 (0)