Skip to content

Commit 841c086

Browse files
authored
Merge pull request #63 from punch-mission/develop
Merge develop
2 parents 93ab305 + 0d79050 commit 841c086

8 files changed

+70
-34
lines changed

CODE_OF_CONDUCT.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ We welcome and encourage contributions to this project in the form of pull reque
22

33
Contributors are expected
44

5-
to be respectful and constructive; and
6-
to enforce 1.
5+
1. to be respectful and constructive; and
6+
2. to enforce 1.
7+
78
This code of conduct applies to all project-related communication that takes place on or at mailing lists, forums, social media, conferences, meetings, and social events.
89

9-
This code of conduct is from [the HAPI-Server project](https://github.com/hapi-server/client-python/blob/master/CODE_OF_CONDUCT.md).
10+
This code of conduct is adapted from [the HAPI-Server project](https://github.com/hapi-server/client-python/blob/master/CODE_OF_CONDUCT.md).

regularizepsf/corrector.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from pathlib import Path
55
from typing import Any, Tuple
66

7-
import deepdish as dd
87
import dill
8+
import h5py
99
import numpy as np
1010
from numpy.fft import fft2, ifft2, ifftshift
1111

@@ -264,11 +264,21 @@ def __getitem__(self, xy: Tuple[int, int]) -> np.ndarray:
264264
raise UnevaluatedPointError(f"Model not evaluated at {xy}.")
265265

266266
def save(self, path: str) -> None:
267-
dd.io.save(path, (self._evaluations, self._target_evaluation))
267+
with h5py.File(path, 'w') as f:
268+
eval_grp = f.create_group('evaluations')
269+
for key, val in self._evaluations.items():
270+
eval_grp.create_dataset(f'{key}', data=val)
271+
f.create_dataset('target', data=self._target_evaluation)
268272

269273
@classmethod
270274
def load(cls, path: str) -> ArrayCorrector:
271-
evaluations, target_evaluation = dd.io.load(path)
275+
with h5py.File(path, 'r') as f:
276+
target_evaluation = f['target'][:].copy()
277+
278+
evaluations = dict()
279+
for key, val in f['evaluations'].items():
280+
parsed_key = tuple(int(val) for val in key.replace("(", "").replace(")", "").split(","))
281+
evaluations[parsed_key] = val[:].copy()
272282
return cls(evaluations, target_evaluation)
273283

274284
def simulate_observation(self, image: np.ndarray) -> np.ndarray:

regularizepsf/fitter.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numbers import Real
88
from typing import Any, Dict, Generator, List, Optional, Tuple
99

10-
import deepdish as dd
10+
import h5py
1111
import numpy as np
1212
import sep
1313
from astropy.io import fits
@@ -172,6 +172,7 @@ def fit(self,
172172
173173
"""
174174

175+
@abc.abstractmethod
175176
def save(self, path: str) -> None:
176177
"""Save the PatchCollection to a file
177178
@@ -184,9 +185,10 @@ def save(self, path: str) -> None:
184185
-------
185186
None
186187
"""
187-
dd.io.save(path, self.patches)
188+
188189

189190
@classmethod
191+
@abc.abstractmethod
190192
def load(cls, path: str) -> PatchCollectionABC:
191193
"""Load a PatchCollection from a file
192194
@@ -200,7 +202,6 @@ def load(cls, path: str) -> PatchCollectionABC:
200202
PatchCollectionABC
201203
the new patch collection
202204
"""
203-
return cls(dd.io.load(path))
204205

205206
def keys(self) -> List:
206207
"""Gets identifiers for all patches"""
@@ -476,7 +477,7 @@ def average(self, corners: np.ndarray, patch_size: int, psf_size: int, # noqa:
476477

477478
for identifier, patch in self.patches.items():
478479
# Normalize the patch
479-
patch = patch / np.max(patch)
480+
patch = patch / patch[psf_size//2, psf_size//2]
480481

481482
# Determine which average region it belongs to
482483
center_x = identifier.x + self.size // 2
@@ -569,3 +570,41 @@ def to_array_corrector(self, target_evaluation: np.array) -> ArrayCorrector:
569570

570571
return ArrayCorrector(evaluation_dictionary, target_evaluation)
571572

573+
def save(self, path: str) -> None:
574+
"""Save the CoordinatePatchCollection to a file
575+
576+
Parameters
577+
----------
578+
path : str
579+
where to save the patch collection
580+
581+
Returns
582+
-------
583+
None
584+
"""
585+
with h5py.File(path, 'w') as f:
586+
patch_grp = f.create_group('patches')
587+
for key, val in self.patches.items():
588+
patch_grp.create_dataset(f"({key.image_index, key.x, key.y})", data=val)
589+
590+
@classmethod
591+
def load(cls, path: str) -> PatchCollectionABC:
592+
"""Load a PatchCollection from a file
593+
594+
Parameters
595+
----------
596+
path : str
597+
file path to load from
598+
599+
Returns
600+
-------
601+
PatchCollectionABC
602+
the new patch collection
603+
"""
604+
patches = dict()
605+
with h5py.File(path, "r") as f:
606+
for key, val in f['patches'].items():
607+
parsed_key = tuple(int(val) for val in key.replace("(", "").replace(")", "").split(","))
608+
coord_id = CoordinateIdentifier(image_index=parsed_key[0], x=parsed_key[1], y=parsed_key[2])
609+
patches[coord_id] = val[:].copy()
610+
return cls(patches)

requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
numpy==1.23.4
1+
numpy>=1.25.2
22
dill==0.3.6
3-
deepdish==0.3.7
3+
h5py>=3.9.0
44
lmfit==1.2.2
55
cython==3.0.0
6-
astropy=5.3.1
6+
astropy==5.3.1
77
scipy>=1.10.0
88
scikit-image==0.19.3
99
sep==1.2.1

requirements_dev.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
numpy==1.23.4
1+
numpy>=1.25.2
22
dill==0.3.6
3-
deepdish==0.3.7
3+
h5py>=3.9.0
44
lmfit==1.2.2
55
cython==3.0.0
66
astropy==5.3.1

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
setup(
1313
name='regularizepsf',
14-
version='0.2.1',
14+
version='0.2.2',
1515
description='Point spread function modeling and regularization',
1616
long_description=long_description,
1717
long_description_content_type='text/markdown',
@@ -21,7 +21,7 @@
2121
author='J. Marcus Hughes',
2222
author_email='[email protected]',
2323
ext_modules=cythonize(ext_modules, annotate=True, compiler_directives={'language_level': 3}),
24-
install_requires=["numpy", "dill", "deepdish", "lmfit", "sep", "cython", "astropy", "scipy", "scikit-image", "matplotlib"],
24+
install_requires=["numpy", "dill", "h5py", "lmfit", "sep", "cython", "astropy", "scipy", "scikit-image", "matplotlib"],
2525
package_data={"regularizepsf": ["helper.pyx"]},
2626
setup_requires=["cython"],
2727
extras_require={"test": ['pytest', 'coverage', 'pytest-runner', 'pytest-mpl']}

tests/test_corrector.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,6 @@ def padded_100by100_image_psf_10_with_pattern():
5555
img_padded = np.pad(img, padding_shape, mode='constant')
5656
return img_padded
5757

58-
#
59-
# @pytest.mark.parametrize("coord, value",
60-
# [((0, 0), 2),
61-
# ((10, 10), 1),
62-
# ((-10, -10), 0)])
63-
# def test_get_padded_img_section(coord, value, padded_100by100_image_psf_10_with_pattern):
64-
# img_i = get_padded_img_section(padded_100by100_image_psf_10_with_pattern, coord[0], coord[1], 10)
65-
# assert np.all(img_i == np.zeros((10, 10)) + value)
66-
#
67-
#
68-
# def test_set_padded_img_section(padded_100by100_image_psf_10_with_pattern):
69-
# test_img = np.pad(np.ones((100, 100)), ((20, 20), (20, 20)), mode='constant')
70-
# for coord, value in [((0, 0), 2), ((10, 10), 1), ((-10, -10), 0)]:
71-
# set_padded_img_section(test_img, coord[0], coord[1], 10, np.zeros((10, 10))+value)
72-
# assert np.all(test_img == padded_100by100_image_psf_10_with_pattern)
73-
7458

7559
def test_create_array_corrector():
7660
example = ArrayCorrector({(0, 0): np.zeros((10, 10))},
@@ -210,6 +194,8 @@ def test_save_load_array_corrector(tmp_path):
210194
assert os.path.isfile(fname)
211195
loaded = example.load(fname)
212196
assert isinstance(loaded, ArrayCorrector)
197+
assert np.all(loaded._target_evaluation == np.ones((100, 100)))
198+
assert np.all(loaded._evaluations[(0,0)] == np.ones((100, 100)))
213199

214200

215201
def test_array_corrector_simulate_observation_with_zero_stars():

tests/test_fitter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_coordinate_patch_average():
8181
})
8282
for patch in collection.values():
8383
# Make the normalization of each patch a no-op
84-
patch[-1, -1] = 1
84+
patch[5, 5] = 1
8585

8686
averaged_collection = collection.average(
8787
np.array([[0, 0]]), 10, 10, mode='median')

0 commit comments

Comments
 (0)