Skip to content

Commit b6e10b1

Browse files
committed
Replace tqdm and coloredlogs with rich (GitHub issue #272)
1 parent b0fcc3c commit b6e10b1

File tree

12 files changed

+704
-731
lines changed

12 files changed

+704
-731
lines changed

bidscoin/bcoin.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import os
99
import types
10-
import coloredlogs
1110
import logging
1211
import shutil
1312
import sys
@@ -21,8 +20,10 @@
2120
from importlib.util import spec_from_file_location, module_from_spec
2221
from pathlib import Path
2322
from typing import Union
24-
from tqdm import tqdm
25-
from tqdm.contrib.logging import logging_redirect_tqdm
23+
from rich.progress import track, Progress, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn, TimeElapsedColumn
24+
from rich.logging import RichHandler
25+
from rich.theme import Theme
26+
from rich.console import Console
2627
from importlib.util import find_spec
2728
if find_spec('bidscoin') is None:
2829
sys.path.append(str(Path(__file__).parents[1]))
@@ -31,22 +32,6 @@
3132
LOGGER = logging.getLogger(__name__)
3233

3334

34-
class TqdmUpTo(tqdm):
35-
36-
def update_to(self, b=1, bsize=1, tsize=None):
37-
"""
38-
Adds a tqdm progress bar to urllib.request.urlretrieve()
39-
https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
40-
41-
:param b: Number of blocks transferred so far [default: 1].
42-
:param bsize: Size of each block (in tqdm units) [default: 1].
43-
:param tsize: Total size (in tqdm units). If [default: None] remains unchanged.
44-
"""
45-
if tsize is not None:
46-
self.total = tsize
47-
self.update(b * bsize - self.n) # will also set self.n = b * bsize
48-
49-
5035
def drmaa_nativespec(specs: str, session) -> str:
5136
"""
5237
Converts (CLI default) native Torque walltime and memory specifications to the DRMAA implementation (currently only Slurm is supported)
@@ -75,42 +60,42 @@ def drmaa_nativespec(specs: str, session) -> str:
7560

7661
def synchronize(pbatch, jobids: list, event: str, wait: int=15):
7762
"""
78-
Shows tqdm progress bars for queued and running DRMAA jobs. Waits until all jobs have finished +
79-
some extra wait time to give NAS systems the opportunity to fully synchronize
63+
Shows Rich progress bars for queued and running DRMAA jobs. Waits until all jobs have finished +
64+
some extra wait time to give NAS systems the opportunity to fully synchronize.
8065
8166
:param pbatch: The DRMAA session
8267
:param jobids: The job ids
8368
:param event: The event that is passed to trackusage()
8469
:param wait: The extra wait time for the NAS
8570
"""
8671

87-
if jobids:
88-
match = re.search(r"(slurm|pbs|torque|sge|lsf|condor|uge)", pbatch.drmaaImplementation.lower())
89-
trackusage(f"{event}_{match.group(1) if match else 'drmaa'}")
90-
else:
72+
if not jobids:
9173
return
9274

93-
with logging_redirect_tqdm():
75+
match = re.search(r"(slurm|pbs|torque|sge|lsf|condor|uge)", pbatch.drmaaImplementation.lower())
76+
trackusage(f"{event}_{match.group(1) if match else 'drmaa'}")
77+
78+
with Progress(TextColumn('{task.description}'), BarColumn(), TextColumn('{task.completed}/{task.total}'), TimeElapsedColumn(), transient=True) as progress:
79+
80+
qtask = progress.add_task('[white]Queued ', total=len(jobids))
81+
rtask = progress.add_task('[green]Running ', total=len(jobids))
9482

95-
qbar = tqdm(total=len(jobids), desc='Queued ', unit='job', leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]')
96-
rbar = tqdm(total=len(jobids), desc='Running', unit='job', leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]', colour='green')
9783
done = 0
9884
while done < len(jobids):
9985
jobs = [pbatch.jobStatus(jobid) for jobid in jobids]
10086
done = sum(status in ('done', 'failed', 'undetermined') for status in jobs)
101-
qbar.n = sum(status == 'queued_active' for status in jobs)
102-
rbar.n = sum(status == 'running' for status in jobs)
103-
qbar.refresh(), rbar.refresh()
87+
qcount = sum(status == 'queued_active' for status in jobs)
88+
rcount = sum(status == 'running' for status in jobs)
89+
progress.update(qtask, completed=qcount)
90+
progress.update(rtask, completed=rcount)
10491
time.sleep(2)
105-
qbar.close(), rbar.close()
10692

107-
failedjobs = [jobid for jobid in jobids if pbatch.jobStatus(jobid)=='failed']
108-
if failedjobs:
93+
if failedjobs := [jobid for jobid in jobids if pbatch.jobStatus(jobid) == 'failed']:
10994
LOGGER.error(f"{len(failedjobs)} HPC jobs failed to run:\n{failedjobs}\nThis may well be due to an underspecified `--cluster` input option (e.g. not enough memory)")
11095

111-
# Give NAS systems some time to fully synchronize
112-
for t in tqdm(range(wait*100), desc='synchronizing', leave=False, bar_format='{l_bar}{bar}| [{elapsed}]'):
113-
time.sleep(.01)
96+
# Synchronization wait bar
97+
for t in track(range(wait*100), description='[cyan]Synchronizing', transient=True):
98+
time.sleep(0.01)
11499

115100

116101
def setup_logging(logfile: Path=Path()):
@@ -129,11 +114,11 @@ def setup_logging(logfile: Path=Path()):
129114

130115
# Set the default formats
131116
if DEBUG:
132-
fmt = '%(asctime)s - %(name)s - %(levelname)s | %(message)s'
133-
cfmt = '%(levelname)s - %(name)s | %(message)s'
117+
fmt = '%(asctime)s - %(name)s | %(message)s'
118+
cfmt = '%(name)s | %(message)s'
134119
else:
135120
fmt = '%(asctime)s - %(levelname)s | %(message)s'
136-
cfmt = '%(levelname)s | %(message)s'
121+
cfmt = '%(message)s'
137122
datefmt = '%Y-%m-%d %H:%M:%S'
138123

139124
# Register custom log levels
@@ -161,13 +146,16 @@ def success(self, message, *args, **kws):
161146
logger = logging.getLogger()
162147
logger.setLevel('BCDEBUG' if DEBUG else 'VERBOSE')
163148

164-
# Add the console streamhandler and bring some color to those boring logs! :-)
165-
coloredlogs.install(level='BCDEBUG' if DEBUG else 'VERBOSE' if not logfile.name else 'INFO', fmt=cfmt, datefmt=datefmt) # NB: Using tqdm sets the streamhandler level to 0, see: https://github.com/tqdm/tqdm/pull/1235
166-
coloredlogs.DEFAULT_LEVEL_STYLES['verbose']['color'] = 245 # = Gray
149+
# Add the Rich console handler and bring some color to those boring logs! :-)
150+
console = Console(theme=Theme({'logging.level.verbose': 'grey50', 'logging.level.success': 'green bold', 'logging.level.bcdebug': 'bright_yellow'}))
151+
consolehandler = RichHandler(console=console, show_time=False, show_level=True, show_path=DEBUG, rich_tracebacks=True, markup=True, level='BCDEBUG' if DEBUG else 'VERBOSE' if not logfile.name else 'INFO')
152+
consolehandler.set_name('console')
153+
consolehandler.setFormatter(logging.Formatter(fmt=cfmt, datefmt=datefmt))
154+
logger.addHandler(consolehandler)
167155

168156
if logfile.name:
169157

170-
# Add the log filehandler
158+
# Add the verbose filehandler
171159
logfile.parent.mkdir(parents=True, exist_ok=True) # Create the log dir if it does not exist
172160
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
173161
loghandler = logging.FileHandler(logfile)
@@ -184,7 +172,7 @@ def success(self, message, *args, **kws):
184172
logger.addHandler(errorhandler)
185173

186174
if DEBUG:
187-
LOGGER.info('\t<<<<<<<<<< Running BIDScoin in DEBUG mode >>>>>>>>>>')
175+
LOGGER.info('\t[bold bright_yellow]<<<<<<<<<< Running BIDScoin in DEBUG mode >>>>>>>>>>')
188176
settracking('show')
189177

190178

@@ -550,8 +538,14 @@ def pulltutorialdata(tutorialfolder: str) -> None:
550538

551539
# Download the data
552540
LOGGER.info(f"Downloading the tutorial dataset...")
553-
with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=tutorialtargz.name) as t:
554-
urllib.request.urlretrieve(tutorialurl, tutorialtargz, reporthook=t.update_to) # NB: In case of ssl certificate issues use: with urllib.request.urlopen(tutorialurl, context=ssl.SSLContext()) as data, open(tutorialtargz, 'wb') as targz_fid: shutil.copyfileobj(data, targz_fid)
541+
with Progress(TextColumn('[bold blue]{task.fields[filename]}'), BarColumn(), DownloadColumn(), TransferSpeedColumn(), TimeRemainingColumn()) as progress:
542+
task = progress.add_task('[cyan]Download', filename=tutorialtargz.name, total=None)
543+
def reporthook(blocknum: int, blocksize: int, totalsize: int):
544+
if totalsize > 0 and progress.tasks[task].total is None:
545+
progress.update(task, total=totalsize)
546+
progress.update(task, completed=blocknum * blocksize)
547+
548+
urllib.request.urlretrieve(tutorialurl, tutorialtargz, reporthook=reporthook) # NB: In case of ssl certificate issues use: with urllib.request.urlopen(tutorialurl, context=ssl.SSLContext()) as data, open(tutorialtargz, 'wb') as targz_fid: shutil.copyfileobj(data, targz_fid)
555549

556550
# Unzip the data in the target folder
557551
LOGGER.info(f"Unpacking the downloaded data in: {tutorialfolder}")

bidscoin/bidsapps/deface.py

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import pandas as pd
1515
import pydeface.utils as pdu
1616
import tempfile
17-
from tqdm import tqdm
18-
from tqdm.contrib.logging import logging_redirect_tqdm
17+
from rich.progress import track
1918
from pathlib import Path
2019
from importlib.util import find_spec
2120
if find_spec('bidscoin') is None:
@@ -76,77 +75,76 @@ def deface(bidsfolder: str, pattern: str, participant: list, force: bool, output
7675
jt.joinFiles = True
7776

7877
# Loop over bids subject/session-directories
79-
with logging_redirect_tqdm():
80-
for n, subject in enumerate(tqdm(subjects, unit='subject', colour='green', leave=False), 1):
81-
82-
subid = subject.name
83-
sessions = lsdirs(subject, 'ses-*')
84-
if not sessions:
85-
sessions = [subject]
86-
for session in sessions:
87-
88-
LOGGER.info('--------------------------------------')
89-
LOGGER.info(f"Processing ({n}/{len(subjects)}): {session}")
90-
91-
# Search for images that need to be defaced
92-
sesid = session.name if session.name.startswith('ses-') else ''
93-
for match in sorted([match for match in session.glob(pattern) if '.nii' in match.suffixes]):
94-
95-
# Construct the output filename and relative path name (used in BIDS)
96-
match_rel = match.relative_to(session).as_posix()
97-
if not output:
98-
outputfile = match
99-
outputfile_rel = match_rel
100-
elif output == 'derivatives':
101-
srcent, suffix = match.with_suffix('').stem.rsplit('_', 1) # Name without suffix, suffix
102-
ext = ''.join(match.suffixes) # Account for e.g. '.nii.gz'
103-
outputfile = bidsdir/'derivatives'/'deface'/subid/sesid/match.parent.name/f"{srcent}_space-orig_{suffix}{ext}"
104-
outputfile_rel = outputfile.relative_to(bidsdir).as_posix()
105-
else:
106-
outputfile = session/output/match.name
107-
outputfile_rel = outputfile.relative_to(session).as_posix()
108-
outputfile.parent.mkdir(parents=True, exist_ok=True)
109-
110-
# Check the json "Defaced" field to see if it has already been defaced
111-
outputjson = outputfile.with_suffix('').with_suffix('.json')
112-
if not force and outputjson.is_file():
113-
with outputjson.open('r') as sidecar:
114-
metadata = json.load(sidecar)
115-
if metadata.get('Defaced'):
116-
LOGGER.info(f"Skipping already defaced image: {match_rel} -> {outputfile_rel}")
117-
continue
118-
119-
# Deface the image
120-
LOGGER.info(f"Defacing: {match_rel} -> {outputfile_rel}")
121-
if cluster:
122-
jt.args = [str(match), '--outfile', str(outputfile), '--force'] + [item for pair in [[f"--{key}",val] for key,val in args.items()] for item in pair]
123-
jt.jobName = f"deface_{subid}_{sesid}"
124-
jt.outputPath = f"{os.getenv('HOSTNAME')}:{Path.cwd() if DEBUG else tempfile.gettempdir()}/{jt.jobName}.out"
125-
jobids.append(pbatch.runJob(jt))
126-
LOGGER.info(f"Your deface job has been submitted with ID: {jobids[-1]}")
127-
else:
128-
pdu.deface_image(str(match), str(outputfile), force=True, forcecleanup=True, **args)
129-
130-
# Add a json sidecar-file with the "Defaced" field
131-
inputjson = match.with_suffix('').with_suffix('.json')
132-
if inputjson.is_file():
133-
with inputjson.open('r') as sidecar:
134-
metadata = json.load(sidecar)
135-
else:
136-
metadata = {}
137-
metadata['Defaced'] = True
138-
with outputjson.open('w') as sidecar:
139-
json.dump(metadata, sidecar, indent=4)
140-
141-
# Update the scans.tsv file
142-
scans_tsv = session/f"{subid}{'_'+sesid if sesid else ''}_scans.tsv"
143-
bidsignore = (bidsdir/'.bidsignore').read_text().splitlines() if (bidsdir/'.bidsignore').is_file() else ['extra_data/']
144-
if output and not bids.check_ignore(output, bidsignore) and scans_tsv.is_file():
145-
LOGGER.info(f"Adding {outputfile_rel} to {scans_tsv}")
146-
scans_table = pd.read_csv(scans_tsv, sep='\t', index_col='filename')
147-
scans_table.loc[outputfile_rel] = scans_table.loc[match_rel]
148-
scans_table.sort_values(by=['acq_time','filename'], inplace=True)
149-
scans_table.to_csv(scans_tsv, sep='\t', encoding='utf-8')
78+
for n, subject in enumerate(track(subjects, description='[green]Subjects', transient=True), 1):
79+
80+
subid = subject.name
81+
sessions = lsdirs(subject, 'ses-*')
82+
if not sessions:
83+
sessions = [subject]
84+
for session in sessions:
85+
86+
LOGGER.info('--------------------------------------')
87+
LOGGER.info(f"Processing ({n}/{len(subjects)}): {session}")
88+
89+
# Search for images that need to be defaced
90+
sesid = session.name if session.name.startswith('ses-') else ''
91+
for match in sorted([match for match in session.glob(pattern) if '.nii' in match.suffixes]):
92+
93+
# Construct the output filename and relative path name (used in BIDS)
94+
match_rel = match.relative_to(session).as_posix()
95+
if not output:
96+
outputfile = match
97+
outputfile_rel = match_rel
98+
elif output == 'derivatives':
99+
srcent, suffix = match.with_suffix('').stem.rsplit('_', 1) # Name without suffix, suffix
100+
ext = ''.join(match.suffixes) # Account for e.g. '.nii.gz'
101+
outputfile = bidsdir/'derivatives'/'deface'/subid/sesid/match.parent.name/f"{srcent}_space-orig_{suffix}{ext}"
102+
outputfile_rel = outputfile.relative_to(bidsdir).as_posix()
103+
else:
104+
outputfile = session/output/match.name
105+
outputfile_rel = outputfile.relative_to(session).as_posix()
106+
outputfile.parent.mkdir(parents=True, exist_ok=True)
107+
108+
# Check the json "Defaced" field to see if it has already been defaced
109+
outputjson = outputfile.with_suffix('').with_suffix('.json')
110+
if not force and outputjson.is_file():
111+
with outputjson.open('r') as sidecar:
112+
metadata = json.load(sidecar)
113+
if metadata.get('Defaced'):
114+
LOGGER.info(f"Skipping already defaced image: {match_rel} -> {outputfile_rel}")
115+
continue
116+
117+
# Deface the image
118+
LOGGER.info(f"Defacing: {match_rel} -> {outputfile_rel}")
119+
if cluster:
120+
jt.args = [str(match), '--outfile', str(outputfile), '--force'] + [item for pair in [[f"--{key}",val] for key,val in args.items()] for item in pair]
121+
jt.jobName = f"deface_{subid}_{sesid}"
122+
jt.outputPath = f"{os.getenv('HOSTNAME')}:{Path.cwd() if DEBUG else tempfile.gettempdir()}/{jt.jobName}.out"
123+
jobids.append(pbatch.runJob(jt))
124+
LOGGER.info(f"Your deface job has been submitted with ID: {jobids[-1]}")
125+
else:
126+
pdu.deface_image(str(match), str(outputfile), force=True, forcecleanup=True, **args)
127+
128+
# Add a json sidecar-file with the "Defaced" field
129+
inputjson = match.with_suffix('').with_suffix('.json')
130+
if inputjson.is_file():
131+
with inputjson.open('r') as sidecar:
132+
metadata = json.load(sidecar)
133+
else:
134+
metadata = {}
135+
metadata['Defaced'] = True
136+
with outputjson.open('w') as sidecar:
137+
json.dump(metadata, sidecar, indent=4)
138+
139+
# Update the scans.tsv file
140+
scans_tsv = session/f"{subid}{'_'+sesid if sesid else ''}_scans.tsv"
141+
bidsignore = (bidsdir/'.bidsignore').read_text().splitlines() if (bidsdir/'.bidsignore').is_file() else ['extra_data/']
142+
if output and not bids.check_ignore(output, bidsignore) and scans_tsv.is_file():
143+
LOGGER.info(f"Adding {outputfile_rel} to {scans_tsv}")
144+
scans_table = pd.read_csv(scans_tsv, sep='\t', index_col='filename')
145+
scans_table.loc[outputfile_rel] = scans_table.loc[match_rel]
146+
scans_table.sort_values(by=['acq_time','filename'], inplace=True)
147+
scans_table.to_csv(scans_tsv, sep='\t', encoding='utf-8')
150148

151149
if cluster and jobids:
152150
LOGGER.info('')

0 commit comments

Comments
 (0)