Skip to content

Commit 92f1636

Browse files
authored
Merge pull request #932 from punch-mission/starfield
Properly interpolate starfield models and use correct time windows
2 parents 152e016 + 22279ca commit 92f1636

4 files changed

Lines changed: 108 additions & 28 deletions

File tree

changelog/932.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allows interpolation between starfield models.

punchbowl/auto/flows/level3.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ def level3_PTM_query_ready_files(session, pipeline_config: dict, reference_time=
9191
File.file_type == "PI")).order_by(File.date_obs.asc()).all()
9292
logger.info(f"{len(all_ready_files)} Level 3 PIM files need to be processed.")
9393

94+
starfield_window = pipeline_config["flows"]["level3_PTM"]["starfield_window"]
95+
9496
actually_ready_files = []
9597
for f in all_ready_files:
96-
# TODO put magic numbers in config
97-
valid_starfields = get_valid_starfields(session, f, timedelta_window=timedelta(days=14), file_type="PS")
98+
valid_starfields = get_valid_starfields(session, f, timedelta_window=timedelta(days=starfield_window), file_type="PS")
9899

99-
if len(valid_starfields) >= 1:
100+
if len(valid_starfields) >= 2:
100101
actually_ready_files.append(f)
101102
if len(actually_ready_files) >= max_n:
102103
break
@@ -113,15 +114,21 @@ def level3_PTM_construct_flow_info(level2_files: list[File], level3_file: File,
113114
state = "planned"
114115
creation_time = datetime.now()
115116
priority = pipeline_config["flows"][flow_type]["priority"]["initial"]
117+
starfield_window = pipeline_config["flows"][flow_type]["starfield_window"]
118+
119+
starfield_files = get_valid_starfields(session,
120+
level2_files[0],
121+
timedelta_window=timedelta(days=starfield_window),
122+
file_type="PS")
123+
124+
before_starfield_path = get_closest_before_file(level2_files[0], starfield_files).filename()
125+
after_starfield_path = get_closest_after_file(level2_files[0], starfield_files).filename()
116126

117-
starfield = get_closest_file(level2_files[0],
118-
get_valid_starfields(session,
119-
level2_files[0],
120-
timedelta_window=timedelta(days=14)))
121127
call_data = json.dumps(
122128
{
123129
"data_list": [level2_file.filename() for level2_file in level2_files],
124-
"starfield_background_path": starfield.filename(),
130+
"before_starfield_path": before_starfield_path,
131+
"after_starfield_path": after_starfield_path,
125132
},
126133
)
127134
return Flow(
@@ -166,7 +173,7 @@ def level3_PTM_scheduler_flow(pipeline_config_path=None, session=None, reference
166173

167174

168175
def level3_PTM_call_data_processor(call_data: dict, pipeline_config, session=None) -> dict:
169-
for key in ["data_list", "starfield_background_path"]:
176+
for key in ["data_list", "before_starfield_path", "after_starfield_path"]:
170177
call_data[key] = file_name_to_full_path(call_data[key], pipeline_config["root"])
171178
return call_data
172179

@@ -408,12 +415,13 @@ def level3_CTM_query_ready_files(session, pipeline_config: dict, reference_time=
408415
File.file_type == "CI")).order_by(File.date_obs.asc()).all()
409416
logger.info(f"{len(all_ready_files)} Level 3 CIM files need to be processed.")
410417

418+
starfield_window = pipeline_config["flows"]["level3_CTM"]["starfield_window"]
419+
411420
actually_ready_files = []
412421
for f in all_ready_files:
413-
# # TODO put magic numbers in config
414-
valid_starfields = get_valid_starfields(session, f, timedelta_window=timedelta(days=14), file_type="CS")
422+
valid_starfields = get_valid_starfields(session, f, timedelta_window=timedelta(days=starfield_window), file_type="CS")
415423

416-
if len(valid_starfields) >= 1:
424+
if len(valid_starfields) >= 2:
417425
actually_ready_files.append(f)
418426
if len(actually_ready_files) >= max_n:
419427
break
@@ -430,16 +438,21 @@ def level3_CTM_construct_flow_info(level2_files: list[File], level3_file: File,
430438
state = "planned"
431439
creation_time = datetime.now()
432440
priority = pipeline_config["flows"][flow_type]["priority"]["initial"]
441+
starfield_window = pipeline_config["flows"][flow_type]["starfield_window"]
442+
443+
starfield_files = get_valid_starfields(session,
444+
level2_files[0],
445+
timedelta_window=timedelta(days=starfield_window),
446+
file_type="CS")
447+
448+
before_starfield_path = get_closest_before_file(level2_files[0], starfield_files).filename()
449+
after_starfield_path = get_closest_after_file(level2_files[0], starfield_files).filename()
433450

434-
starfield = get_closest_file(level2_files[0],
435-
get_valid_starfields(session,
436-
level2_files[0],
437-
timedelta_window=timedelta(days=90),
438-
file_type="CS"))
439451
call_data = json.dumps(
440452
{
441453
"data_list": [level2_file.filename() for level2_file in level2_files],
442-
"starfield_background_path": starfield.filename(),
454+
"before_starfield_path": before_starfield_path,
455+
"after_starfield_path": after_starfield_path,
443456
},
444457
)
445458
return Flow(
@@ -484,7 +497,7 @@ def level3_CTM_scheduler_flow(pipeline_config_path=None, session=None, reference
484497

485498

486499
def level3_CTM_call_data_processor(call_data: dict, pipeline_config, session=None) -> dict:
487-
for key in ["data_list" , "starfield_background_path"]:
500+
for key in ["data_list" , "before_starfield_path", "after_starfield_path"]:
488501
call_data[key] = file_name_to_full_path(call_data[key], pipeline_config["root"])
489502
return call_data
490503

punchbowl/level3/stellar.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from prefect import get_run_logger
1515
from remove_starfield import ImageHolder, ImageProcessor, Starfield
1616
from remove_starfield.reducers import GaussianReducer
17+
from reproject import reproject_interp
18+
from reproject.mosaicking import find_optimal_celestial_wcs
1719
from scipy.ndimage import percentile_filter
1820
from scipy.stats import circmean
1921
from solpolpy import resolve
@@ -25,8 +27,9 @@
2527
calculate_helio_wcs_from_celestial,
2628
celestial_north_from_wcs,
2729
)
30+
from punchbowl.exceptions import InvalidDataError
2831
from punchbowl.prefect import punch_flow, punch_task
29-
from punchbowl.util import average_datetime
32+
from punchbowl.util import average_datetime, interpolate_data
3033

3134
warnings.filterwarnings("ignore")
3235

@@ -334,7 +337,8 @@ def generate_starfield_background(
334337

335338
@punch_task
336339
def subtract_starfield_background_task(data_object: NDCube,
337-
starfield_background_path: str | None,
340+
before_starfield_path: str | None,
341+
after_starfield_path: str | None,
338342
is_polarized: bool = False) -> NDCube:
339343
"""
340344
Subtracts a background starfield from an input data frame.
@@ -346,8 +350,10 @@ def subtract_starfield_background_task(data_object: NDCube,
346350
----------
347351
data_object : NDCube
348352
A NDCube data frame to be background subtracted
349-
starfield_background_path : str
350-
path to a NDCube background starfield map
353+
before_starfield_path : str
354+
path to a NDCube background starfield map centered before the observation
355+
after_starfield_path : str
356+
path to a NDCube background starfield map centered after the observation
351357
is_polarized : bool
352358
whether the data is polarized
353359
@@ -360,17 +366,69 @@ def subtract_starfield_background_task(data_object: NDCube,
360366
logger = get_run_logger()
361367
logger.info("subtract_starfield_background started")
362368

363-
if starfield_background_path is None:
369+
if before_starfield_path is None and after_starfield_path is None:
364370
output = data_object
365371
output.meta.history.add_now("LEVEL3-subtract_starfield_background",
366372
"starfield subtraction skipped since path is empty")
373+
elif before_starfield_path is None or after_starfield_path is None:
374+
raise InvalidDataError("subtract_starfield_background requires two input starfield models.")
367375
else:
368-
star_datacube = load_ndcube_from_fits(starfield_background_path)
369-
wcs_celestial = calculate_celestial_wcs_from_helio(star_datacube.wcs)
370-
wcs_celestial.wcs.cdelt[0] = wcs_celestial.wcs.cdelt[0] * -1
376+
star_datacube_before = load_ndcube_from_fits(before_starfield_path)
377+
star_datacube_after = load_ndcube_from_fits(after_starfield_path)
378+
379+
shape_before = star_datacube_before.data.shape[-2:]
380+
shape_after = star_datacube_after.data.shape[-2:]
381+
382+
wcs_celestial_before = calculate_celestial_wcs_from_helio(star_datacube_before.wcs)
383+
wcs_celestial_before.wcs.cdelt[0] = wcs_celestial_before.wcs.cdelt[0] * -1
384+
385+
wcs_celestial_after = calculate_celestial_wcs_from_helio(star_datacube_after.wcs)
386+
wcs_celestial_after.wcs.cdelt[0] = wcs_celestial_after.wcs.cdelt[0] * -1
387+
388+
# TODO - Test with polarized data...
389+
union_wcs, union_shape = find_optimal_celestial_wcs(
390+
[(shape_before, wcs_celestial_before),
391+
(shape_after, wcs_celestial_after)],
392+
auto_rotate=False, projection="CAR")
393+
394+
starfield_reprojected_before = reproject_interp(
395+
(np.stack([star_datacube_before.data, star_datacube_before.uncertainty.array], axis=0),
396+
wcs_celestial_before),
397+
union_wcs,
398+
shape_out=union_shape,
399+
return_footprint=False)
400+
401+
starfield_reprojected_after = reproject_interp(
402+
(np.stack([star_datacube_after.data, star_datacube_after.uncertainty.array], axis=0),
403+
wcs_celestial_after),
404+
union_wcs,
405+
shape_out=union_shape,
406+
return_footprint=False)
407+
408+
starfield_before = NDCube(data=starfield_reprojected_before[0],
409+
uncertainty = StdDevUncertainty(starfield_reprojected_before[1]),
410+
wcs = union_wcs, meta=star_datacube_before.meta)
411+
starfield_after = NDCube(data=starfield_reprojected_after[0],
412+
uncertainty = StdDevUncertainty(starfield_reprojected_after[1]),
413+
wcs = union_wcs, meta=star_datacube_after.meta)
414+
415+
starfield_data_interpolated, starfield_uncert_interpolated = interpolate_data(starfield_before,
416+
starfield_after,
417+
data_object.meta.datetime,
418+
allow_extrapolation=False,
419+
and_uncertainty=True,
420+
infill_nans=True)
421+
# TODO - metadata...
422+
star_datacube = NDCube(data=starfield_data_interpolated,
423+
uncertainty=StdDevUncertainty(starfield_uncert_interpolated),
424+
wcs = union_wcs,
425+
meta=star_datacube_before.meta)
426+
wcs_celestial = union_wcs
371427

372428
original_mask = data_object.data == 0
373429

430+
# TODO - Think about where to do the interpolation at this stage...
431+
# Is this going to require a change in the subtraction code to avoid more reprojections back and forth?
374432
if is_polarized:
375433
starfield_model_m = Starfield(np.stack((star_datacube.data[0], star_datacube.uncertainty.array[0])),
376434
wcs_celestial[0])

punchbowl/util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def nan_gaussian(image: np.ndarray, sigma: float) -> np.ndarray:
270270

271271
def interpolate_data(data_before: NDCube, data_after:NDCube, reference_time: datetime, time_key: str = "DATE-OBS",
272272
allow_extrapolation: bool = False, and_uncertainty: bool = False,
273-
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
273+
infill_nans: bool = False) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
274274
"""Interpolates between two data objects."""
275275
before_date = parse_datetime(data_before.meta[time_key].value + " UTC").timestamp()
276276
after_date = parse_datetime(data_after.meta[time_key].value + " UTC").timestamp()
@@ -296,10 +296,18 @@ def interpolate_data(data_before: NDCube, data_after:NDCube, reference_time: dat
296296
data_interpolated = ((data_after.data - data_before.data)
297297
* (observation_date - before_date) / (after_date - before_date)
298298
+ data_before.data)
299+
if infill_nans:
300+
data_before_nan = np.isnan(data_before.data)
301+
data_after_nan = np.isnan(data_after.data)
302+
data_interpolated[data_before_nan] = data_after.data[data_before_nan]
303+
data_interpolated[data_after_nan] = data_before.data[data_after_nan]
299304
if and_uncertainty:
300305
uncert_interpolated = ((data_after.uncertainty.array - data_before.uncertainty.array)
301306
* (observation_date - before_date) / (after_date - before_date)
302307
+ data_before.uncertainty.array)
308+
if infill_nans:
309+
uncert_interpolated[data_before_nan] = data_after.uncertainty.array[data_before_nan]
310+
uncert_interpolated[data_after_nan] = data_before.uncertainty.array[data_after_nan]
303311

304312
if and_uncertainty:
305313
return data_interpolated, uncert_interpolated

0 commit comments

Comments
 (0)