Skip to content

Commit 74cc784

Browse files
committed
refactor: Improve verbosity messages and progress bars in data processing functions
1 parent 56fe9b2 commit 74cc784

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

src/stepcount/models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,14 +493,14 @@ def make_windows(data, window_sec, fn=None, return_index=False, verbose=True):
493493
""" Split data into windows """
494494

495495
if verbose:
496-
print("Defining windows...")
496+
print("Defining segments...")
497497

498498
if fn is None:
499499
def fn(x):
500500
return x
501501

502502
X, T = [], []
503-
for t, x in tqdm(data.resample(f"{window_sec}s", origin="start"), mininterval=5, disable=not verbose):
503+
for t, x in data.resample(f"{window_sec}s", origin="start"):
504504
x = fn(x)
505505
X.append(x)
506506
T.append(t)
@@ -626,12 +626,10 @@ def smooth_mean_absolute_percentage_error(yt, yp, sample_weight=None):
626626
def batch_extract_features(X, sample_rate, to_numpy=True, n_jobs=1, verbose=False):
627627
""" Extract features for a list or array of windows """
628628

629-
if verbose:
630-
print("Extracting features...")
631629

632630
X_feats = Parallel(n_jobs=n_jobs)(
633631
delayed(features.extract_features)(x, sample_rate)
634-
for x in tqdm(X, mininterval=5, disable=not verbose)
632+
for x in tqdm(X, total=len(X), mininterval=5, disable=not verbose, bar_format='Extracting features: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]')
635633
)
636634
X_feats = pd.DataFrame(X_feats)
637635

src/stepcount/sslmodel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,6 @@ def predict(model, dataloader, device, output_logits=False):
237237
:rtype: (np.ndarray, np.ndarray, np.ndarray)
238238
"""
239239

240-
if verbose:
241-
print('Classifying windows...')
242240

243241
predictions_list = []
244242
true_list = []
@@ -249,7 +247,7 @@ def predict(model, dataloader, device, output_logits=False):
249247
return np.array([]), np.array([]), np.array([])
250248

251249
with torch.inference_mode():
252-
for x, y, pid in tqdm(dataloader, mininterval=5, disable=not verbose):
250+
for x, y, pid in tqdm(dataloader, total=len(dataloader), mininterval=5, disable=not verbose, bar_format='Classifying segments: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]'):
253251
x = x.to(device, dtype=torch.float)
254252
logits = model(x)
255253
true_list.append(y)
@@ -314,7 +312,7 @@ def train(model, train_loader, val_loader, device, class_weights=None, weights_p
314312
model.train()
315313
train_losses = []
316314
train_acces = []
317-
for x, y, _ in tqdm(train_loader, disable=not verbose):
315+
for x, y, _ in tqdm(train_loader, total=len(train_loader), disable=not verbose, bar_format='Training: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}]'):
318316
x.requires_grad_(True)
319317
x = x.to(device, dtype=torch.float)
320318
true_y = y.to(device, dtype=torch.long)

src/stepcount/stepcount.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def main():
153153

154154
model.wd.device = args.pytorch_device
155155

156-
if verbose:
157-
print("Running step counter...")
158156
Y, W, T_steps = model.predict_from_frame(data)
159157

160158
# Save step counts

0 commit comments

Comments
 (0)