Skip to content

Commit 37f0636

Browse files
authored
Merge pull request #21 from AlecThomson/args
Better main args
2 parents 417fe5a + c1942dc commit 37f0636

File tree

3 files changed

+160
-92
lines changed

3 files changed

+160
-92
lines changed

racs_tools/beamcon_2D.py

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import schwimmbad
2020
import psutil
2121
from tqdm import tqdm
22-
import logging as log
22+
import logging as logger
2323

2424
#############################################
2525
#### ADAPTED FROM SCRIPT BY T. VERNSTROM ####
@@ -75,7 +75,7 @@ def getbeam(
7575
Returns:
7676
Tuple[Beam, float]: Convolving beam and scaling factor.
7777
"""
78-
log.info(f"Current beam is {old_beam!r}")
78+
logger.info(f"Current beam is {old_beam!r}")
7979

8080
if cutoff is not None and old_beam.major.to(u.arcsec) > cutoff * u.arcsec:
8181
return np.nan, np.nan
@@ -87,14 +87,14 @@ def getbeam(
8787
pa=0 * u.deg,
8888
)
8989
fac = 1.0
90-
log.warning(
90+
logger.warning(
9191
f"New beam {new_beam!r} and old beam {old_beam!r} are the same. Won't attempt convolution."
9292
)
9393
return conbm, fac
9494
try:
9595
conbm = new_beam.deconvolve(old_beam)
9696
except Exception as err:
97-
log.warning(f"Could not deconvolve. New: {new_beam!r}, Old: {old_beam!r}")
97+
logger.warning(f"Could not deconvolve. New: {new_beam!r}, Old: {old_beam!r}")
9898
raise err
9999
fac, amp, outbmaj, outbmin, outbpa = au2.gauss_factor(
100100
beamConv=[
@@ -123,7 +123,7 @@ def getimdata(cubenm: str) -> dict:
123123
Returns:
124124
dict: Data and metadata.
125125
"""
126-
log.info(f"Getting image data from {cubenm}")
126+
logger.info(f"Getting image data from {cubenm}")
127127
with fits.open(cubenm, memmap=True, mode="denywrite") as hdu:
128128

129129
w = astropy.wcs.WCS(hdu[0])
@@ -172,7 +172,7 @@ def savefile(
172172
outdir (str, optional): Output directory. Defaults to ".".
173173
"""
174174
outfile = f"{outdir}/{filename}"
175-
log.info(f"Saving to {outfile}")
175+
logger.info(f"Saving to {outfile}")
176176
beam = final_beam
177177
header = beam.attach_to_header(header)
178178
fits.writeto(outfile, newimage.astype(np.float32), header=header, overwrite=True)
@@ -203,7 +203,7 @@ def worker(
203203
Returns:
204204
dict: Output data.
205205
"""
206-
log.info(f"Working on {file}")
206+
logger.info(f"Working on {file}")
207207

208208
if outdir is None:
209209
outdir = os.path.dirname(file)
@@ -227,7 +227,10 @@ def worker(
227227

228228
datadict.update({"conbeam": conbeam, "final_beam": new_beam, "sfactor": sfactor})
229229
if not dryrun:
230-
if (
230+
if np.isnan(sfactor) or np.isnan(conbeam):
231+
logger.warning(f"Setting {outfile} to NaN")
232+
newim = datadict["image"] * np.nan
233+
elif (
231234
conbeam == Beam(major=0 * u.deg, minor=0 * u.deg, pa=0 * u.deg)
232235
and sfactor == 1
233236
):
@@ -309,7 +312,7 @@ def getmaxbeam(
309312
tolerance=tolerance, epsilon=epsilon, nsamps=nsamps
310313
)
311314
except BeamError:
312-
log.warning(
315+
logger.warning(
313316
"Couldn't find common beam with defaults\nTrying again with smaller tolerance"
314317
)
315318
cmn_beam = beams[~flags].common_beam(
@@ -352,14 +355,14 @@ def getmaxbeam(
352355
* u.arcsec,
353356
pa=round_up(nyq_beam.pa.to(u.deg), decimals=2),
354357
)
355-
log.info(f"Smallest common Nyquist sampled beam is: {nyq_beam!r}")
358+
logger.info(f"Smallest common Nyquist sampled beam is: {nyq_beam!r}")
356359
if target_beam is not None:
357360
if target_beam < nyq_beam:
358-
log.warning("TARGET BEAM WILL BE UNDERSAMPLED!")
361+
logger.warning("TARGET BEAM WILL BE UNDERSAMPLED!")
359362
raise Exception("CAN'T UNDERSAMPLE BEAM - EXITING")
360363
else:
361-
log.warning("COMMON BEAM WILL BE UNDERSAMPLED!")
362-
log.warning("SETTING COMMON BEAM TO NYQUIST BEAM")
364+
logger.warning("COMMON BEAM WILL BE UNDERSAMPLED!")
365+
logger.warning("SETTING COMMON BEAM TO NYQUIST BEAM")
363366
cmn_beam = nyq_beam
364367

365368
return cmn_beam, beams
@@ -424,55 +427,78 @@ def writelog(output: List[Dict], commonbeam_log: str):
424427
ascii.write(
425428
commonbeam_tab, output=commonbeam_log, format="commented_header", overwrite=True
426429
)
427-
log.info(f"Convolving log written to {commonbeam_log}")
430+
logger.info(f"Convolving log written to {commonbeam_log}")
428431

429432

430-
def main(pool, args):
433+
def main(
434+
pool,
435+
infile: list = [],
436+
prefix: str = None,
437+
suffix: str = None,
438+
outdir: str = None,
439+
conv_mode: str = "robust",
440+
dryrun: bool = False,
441+
bmaj: float = None,
442+
bmin: float = None,
443+
bpa: float = None,
444+
log: str = None,
445+
circularise: bool = False,
446+
cutoff: float = None,
447+
tolerance: float = 0.0001,
448+
nsamps: int = 200,
449+
epsilon: float = 0.0005,
450+
):
431451
"""Main script.
432452
433453
Args:
434-
pool (method): Multiprocessing or schwimmbad Pool.
435-
args (Namespace): Commandline args.
454+
pool (mp.Pool): Multiprocessing pool.
455+
infile (list, optional): List of images to convolve. Defaults to [].
456+
prefix (str, optional): Output prefix. Defaults to None.
457+
suffix (str, optional): Output suffix. Defaults to None.
458+
outdir (str, optional): Output directory. Defaults to None.
459+
conv_mode (str, optional): Colvolution mode. Defaults to "robust".
460+
dryrun (bool, optional): Do a dryrun. Defaults to False.
461+
bmaj (float, optional): Target BMAJ. Defaults to None.
462+
bmin (float, optional): Target BMIN. Defaults to None.
463+
bpa (float, optional): Target BPA. Defaults to None.
464+
log (str, optional): Input beamlog. Defaults to None.
465+
circularise (bool, optional): Make beam circular. Defaults to False.
466+
cutoff (float, optional): Cutoff beams. Defaults to None.
467+
tolerance (float, optional): Common tolerance. Defaults to 0.0001.
468+
nsamps (int, optional): Common samples. Defaults to 200.
469+
epsilon (float, optional): Common epsilon. Defaults to 0.0005.
470+
436471
437472
Raises:
438473
Exception: If no files are found.
439474
Exception: If invalid convolution mode is specified.
440475
Exception: If partial target beam is specified.
441476
Exception: If target beam cannot be used.
442477
"""
443-
if args.dryrun:
444-
log.info("Doing a dry run -- no files will be saved")
478+
479+
if dryrun:
480+
logger.info("Doing a dry run -- no files will be saved")
445481
# Fix up outdir
446-
outdir = args.outdir
447482
if outdir is not None:
448-
if outdir[-1] == "/":
449-
outdir = outdir[:-1]
450-
else:
451-
outdir = None
483+
outdir = os.path.abspath(outdir)
452484

453485
# Get file list
454-
files = sorted(args.infile)
486+
files = sorted(infile)
455487
if files == []:
456488
raise Exception("No files found!")
457489

458490
# Parse args
459-
460-
conv_mode = args.conv_mode
461-
log.info(f"Convolution mode: {conv_mode}")
491+
logger.info(f"Convolution mode: {conv_mode}")
462492
if not conv_mode in ["robust", "scipy", "astropy", "astropy_fft"]:
463493
raise Exception("Please select valid convolution method!")
464494

465-
log.info(f"Using convolution method {conv_mode}")
495+
logger.info(f"Using convolution method {conv_mode}")
466496
if conv_mode == "robust":
467-
log.info("This is the most robust method. And fast!")
497+
logger.info("This is the most robust method. And fast!")
468498
elif conv_mode == "scipy":
469-
log.info("This fast, but not robust to NaNs or small PSF changes")
499+
logger.info("This fast, but not robust to NaNs or small PSF changes")
470500
else:
471-
log.info("This is slower, but robust to NaNs, but not to small PSF changes")
472-
473-
bmaj = args.bmaj
474-
bmin = args.bmin
475-
bpa = args.bpa
501+
logger.info("This is slower, but robust to NaNs, but not to small PSF changes")
476502

477503
nonetest = [test is None for test in [bmaj, bmin, bpa]]
478504

@@ -484,21 +510,21 @@ def main(pool, args):
484510

485511
elif not all(nonetest) and not any(nonetest):
486512
target_beam = Beam(bmaj * u.arcsec, bmin * u.arcsec, bpa * u.deg)
487-
log.info(f"Target beam is {target_beam!r}")
513+
logger.info(f"Target beam is {target_beam!r}")
488514

489515
# Find smallest common beam
490516
big_beam, allbeams = getmaxbeam(
491517
files,
492518
conv_mode=conv_mode,
493519
target_beam=target_beam,
494-
cutoff=args.cutoff,
495-
tolerance=args.tolerance,
496-
nsamps=args.nsamps,
497-
epsilon=args.epsilon,
520+
cutoff=cutoff,
521+
tolerance=tolerance,
522+
nsamps=nsamps,
523+
epsilon=epsilon,
498524
)
499525

500526
if target_beam is not None:
501-
log.info("Checking that target beam will deconvolve...")
527+
logger.info("Checking that target beam will deconvolve...")
502528

503529
mask_count = 0
504530
failed = []
@@ -507,7 +533,7 @@ def main(pool, args):
507533
zip(allbeams, files),
508534
total=len(allbeams),
509535
desc="Deconvolving",
510-
disable=(log.root.level > log.INFO),
536+
disable=(logger.root.level > logger.INFO),
511537
)
512538
):
513539
try:
@@ -516,8 +542,8 @@ def main(pool, args):
516542
mask_count += 1
517543
failed.append(file)
518544
if mask_count > 0:
519-
log.warning("The following images could not reach target resolution:")
520-
log.warning(failed)
545+
logger.warning("The following images could not reach target resolution:")
546+
logger.warning(failed)
521547

522548
raise Exception("Please choose a larger target beam!")
523549

@@ -527,15 +553,15 @@ def main(pool, args):
527553
else:
528554
new_beam = big_beam
529555

530-
if args.circularise:
531-
log.info("Circular beam requested, setting BMIN=BMAJ and BPA=0")
556+
if circularise:
557+
logger.info("Circular beam requested, setting BMIN=BMAJ and BPA=0")
532558
new_beam = Beam(
533559
major=new_beam.major,
534560
minor=new_beam.major,
535561
pa=0 * u.deg,
536562
)
537563

538-
log.info(f"Final beam is {new_beam!r}")
564+
logger.info(f"Final beam is {new_beam!r}")
539565

540566
output = list(
541567
pool.map(
@@ -544,19 +570,20 @@ def main(pool, args):
544570
outdir=outdir,
545571
new_beam=new_beam,
546572
conv_mode=conv_mode,
547-
suffix=args.suffix,
548-
prefix=args.prefix,
549-
cutoff=args.cutoff,
550-
dryrun=args.dryrun,
573+
suffix=suffix,
574+
prefix=prefix,
575+
cutoff=cutoff,
576+
dryrun=dryrun,
551577
),
552578
files,
553579
)
554580
)
555581

556-
if args.log is not None:
557-
writelog(output, args.log)
582+
if log is not None:
583+
writelog(output, log)
558584

559-
log.info("Done!")
585+
logger.info("Done!")
586+
return new_beam
560587

561588

562589
def cli():
@@ -748,16 +775,16 @@ def cli():
748775
except AttributeError:
749776
myPE = 0
750777
if args.verbosity == 1:
751-
log.basicConfig(
778+
logger.basicConfig(
752779
filename=args.logfile,
753-
level=log.INFO,
780+
level=logger.INFO,
754781
format=f"[{myPE}] %(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
755782
datefmt="%Y-%m-%d %H:%M:%S",
756783
)
757784
elif args.verbosity >= 2:
758-
log.basicConfig(
785+
logger.basicConfig(
759786
filename=args.logfile,
760-
level=log.DEBUG,
787+
level=logger.DEBUG,
761788
format=f"[{myPE}] %(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
762789
datefmt="%Y-%m-%d %H:%M:%S",
763790
)
@@ -767,7 +794,14 @@ def cli():
767794
pool.wait()
768795
sys.exit(0)
769796

770-
main(pool, args)
797+
arg_dict = vars(args)
798+
# pop unwanted arguments
799+
_ = arg_dict.pop("mpi")
800+
_ = arg_dict.pop("n_cores")
801+
_ = arg_dict.pop("verbosity")
802+
_ = arg_dict.pop("logfile")
803+
804+
_ = main(pool, **arg_dict)
771805
pool.close()
772806

773807

0 commit comments

Comments
 (0)