Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 83e2386

Browse files
committedMay 1, 2025
Updated dataloader to work with AddBiomechanics dataset
1 parent 6ecf8b6 commit 83e2386

File tree

5 files changed

+138
-28
lines changed

5 files changed

+138
-28
lines changed
 

‎Dockerfile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ ENV PATH=/usr/local/bin:$PATH
2222

2323
# Clone UCSD-Github dataset
2424
# Set the working directory
25-
WORKDIR /
26-
RUN git -c http.sslVerify=false clone https://github.com/Rose-STL-Lab/UCSD-OpenCap-Fitness-Dataset.git
25+
#WORKDIR /
26+
#RUN git -c http.sslVerify=false clone https://github.com/Rose-STL-Lab/UCSD-OpenCap-Fitness-Dataset.git
2727

2828

2929
# Clone the digital-coach-anwesh repository
30-
RUN git -c http.sslVerify=false clone https://gitlab.nrp-nautilus.io/shmaheshwari/digital-coach-anwesh.git .
30+
#RUN git -c http.sslVerify=false clone https://gitlab.nrp-nautilus.io/shmaheshwari/digital-coach-anwesh.git .
3131

3232
# Copy the environment.yml file and create the conda environment
3333
# COPY digital-coach-anwesh/environment.yml /T2M-GPT/environment.yml
34+
COPY . /T2M-GPT
3435
RUN conda env create -f environment.yml
3536

3637
# Activate the conda environment

‎dataset/dataset_MOT_segmented.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
import codecs as cs
2+
import nimblephysics as nimble
3+
import numpy as np
14
import os
5+
import random
26
import torch
3-
from torch.utils import data
4-
import numpy as np
7+
from glob import glob
58
from os.path import join as pjoin
6-
import random
7-
import codecs as cs
9+
from torch.utils import data
810
from tqdm import tqdm
9-
from glob import glob
10-
1111

1212
class VQMotionDataset(data.Dataset):
1313
def __init__(self, dataset_name, window_size = 64, unit_length = 4, mode = 'train', mode2='embeddings', data_dirs=['/home/ubuntu/data/MCS_DATA', '/media/shubh/Elements/RoseYu/UCSD-OpenCap-Fitness-Dataset/MCS_DATA']):
@@ -204,6 +204,108 @@ def __getitem__(self, item):
204204

205205
return subsequences, subsequence_lengths, names
206206

207+
class AddBiomechanicsDataset(data.Dataset):
208+
def __init__(self, window_size=64, unit_length=4, mode='train', data_dir='/home/mnt/data/addb_dataset_publication'):
209+
self.window_size = window_size
210+
self.unit_length = unit_length
211+
self.data_dir = data_dir
212+
self.mode = mode
213+
214+
# Define subdirectories for each paper
215+
paper_dirs = [
216+
"train/No_Arm/Falisse2016_Formatted_No_Arm",
217+
"train/No_Arm/Uhlrich2023_Opencap_Formatted_No_Arm",
218+
"train/No_Arm/Wang2023_Formatted_No_Arm",
219+
"train/No_Arm/Han2023_Formatted_No_Arm",
220+
]
221+
222+
# Collect all .b3d files from the specified subdirectories
223+
self.b3d_file_paths = []
224+
for paper_dir in paper_dirs:
225+
search_path = os.path.join(data_dir, paper_dir, '**', '*.b3d')
226+
files = glob(search_path, recursive=True)
227+
self.b3d_file_paths.extend(files)
228+
229+
self.motion_data = []
230+
self.motion_lengths = []
231+
self.motion_names = []
232+
self.motion_fps = []
233+
234+
for b3d_file in tqdm(self.b3d_file_paths):
235+
try:
236+
if os.path.getsize(b3d_file) == 0:
237+
continue
238+
subject = nimble.biomechanics.SubjectOnDisk(b3d_file)
239+
num_trials = subject.getNumTrials()
240+
for trial in range(num_trials):
241+
trial_length = subject.getTrialLength(trial)
242+
if trial_length < self.window_size:
243+
continue
244+
frames = subject.readFrames(
245+
trial=trial,
246+
startFrame=0,
247+
numFramesToRead=trial_length,
248+
includeSensorData=False,
249+
includeProcessingPasses=True
250+
)
251+
if not frames:
252+
continue
253+
kin_passes = [frame.processingPasses[0] for frame in frames]
254+
positions = np.array([kp.pos for kp in kin_passes]) # shape: (frames, dofs)
255+
# Get FPS for this trial
256+
seconds_per_frame = subject.getTrialTimestep(trial)
257+
fps = int(round(1.0 / seconds_per_frame)) if seconds_per_frame > 0 else 0
258+
259+
# Downsample here, at load time
260+
if fps == 100:
261+
positions = positions[::2] # Take every 2nd frame
262+
elif fps == 250:
263+
positions = positions[::5] # Take every 5th frame
264+
265+
# After downsampling, skip if too short
266+
if len(positions) < self.window_size:
267+
continue
268+
269+
self.motion_data.append(positions)
270+
self.motion_lengths.append(len(positions))
271+
self.motion_names.append(f"{b3d_file}::trial{trial}")
272+
self.motion_fps.append(fps)
273+
except Exception as e:
274+
print(f"Skipping file {b3d_file} due to error: {e}")
275+
276+
print("Total number of motions:", len(self.motion_data))
277+
print("Example motion shape:", self.motion_data[0].shape if self.motion_data else "None")
278+
279+
def __len__(self):
280+
return len(self.motion_data)
281+
282+
def __getitem__(self, item):
283+
motion = self.motion_data[item]
284+
len_motion = len(motion) if len(motion) <= self.window_size else self.window_size
285+
name = self.motion_names[item]
286+
287+
# Crop or pad to window_size (no downsampling here)
288+
if len(motion) >= self.window_size:
289+
idx = random.randint(0, len(motion) - self.window_size)
290+
motion = motion[idx:idx + self.window_size]
291+
else:
292+
repeat_count = (self.window_size + len(motion) - 1) // len(motion)
293+
motion = np.tile(motion, (repeat_count, 1))[:self.window_size]
294+
295+
return motion, len_motion, name
296+
297+
298+
def addb_data_loader(window_size=64, unit_length=4, batch_size=1, num_workers=4, mode='train'):
299+
dataset = AddBiomechanicsDataset(window_size=window_size, unit_length=unit_length, mode=mode)
300+
loader = torch.utils.data.DataLoader(
301+
dataset,
302+
batch_size=batch_size,
303+
shuffle=True,
304+
num_workers=num_workers,
305+
drop_last=True
306+
)
307+
return loader
308+
207309
def DATALoader(dataset_name,
208310
batch_size,
209311
num_workers = 4,
@@ -231,4 +333,4 @@ def cycle(iterable):
231333

232334

233335
if __name__ == "__main__":
234-
dataloader = DATALoader('mcs',1,window_size=64,unit_length=2**2,mode='limo')
336+
dataloader = addb_data_loader(window_size=64, unit_length=4, batch_size=1, mode='train')

‎environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ dependencies:
8585
- markdown==3.3.4
8686
- matplotlib==3.4.3
8787
- matplotlib-inline==0.1.2
88+
- nimblephysics
8889
- oauthlib==3.1.1
8990
- pandas==1.3.2
9091
- parso==0.8.2

‎models/vqvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(self,
2121
self.code_dim = code_dim
2222
self.num_code = nb_code
2323
self.quant = args.quantizer
24-
self.encoder = Encoder(33 if args.dataname == 'mcs' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
25-
self.decoder = Decoder(33 if args.dataname == 'mcs' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
24+
self.encoder = Encoder(23 if args.dataname == 'mcs' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
25+
self.decoder = Decoder(23 if args.dataname == 'mcs' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
2626
if args.quantizer == "ema_reset":
2727
self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
2828
elif args.quantizer == "orig":

‎train_vq.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
warnings.filterwarnings('ignore')
1919
from utils.word_vectorizer import WordVectorizer
20-
# import nimblephysics as nimble
20+
import nimblephysics as nimble
2121
import deepspeed
2222

2323

@@ -67,7 +67,8 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
6767
##### ---- Exp dirs ---- #####
6868
args = option_vq.get_args_parser()
6969
torch.manual_seed(args.seed)
70-
torch.cuda.set_device(args.local_rank)
70+
if torch.cuda.is_available():
71+
torch.cuda.set_device(args.local_rank)
7172

7273
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
7374
os.makedirs(args.out_dir, exist_ok = True)
@@ -87,10 +88,13 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
8788
dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
8889
args.nb_joints = 22
8990

91+
args.nb_joints = 23 # fixed issues
92+
9093
logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')
9194

92-
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
93-
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
95+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
96+
wrapper_opt = get_opt(dataset_opt_path, device)
97+
#eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
9498

9599

96100
##### ---- Dataloader ---- #####
@@ -99,18 +103,20 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
99103
# window_size=args.window_size,
100104
# unit_length=2**args.down_t)
101105

102-
train_loader = dataset_MOT_segmented.DATALoader(args.dataname,
103-
args.batch_size,
104-
window_size=args.window_size,
105-
unit_length=2**args.down_t)
106+
train_loader = dataset_MOT_segmented.addb_data_loader(
107+
window_size=args.window_size,
108+
unit_length=2**args.down_t,
109+
batch_size=args.batch_size,
110+
mode=args.dataname
111+
)
106112

107113
# train_loader_iter = dataset_MOT_MCS.cycle(train_loader)
108114
train_loader_iter = dataset_MOT_segmented.cycle(train_loader)
109115

110-
val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
111-
32,
112-
w_vectorizer,
113-
unit_length=2**args.down_t)
116+
# val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
117+
# 32,
118+
# w_vectorizer,
119+
# unit_length=2**args.down_t)
114120

115121
##### ---- Network ---- #####
116122
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
@@ -128,10 +134,10 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
128134

129135
if args.resume_pth :
130136
logger.info('loading checkpoint from {}'.format(args.resume_pth))
131-
ckpt = torch.load(args.resume_pth, map_location='cuda')
137+
ckpt = torch.load(args.resume_pth, map_location=device)
132138
net.load_state_dict(ckpt['net'], strict=True)
133139
net.train()
134-
net.cuda()
140+
net.to(device)
135141

136142
##### ---- Optimizer & Scheduler ---- #####
137143
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
@@ -170,7 +176,7 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
170176
optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
171177

172178
gt_motion,_, names = next(train_loader_iter)
173-
gt_motion = gt_motion.cuda().float() # (bs, 64, dim)
179+
gt_motion = gt_motion.to(device).float() # (bs, 64, dim)
174180

175181
pred_motion, loss_commit, perplexity = net(gt_motion)
176182
loss_motion = Loss(pred_motion, gt_motion)
@@ -222,7 +228,7 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
222228
for nb_iter in range(1, args.total_iter + 1):
223229

224230
gt_motion,_,_ = next(train_loader_iter)
225-
gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len
231+
gt_motion = gt_motion.to(device).float() # bs, nb_joints, joints_dim, seq_len
226232

227233
pred_motion, loss_commit, perplexity = net(gt_motion)
228234
loss_motion = Loss(pred_motion, gt_motion)

0 commit comments

Comments
 (0)
Please sign in to comment.