1
- import os
1
+ import deepspeed
2
2
import json
3
- # from osim_sequence import OSIMSequence,load_osim
3
+ import models .vqvae as vqvae
4
+ import nimblephysics as nimble
5
+ import options .option_vq as option_vq
6
+ import os
4
7
import torch
8
+ import torch .distributed as dist
5
9
import torch .optim as optim
6
- from torch .utils .tensorboard import SummaryWriter
7
- from tqdm import tqdm
8
-
9
- import models .vqvae as vqvae
10
+ import utils .eval_trans as eval_trans
10
11
import utils .losses as losses
11
- import options .option_vq as option_vq
12
12
import utils .utils_model as utils_model
13
+ import warnings
13
14
from dataset import dataset_MOT_MCS , dataset_TM_eval , dataset_MOT_segmented
14
- import utils .eval_trans as eval_trans
15
- from options .get_eval_option import get_opt
16
15
from models .evaluator_wrapper import EvaluatorModelWrapper
17
- import warnings
18
- warnings .filterwarnings ('ignore' )
16
+ from options .get_eval_option import get_opt
17
+ from torch .utils .tensorboard import SummaryWriter
18
+ from tqdm import tqdm
19
19
from utils .word_vectorizer import WordVectorizer
20
- import nimblephysics as nimble
21
- import deepspeed
22
-
20
+ warnings .filterwarnings ('ignore' )
23
21
24
22
def update_lr_warm_up (optimizer , nb_iter , warm_up_iter , lr ):
25
-
26
23
current_lr = lr * (nb_iter + 1 ) / (warm_up_iter + 1 )
27
24
for param_group in optimizer .param_groups :
28
25
param_group ["lr" ] = current_lr
29
-
30
26
return optimizer , current_lr
31
27
32
28
def get_foot_losses (motion , y_translation = 0.0 ,feet_threshold = 0.01 ):
33
- # y_translation = 0.0
34
29
min_height , idx = motion [..., 1 ].min (dim = - 1 )
35
-
36
- # y_translation = -min_height.median() # Change reference to median (Other set of experiments determine this. See paper)
37
-
38
- # print(min_height,idx,motion[..., 1].shape)
39
30
min_height = min_height + y_translation
40
31
pn = - torch .minimum (min_height , torch .zeros_like (min_height )) # penetration
41
32
pn [pn < feet_threshold ] = 0.0
@@ -64,16 +55,25 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
64
55
65
56
return loss_pn .sum ()/ bs , loss_fl .sum ()/ bs , loss_sk .sum ()/ bs
66
57
67
- ##### ---- Exp dirs ---- #####
58
+ # --- Robust device and distributed/deepspeed setup ---
68
59
args = option_vq .get_args_parser ()
69
60
torch .manual_seed (args .seed )
61
+
62
+ args .local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
63
+ world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
64
+
70
65
if torch .cuda .is_available ():
71
66
torch .cuda .set_device (args .local_rank )
67
+ if dist .is_available () and not dist .is_initialized () and world_size > 1 :
68
+ dist .init_process_group (backend = "nccl" )
69
+ device = torch .device (f"cuda:{ args .local_rank } " )
70
+ else :
71
+ device = torch .device ("cpu" )
72
72
73
73
args .out_dir = os .path .join (args .out_dir , f'{ args .exp_name } ' )
74
74
os .makedirs (args .out_dir , exist_ok = True )
75
75
76
- ##### ---- Logger ---- #####
76
+ # Logger
77
77
logger = utils_model .get_logger (args .out_dir )
78
78
writer = SummaryWriter (args .out_dir )
79
79
logger .info (json .dumps (vars (args ), indent = 4 , sort_keys = True ))
@@ -83,54 +83,37 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
83
83
if args .dataname == 'kit' :
84
84
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'
85
85
args .nb_joints = 21
86
-
86
+ if args .dataname == 'addb' :
87
+ args .nb_joints = 23
87
88
else :
88
89
dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
89
90
args .nb_joints = 22
90
91
91
- args .nb_joints = 23 # fixed issues
92
-
93
92
logger .info (f'Training on { args .dataname } , motions are with { args .nb_joints } joints' )
94
93
95
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
96
- wrapper_opt = get_opt (dataset_opt_path , device )
97
- #eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
98
-
99
-
100
- ##### ---- Dataloader ---- #####
101
- # train_loader = dataset_MOT_MCS.DATALoader(args.dataname,
102
- # args.batch_size,
103
- # window_size=args.window_size,
104
- # unit_length=2**args.down_t)
105
-
106
94
train_loader = dataset_MOT_segmented .addb_data_loader (
107
95
window_size = args .window_size ,
108
96
unit_length = 2 ** args .down_t ,
109
97
batch_size = args .batch_size ,
110
98
mode = args .dataname
111
99
)
112
100
113
- # train_loader_iter = dataset_MOT_MCS.cycle(train_loader)
114
101
train_loader_iter = dataset_MOT_segmented .cycle (train_loader )
115
102
116
- # val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
117
- # 32,
118
- # w_vectorizer,
119
- # unit_length=2**args.down_t)
120
-
121
- ##### ---- Network ---- #####
122
- net = vqvae .HumanVQVAE (args , ## use args to define different parameters in different quantizers
123
- args .nb_code ,
124
- args .code_dim ,
125
- args .output_emb_width ,
126
- args .down_t ,
127
- args .stride_t ,
128
- args .width ,
129
- args .depth ,
130
- args .dilation_growth_rate ,
131
- args .vq_act ,
132
- args .vq_norm )
133
-
103
+ # Setup VQ-VAE model
104
+ net = vqvae .HumanVQVAE (
105
+ args ,
106
+ args .nb_code ,
107
+ args .code_dim ,
108
+ args .output_emb_width ,
109
+ args .down_t ,
110
+ args .stride_t ,
111
+ args .width ,
112
+ args .depth ,
113
+ args .dilation_growth_rate ,
114
+ args .vq_act ,
115
+ args .vq_norm
116
+ )
134
117
135
118
if args .resume_pth :
136
119
logger .info ('loading checkpoint from {}' .format (args .resume_pth ))
@@ -139,67 +122,47 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
139
122
net .train ()
140
123
net .to (device )
141
124
142
- ##### ---- Optimizer & Scheduler ---- #####
125
+ # Optimizer and scheduler setup
143
126
optimizer = optim .AdamW (net .parameters (), lr = args .lr , betas = (0.9 , 0.99 ), weight_decay = args .weight_decay )
144
127
scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , milestones = args .lr_scheduler , gamma = args .gamma )
145
128
146
- deepspeed_config = {
147
- "train_micro_batch_size_per_gpu" : args .batch_size ,
148
- "optimizer" : {
149
- "type" : "AdamW" ,
150
- "params" : {
151
- "lr" : args .lr ,
152
- "betas" : [
153
- 0.9 ,
154
- 0.99
155
- ],
156
- "weight_decay" :args .weight_decay
157
- }
158
- },
159
- "gradient_accumulation_steps" : 1 ,
160
- # "fp16": {
161
- # "enabled": True
162
- # },
163
- "zero_optimization" : {
164
- "stage" : 0
129
+ # Run deepspeed if avaliable (2+ GPUs)
130
+ if torch .cuda .is_available () and world_size > 1 :
131
+ deepspeed_config = {
132
+ "train_micro_batch_size_per_gpu" : args .batch_size ,
133
+ "optimizer" : {
134
+ "type" : "AdamW" ,
135
+ "params" : {
136
+ "lr" : args .lr ,
137
+ "betas" : [0.9 , 0.99 ],
138
+ "weight_decay" : args .weight_decay
139
+ }
140
+ },
141
+ "gradient_accumulation_steps" : 1 ,
142
+ "zero_optimization" : {"stage" : 0 }
165
143
}
166
- }
167
- net , optimizer , _ , _ = deepspeed .initialize (model = net , optimizer = optimizer , args = args , config_params = deepspeed_config )
144
+ net , optimizer , _ , _ = deepspeed .initialize (
145
+ model = net ,
146
+ optimizer = optimizer ,
147
+ args = args ,
148
+ config_params = deepspeed_config
149
+ )
150
+ else :
151
+ logger .info ("Running without DeepSpeed (single GPU or CPU)." )
168
152
169
153
Loss = losses .ReConsLoss (args .recons_loss , args .nb_joints )
170
154
171
- ##### ------ warm-up ------- #####
155
+ # Warm up
172
156
avg_recons , avg_perplexity , avg_commit , avg_temporal = 0. , 0. , 0. , 0.
173
-
174
157
for nb_iter in range (1 , args .warm_up_iter ):
175
-
176
158
optimizer , current_lr = update_lr_warm_up (optimizer , nb_iter , args .warm_up_iter , args .lr )
177
-
178
159
gt_motion ,_ , names = next (train_loader_iter )
179
- gt_motion = gt_motion .to (device ).float () # (bs, 64, dim)
160
+ gt_motion = gt_motion .to (device ).float ()
180
161
181
162
pred_motion , loss_commit , perplexity = net (gt_motion )
182
163
loss_motion = Loss (pred_motion , gt_motion )
183
-
184
- loss_temp = torch .mean ((pred_motion [:,1 :,:]- pred_motion [:,:- 1 ,:])** 2 )
185
-
186
- # loss_vel = Loss.forward_vel(pred_motion, gt_motion)
187
- # loss_pn, loss_fl, loss_sk = get_foot_losses(pred_motion)
188
- # print(loss_pn, loss_fl, loss_sk)
189
-
190
- # # hip flexion 7, 15
191
- # hip_flexion_l = -pred_motion[:,:,7].max(dim=1).values.mean()
192
- # hip_flexion_r = -pred_motion[:,:,15].max(dim=1).values.mean()
193
- # # knee angle 10, 18
194
- # knee_angle_l = -pred_motion[:,:,10].max(dim=1).values.mean()
195
- # knee_angle_r = -pred_motion[:,:,18].max(dim=1).values.mean()
196
- # # ankle angle 12, 20
197
- # ankle_angle_l = -pred_motion[:,:,12].max(dim=1).values.mean()
198
- # ankle_angle_r = -pred_motion[:,:,20].max(dim=1).values.mean()
199
-
200
- # print(hip_flexion_l, hip_flexion_r, knee_angle_l, knee_angle_r, ankle_angle_l, ankle_angle_r)
201
-
202
- loss = loss_motion + args .commit * loss_commit + 0.5 * loss_temp #+ args.loss_vel * loss_vel
164
+ loss_temp = torch .mean ((pred_motion [:,1 :,:] - pred_motion [:,:- 1 ,:])** 2 )
165
+ loss = loss_motion + args .commit * loss_commit + 0.5 * loss_temp
203
166
204
167
optimizer .zero_grad ()
205
168
loss .backward ()
@@ -220,24 +183,17 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
220
183
221
184
avg_recons , avg_perplexity , avg_commit , avg_temporal = 0. , 0. , 0. , 0.
222
185
223
- ##### ---- Training ---- #####
186
+ # Training Loop
224
187
avg_recons , avg_perplexity , avg_commit , avg_temporal = 0. , 0. , 0. , 0.
225
188
torch .save ({'net' : net .state_dict ()}, os .path .join (args .out_dir , 'warmup.pth' ))
226
- # best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)
227
-
228
189
for nb_iter in range (1 , args .total_iter + 1 ):
229
-
230
190
gt_motion ,_ ,_ = next (train_loader_iter )
231
- gt_motion = gt_motion .to (device ).float () # bs, nb_joints, joints_dim, seq_len
191
+ gt_motion = gt_motion .to (device ).float ()
232
192
233
193
pred_motion , loss_commit , perplexity = net (gt_motion )
234
194
loss_motion = Loss (pred_motion , gt_motion )
235
195
loss_temp = torch .mean ((pred_motion [:,1 :,:]- pred_motion [:,:- 1 ,:])** 2 )
236
- # loss_vel = Loss.forward_vel(pred_motion, gt_motion)
237
- # loss_pn, loss_fl, loss_sk = get_foot_losses(pred_motion)
238
- # print(loss_pn, loss_fl, loss_sk)
239
-
240
- loss = loss_motion + args .commit * loss_commit + 0.5 * loss_temp #+ args.loss_vel * loss_vel # Need to remove/change loss_vel since its not SMPL
196
+ loss = loss_motion + args .commit * loss_commit + 0.5 * loss_temp
241
197
242
198
optimizer .zero_grad ()
243
199
loss .backward ()
@@ -261,16 +217,8 @@ def get_foot_losses(motion, y_translation=0.0,feet_threshold=0.01):
261
217
262
218
logger .info (f"Train. Iter { nb_iter } : \t Commit. { avg_commit :.5f} \t PPL. { avg_perplexity :.2f} \t Recons. { avg_recons :.5f} \t Temporal. { avg_temporal :.5f} " )
263
219
264
- avg_recons , avg_perplexity , avg_commit = 0. , 0. , 0. ,
265
-
266
- if nb_iter % (10 * args .eval_iter ) == 0 :
220
+ avg_recons , avg_perplexity , avg_commit , avg_temporal = 0. , 0. , 0. , 0.
221
+
222
+ if nb_iter % (10 * args .eval_iter ) == 0 :
267
223
torch .save ({'net' : net .state_dict ()}, os .path .join (args .out_dir , str (nb_iter ) + '.pth' ))
268
224
269
- # if nb_iter % args.eval_iter==0 :
270
- # # The line `best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching,
271
- # # writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer,
272
- # # nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching,
273
- # # eval_wrapper=eval_wrapper)` is calling a function named `evaluation_vqvae` from the
274
- # # `eval_trans` module.
275
- # best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
276
-
0 commit comments