-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathjwst_image_processor.py
More file actions
362 lines (302 loc) · 11.8 KB
/
jwst_image_processor.py
File metadata and controls
362 lines (302 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
JWST Image Processor
Processes and enhances James Webb Space Telescope images
"""
import os
import numpy as np
from astropy.io import fits
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.stats import sigma_clipped_stats
from photutils.segmentation import detect_sources
from photutils.segmentation import SourceCatalog
from scipy import ndimage
from scipy.ndimage import gaussian_filter
from skimage import exposure, filters
from skimage.restoration import denoise_tv_chambolle
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, PowerNorm
import warnings
warnings.filterwarnings('ignore')
class JWSTImageProcessor:
"""Process and enhance JWST images"""
def __init__(self, data_dir: str = "jwst_data"):
self.data_dir = data_dir
self.processed_dir = os.path.join(data_dir, "processed")
os.makedirs(self.processed_dir, exist_ok=True)
def load_fits_file(self, filepath: str) -> tuple:
"""
Load FITS file and extract data, header, and WCS
Parameters:
-----------
filepath : str
Path to FITS file
Returns:
--------
data : numpy.ndarray
Image data
header : astropy.io.fits.header.Header
FITS header
wcs : astropy.wcs.WCS
World Coordinate System
"""
try:
with fits.open(filepath) as hdul:
# JWST data is usually in the first HDU
hdu = hdul[0]
data = hdu.data
header = hdu.header
# Create WCS from header
wcs = WCS(header)
print(f"Loaded {filepath}")
print(f"Data shape: {data.shape}")
print(f"Data type: {data.dtype}")
print(f"Data range: {np.nanmin(data):.2e} to {np.nanmax(data):.2e}")
return data, header, wcs
except Exception as e:
print(f"Error loading {filepath}: {e}")
return None, None, None
def basic_calibration(self, data: np.ndarray, header: dict) -> np.ndarray:
"""
Apply basic calibration steps
Parameters:
-----------
data : numpy.ndarray
Raw image data
header : dict
FITS header
Returns:
--------
calibrated_data : numpy.ndarray
Calibrated image data
"""
calibrated_data = data.copy()
# Remove NaN and infinite values
calibrated_data = np.nan_to_num(calibrated_data, nan=0.0, posinf=0.0, neginf=0.0)
# Apply basic background subtraction
mean, median, std = sigma_clipped_stats(calibrated_data, sigma=3.0)
calibrated_data = calibrated_data - median
# Remove negative values
calibrated_data = np.maximum(calibrated_data, 0)
return calibrated_data
def enhance_contrast(self, data: np.ndarray, method: str = 'log') -> np.ndarray:
"""
Enhance image contrast
Parameters:
-----------
data : numpy.ndarray
Image data
method : str
Enhancement method ('log', 'sqrt', 'power', 'histogram')
Returns:
--------
enhanced_data : numpy.ndarray
Enhanced image data
"""
# Normalize to 0-1 range
data_norm = (data - np.min(data)) / (np.max(data) - np.min(data))
if method == 'log':
# Logarithmic scaling
enhanced_data = np.log1p(data_norm * 1000) / np.log(1001)
elif method == 'sqrt':
# Square root scaling
enhanced_data = np.sqrt(data_norm)
elif method == 'power':
# Power law scaling (gamma correction)
enhanced_data = np.power(data_norm, 0.5)
elif method == 'histogram':
# Histogram equalization
enhanced_data = exposure.equalize_hist(data_norm)
else:
enhanced_data = data_norm
return enhanced_data
def denoise_image(self, data: np.ndarray, method: str = 'gaussian') -> np.ndarray:
"""
Denoise image using various methods
Parameters:
-----------
data : numpy.ndarray
Image data
method : str
Denoising method ('gaussian', 'median', 'tv')
Returns:
--------
denoised_data : numpy.ndarray
Denoised image data
"""
if method == 'gaussian':
# Gaussian filter
denoised_data = gaussian_filter(data, sigma=1.0)
elif method == 'median':
# Median filter
denoised_data = ndimage.median_filter(data, size=3)
elif method == 'tv':
# Total variation denoising
denoised_data = denoise_tv_chambolle(data, weight=0.1)
else:
denoised_data = data
return denoised_data
def detect_sources(self, data: np.ndarray, threshold: float = 3.0) -> SourceCatalog:
"""
Detect sources in the image
Parameters:
-----------
data : numpy.ndarray
Image data
threshold : float
Detection threshold in sigma
Returns:
--------
catalog : photutils.segmentation.SourceCatalog
Source catalog
"""
# Calculate background and noise
mean, median, std = sigma_clipped_stats(data, sigma=3.0)
# Detect sources
threshold_value = median + threshold * std
segment_map = detect_sources(data, threshold_value, npixels=5)
if segment_map is not None:
catalog = SourceCatalog(data, segment_map)
print(f"Detected {len(catalog)} sources")
return catalog
else:
print("No sources detected")
return None
def create_color_composite(self,
red_data: np.ndarray,
green_data: np.ndarray,
blue_data: np.ndarray,
red_filter: str = "F444W",
green_filter: str = "F277W",
blue_filter: str = "F090W") -> np.ndarray:
"""
Create RGB color composite from three filter images
Parameters:
-----------
red_data, green_data, blue_data : numpy.ndarray
Image data for each filter
red_filter, green_filter, blue_filter : str
Filter names for labeling
Returns:
--------
rgb_image : numpy.ndarray
RGB composite image (H, W, 3)
"""
# Normalize each channel
def normalize_channel(data):
data_norm = (data - np.percentile(data, 1)) / (np.percentile(data, 99) - np.percentile(data, 1))
return np.clip(data_norm, 0, 1)
red_norm = normalize_channel(red_data)
green_norm = normalize_channel(green_data)
blue_norm = normalize_channel(blue_data)
# Create RGB array
rgb_image = np.stack([red_norm, green_norm, blue_norm], axis=2)
print(f"Created RGB composite: {red_filter} (R), {green_filter} (G), {blue_filter} (B)")
return rgb_image
def process_single_image(self,
filepath: str,
enhance_method: str = 'log',
denoise_method: str = 'gaussian',
detect_sources_flag: bool = True) -> dict:
"""
Process a single JWST image
Parameters:
-----------
filepath : str
Path to FITS file
enhance_method : str
Contrast enhancement method
denoise_method : str
Denoising method
detect_sources_flag : bool
Whether to detect sources
Returns:
--------
result : dict
Dictionary containing processed data and metadata
"""
# Load data
data, header, wcs = self.load_fits_file(filepath)
if data is None:
return None
# Basic calibration
calibrated_data = self.basic_calibration(data, header)
# Denoise
denoised_data = self.denoise_image(calibrated_data, method=denoise_method)
# Enhance contrast
enhanced_data = self.enhance_contrast(denoised_data, method=enhance_method)
# Detect sources
catalog = None
if detect_sources_flag:
catalog = self.detect_sources(denoised_data)
# Get filter information
filter_name = header.get('FILTER', 'Unknown')
instrument = header.get('INSTRUME', 'Unknown')
result = {
'original_data': data,
'calibrated_data': calibrated_data,
'denoised_data': denoised_data,
'enhanced_data': enhanced_data,
'catalog': catalog,
'header': header,
'wcs': wcs,
'filter_name': filter_name,
'instrument': instrument,
'filename': os.path.basename(filepath)
}
return result
def process_multiple_filters(self,
filepaths: list,
filters: list = None) -> dict:
"""
Process multiple filter images and create composite
Parameters:
-----------
filepaths : list
List of FITS file paths
filters : list
List of filter names (optional)
Returns:
--------
result : dict
Dictionary containing processed data for all filters
"""
processed_filters = {}
for i, filepath in enumerate(filepaths):
if not os.path.exists(filepath):
print(f"File not found: {filepath}")
continue
print(f"\nProcessing filter {i+1}/{len(filepaths)}: {filepath}")
result = self.process_single_image(filepath)
if result is not None:
filter_name = result['filter_name']
processed_filters[filter_name] = result
# Create composite if we have multiple filters
if len(processed_filters) >= 3:
filter_names = list(processed_filters.keys())
print(f"\nCreating composite from filters: {filter_names}")
# Use the first three filters for RGB
rgb_data = self.create_color_composite(
processed_filters[filter_names[0]]['enhanced_data'],
processed_filters[filter_names[1]]['enhanced_data'],
processed_filters[filter_names[2]]['enhanced_data'],
filter_names[0], filter_names[1], filter_names[2]
)
processed_filters['composite'] = {
'rgb_data': rgb_data,
'filters_used': filter_names[:3]
}
return processed_filters
# Example usage
if __name__ == "__main__":
processor = JWSTImageProcessor()
# Process a single file (if available)
raw_dir = os.path.join(processor.data_dir, "raw")
if os.path.exists(raw_dir):
files = [f for f in os.listdir(raw_dir) if f.endswith('.fits')]
if files:
filepath = os.path.join(raw_dir, files[0])
result = processor.process_single_image(filepath)
if result:
print(f"Processed {result['filename']} with {result['filter_name']} filter")