Skip to content

Commit 46579a1

Browse files
authored
Patch for Dask (#74)
* Weird errors * Casting * Add logs * Docs
1 parent 9729bcb commit 46579a1

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

racs_tools/beamcon_2D.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import logging
99
import sys
10+
import traceback
1011
from pathlib import Path
1112
from typing import Literal, NamedTuple, Optional
1213

@@ -318,13 +319,19 @@ def get_common_beam(
318319
beams = Beams(beams=beams_list)
319320

320321
# Init flags - False is good, True is bad
321-
flags = np.array([False for beam in beams])
322+
flags = np.array([False for _ in beams])
322323

323324
# Flag zero beams
324325
flags = np.array([beam == ZERO_BEAM for beam in beams]) | flags
325326

326327
if cutoff is not None:
327-
flags = beams.major.to(u.arcsec).value > cutoff | flags
328+
# Make an uncessary copy to in case we're using with Dask
329+
# Whacky errors abound...
330+
_majors = beams.major.to(u.arcsec).value
331+
major_values = np.copy(_majors).astype(float)
332+
major_flags = np.array(major_values > cutoff).astype(bool)
333+
flags = major_flags | flags
334+
328335
if np.all(flags):
329336
logger.critical(
330337
"All beams are larger than cutoff. All outputs will be blanked!"

tests/test_2d.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
import pytest
1313
from astropy.io import fits
14-
from radio_beam import Beam
14+
from radio_beam import Beam, Beams
1515

1616
from racs_tools import au2, beamcon_2D
1717
from racs_tools.convolve_uv import smooth
@@ -279,3 +279,14 @@ def test_get_common_beam(make_2d_image, make_2d_image_smaller):
279279
common_beam.major.to(u.arcsec).value
280280
== make_2d_image_smaller.beam.major.to(u.arcsec).value
281281
)
282+
283+
284+
def test_flags():
285+
beams = Beams([1, 1, 1, 1])
286+
flags = np.array([False for beam in beams])
287+
cutoff = 1
288+
289+
with pytest.raises(TypeError) as e_info:
290+
flags = beams.major.to(u.arcsec) > cutoff * u.arcsec | flags
291+
292+
flags = beams.major.to(u.arcsec).value > cutoff | flags

0 commit comments

Comments
 (0)