77
88import os
99import types
10- import coloredlogs
1110import logging
1211import shutil
1312import sys
2120from importlib .util import spec_from_file_location , module_from_spec
2221from pathlib import Path
2322from typing import Union
24- from tqdm import tqdm
25- from tqdm .contrib .logging import logging_redirect_tqdm
23+ from rich .progress import track , Progress , TextColumn , BarColumn , DownloadColumn , TransferSpeedColumn , TimeRemainingColumn , TimeElapsedColumn
24+ from rich .logging import RichHandler
25+ from rich .theme import Theme
26+ from rich .console import Console
2627from importlib .util import find_spec
2728if find_spec ('bidscoin' ) is None :
2829 sys .path .append (str (Path (__file__ ).parents [1 ]))
3132LOGGER = logging .getLogger (__name__ )
3233
3334
34- class TqdmUpTo (tqdm ):
35-
36- def update_to (self , b = 1 , bsize = 1 , tsize = None ):
37- """
38- Adds a tqdm progress bar to urllib.request.urlretrieve()
39- https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
40-
41- :param b: Number of blocks transferred so far [default: 1].
42- :param bsize: Size of each block (in tqdm units) [default: 1].
43- :param tsize: Total size (in tqdm units). If [default: None] remains unchanged.
44- """
45- if tsize is not None :
46- self .total = tsize
47- self .update (b * bsize - self .n ) # will also set self.n = b * bsize
48-
49-
5035def drmaa_nativespec (specs : str , session ) -> str :
5136 """
5237 Converts (CLI default) native Torque walltime and memory specifications to the DRMAA implementation (currently only Slurm is supported)
@@ -75,42 +60,42 @@ def drmaa_nativespec(specs: str, session) -> str:
7560
7661def synchronize (pbatch , jobids : list , event : str , wait : int = 15 ):
7762 """
78- Shows tqdm progress bars for queued and running DRMAA jobs. Waits until all jobs have finished +
79- some extra wait time to give NAS systems the opportunity to fully synchronize
63+ Shows Rich progress bars for queued and running DRMAA jobs. Waits until all jobs have finished +
64+ some extra wait time to give NAS systems the opportunity to fully synchronize.
8065
8166 :param pbatch: The DRMAA session
8267 :param jobids: The job ids
8368 :param event: The event that is passed to trackusage()
8469 :param wait: The extra wait time for the NAS
8570 """
8671
87- if jobids :
88- match = re .search (r"(slurm|pbs|torque|sge|lsf|condor|uge)" , pbatch .drmaaImplementation .lower ())
89- trackusage (f"{ event } _{ match .group (1 ) if match else 'drmaa' } " )
90- else :
72+ if not jobids :
9173 return
9274
93- with logging_redirect_tqdm ():
75+ match = re .search (r"(slurm|pbs|torque|sge|lsf|condor|uge)" , pbatch .drmaaImplementation .lower ())
76+ trackusage (f"{ event } _{ match .group (1 ) if match else 'drmaa' } " )
77+
78+ with Progress (TextColumn ('{task.description}' ), BarColumn (), TextColumn ('{task.completed}/{task.total}' ), TimeElapsedColumn (), transient = True ) as progress :
79+
80+ qtask = progress .add_task ('[white]Queued ' , total = len (jobids ))
81+ rtask = progress .add_task ('[green]Running ' , total = len (jobids ))
9482
95- qbar = tqdm (total = len (jobids ), desc = 'Queued ' , unit = 'job' , leave = False , bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]' )
96- rbar = tqdm (total = len (jobids ), desc = 'Running' , unit = 'job' , leave = False , bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]' , colour = 'green' )
9783 done = 0
9884 while done < len (jobids ):
9985 jobs = [pbatch .jobStatus (jobid ) for jobid in jobids ]
10086 done = sum (status in ('done' , 'failed' , 'undetermined' ) for status in jobs )
101- qbar .n = sum (status == 'queued_active' for status in jobs )
102- rbar .n = sum (status == 'running' for status in jobs )
103- qbar .refresh (), rbar .refresh ()
87+ qcount = sum (status == 'queued_active' for status in jobs )
88+ rcount = sum (status == 'running' for status in jobs )
89+ progress .update (qtask , completed = qcount )
90+ progress .update (rtask , completed = rcount )
10491 time .sleep (2 )
105- qbar .close (), rbar .close ()
10692
107- failedjobs = [jobid for jobid in jobids if pbatch .jobStatus (jobid )== 'failed' ]
108- if failedjobs :
93+ if failedjobs := [jobid for jobid in jobids if pbatch .jobStatus (jobid ) == 'failed' ]:
10994 LOGGER .error (f"{ len (failedjobs )} HPC jobs failed to run:\n { failedjobs } \n This may well be due to an underspecified `--cluster` input option (e.g. not enough memory)" )
11095
111- # Give NAS systems some time to fully synchronize
112- for t in tqdm (range (wait * 100 ), desc = 'synchronizing ' , leave = False , bar_format = '{l_bar}{bar}| [{elapsed}]' ):
113- time .sleep (.01 )
96+ # Synchronization wait bar
97+ for t in track (range (wait * 100 ), description = '[cyan]Synchronizing ' , transient = True ):
98+ time .sleep (0 .01 )
11499
115100
116101def setup_logging (logfile : Path = Path ()):
@@ -129,11 +114,11 @@ def setup_logging(logfile: Path=Path()):
129114
130115 # Set the default formats
131116 if DEBUG :
132- fmt = '%(asctime)s - %(name)s - %(levelname)s | %(message)s'
133- cfmt = '%(levelname)s - %( name)s | %(message)s'
117+ fmt = '%(asctime)s - %(name)s | %(message)s'
118+ cfmt = '%(name)s | %(message)s'
134119 else :
135120 fmt = '%(asctime)s - %(levelname)s | %(message)s'
136- cfmt = '%(levelname)s | %( message)s'
121+ cfmt = '%(message)s'
137122 datefmt = '%Y-%m-%d %H:%M:%S'
138123
139124 # Register custom log levels
@@ -161,13 +146,16 @@ def success(self, message, *args, **kws):
161146 logger = logging .getLogger ()
162147 logger .setLevel ('BCDEBUG' if DEBUG else 'VERBOSE' )
163148
164- # Add the console streamhandler and bring some color to those boring logs! :-)
165- coloredlogs .install (level = 'BCDEBUG' if DEBUG else 'VERBOSE' if not logfile .name else 'INFO' , fmt = cfmt , datefmt = datefmt ) # NB: Using tqdm sets the streamhandler level to 0, see: https://github.com/tqdm/tqdm/pull/1235
166- coloredlogs .DEFAULT_LEVEL_STYLES ['verbose' ]['color' ] = 245 # = Gray
149+ # Add the Rich console handler and bring some color to those boring logs! :-)
150+ console = Console (theme = Theme ({'logging.level.verbose' : 'grey50' , 'logging.level.success' : 'green bold' , 'logging.level.bcdebug' : 'bright_yellow' }))
151+ consolehandler = RichHandler (console = console , show_time = False , show_level = True , show_path = DEBUG , rich_tracebacks = True , markup = True , level = 'BCDEBUG' if DEBUG else 'VERBOSE' if not logfile .name else 'INFO' )
152+ consolehandler .set_name ('console' )
153+ consolehandler .setFormatter (logging .Formatter (fmt = cfmt , datefmt = datefmt ))
154+ logger .addHandler (consolehandler )
167155
168156 if logfile .name :
169157
170- # Add the log filehandler
158+ # Add the verbose filehandler
171159 logfile .parent .mkdir (parents = True , exist_ok = True ) # Create the log dir if it does not exist
172160 formatter = logging .Formatter (fmt = fmt , datefmt = datefmt )
173161 loghandler = logging .FileHandler (logfile )
@@ -184,7 +172,7 @@ def success(self, message, *args, **kws):
184172 logger .addHandler (errorhandler )
185173
186174 if DEBUG :
187- LOGGER .info ('\t <<<<<<<<<< Running BIDScoin in DEBUG mode >>>>>>>>>>' )
175+ LOGGER .info ('\t [bold bright_yellow] <<<<<<<<<< Running BIDScoin in DEBUG mode >>>>>>>>>>' )
188176 settracking ('show' )
189177
190178
@@ -550,8 +538,14 @@ def pulltutorialdata(tutorialfolder: str) -> None:
550538
551539 # Download the data
552540 LOGGER .info (f"Downloading the tutorial dataset..." )
553- with TqdmUpTo (unit = 'B' , unit_scale = True , unit_divisor = 1024 , miniters = 1 , desc = tutorialtargz .name ) as t :
554- urllib .request .urlretrieve (tutorialurl , tutorialtargz , reporthook = t .update_to ) # NB: In case of ssl certificate issues use: with urllib.request.urlopen(tutorialurl, context=ssl.SSLContext()) as data, open(tutorialtargz, 'wb') as targz_fid: shutil.copyfileobj(data, targz_fid)
541+ with Progress (TextColumn ('[bold blue]{task.fields[filename]}' ), BarColumn (), DownloadColumn (), TransferSpeedColumn (), TimeRemainingColumn ()) as progress :
542+ task = progress .add_task ('[cyan]Download' , filename = tutorialtargz .name , total = None )
543+ def reporthook (blocknum : int , blocksize : int , totalsize : int ):
544+ if totalsize > 0 and progress .tasks [task ].total is None :
545+ progress .update (task , total = totalsize )
546+ progress .update (task , completed = blocknum * blocksize )
547+
548+ urllib .request .urlretrieve (tutorialurl , tutorialtargz , reporthook = reporthook ) # NB: In case of ssl certificate issues use: with urllib.request.urlopen(tutorialurl, context=ssl.SSLContext()) as data, open(tutorialtargz, 'wb') as targz_fid: shutil.copyfileobj(data, targz_fid)
555549
556550 # Unzip the data in the target folder
557551 LOGGER .info (f"Unpacking the downloaded data in: { tutorialfolder } " )
0 commit comments