@@ -237,8 +237,6 @@ def predict(model, dataloader, device, output_logits=False):
237237 :rtype: (np.ndarray, np.ndarray, np.ndarray)
238238 """
239239
240- if verbose :
241- print ('Classifying windows...' )
242240
243241 predictions_list = []
244242 true_list = []
@@ -249,7 +247,7 @@ def predict(model, dataloader, device, output_logits=False):
249247 return np .array ([]), np .array ([]), np .array ([])
250248
251249 with torch .inference_mode ():
252- for x , y , pid in tqdm (dataloader , mininterval = 5 , disable = not verbose ):
250+ for x , y , pid in tqdm (dataloader , total = len ( dataloader ), mininterval = 5 , disable = not verbose , bar_format = 'Classifying segments: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]' ):
253251 x = x .to (device , dtype = torch .float )
254252 logits = model (x )
255253 true_list .append (y )
@@ -314,7 +312,7 @@ def train(model, train_loader, val_loader, device, class_weights=None, weights_p
314312 model .train ()
315313 train_losses = []
316314 train_acces = []
317- for x , y , _ in tqdm (train_loader , disable = not verbose ):
315+ for x , y , _ in tqdm (train_loader , total = len ( train_loader ), disable = not verbose , bar_format = 'Training: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]' ):
318316 x .requires_grad_ (True )
319317 x = x .to (device , dtype = torch .float )
320318 true_y = y .to (device , dtype = torch .long )
0 commit comments