-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreviews.py
More file actions
428 lines (367 loc) · 15.1 KB
/
previews.py
File metadata and controls
428 lines (367 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# utils/previews.py
from __future__ import annotations
from pathlib import Path
import os
import numpy as np
import imageio.v3 as iio
import tempfile
import logging
import time
import threading
import tifffile as tiff
from typing import List, Optional, Tuple
from PIL import Image
from ai_agent.utils.image_meta import summarize_image_metadata
from ai_agent.utils.image_io import load_any
log = logging.getLogger("pipeline")
PREVIEW_CACHE_TTL_SECONDS = int(os.getenv("PREVIEW_CACHE_TTL_SECONDS", "1800"))
PREVIEW_CACHE_MAX_ENTRIES = int(os.getenv("PREVIEW_CACHE_MAX_ENTRIES", "64"))
PREVIEW_MAX_SIDE_PX = int(os.getenv("PREVIEW_MAX_SIDE_PX", "500"))
_PREVIEW_CACHE: dict[tuple[str, ...], tuple[float, str, Optional[str]]] = {}
_PREVIEW_CACHE_LOCK = threading.Lock()
def _fingerprint_paths(paths: List[str]) -> tuple[str, ...]:
fps: list[str] = []
for p in paths:
pp = Path(p)
try:
st = pp.stat()
fps.append(f"{pp.resolve()}::{st.st_mtime_ns}::{st.st_size}")
except Exception:
fps.append(str(pp))
return tuple(fps)
def _evict_expired_preview_entries(now: float) -> None:
expired = [k for k, (exp, _, _) in _PREVIEW_CACHE.items() if exp <= now]
for key in expired:
_PREVIEW_CACHE.pop(key, None)
def _clear_preview_cache_for_tests() -> None:
"""Test helper to avoid cache state leakage across test cases."""
with _PREVIEW_CACHE_LOCK:
_PREVIEW_CACHE.clear()
def _preview_cache_get(key: tuple[str, ...]) -> Tuple[Optional[str], Optional[str]]:
if PREVIEW_CACHE_TTL_SECONDS <= 0:
return None, None
now = time.monotonic()
with _PREVIEW_CACHE_LOCK:
_evict_expired_preview_entries(now)
entry = _PREVIEW_CACHE.get(key)
if not entry:
return None, None
_, preview_path, meta_text = entry
if not Path(preview_path).exists():
_PREVIEW_CACHE.pop(key, None)
return None, None
return preview_path, meta_text
def _preview_cache_set(
key: tuple[str, ...], preview_path: str, meta_text: Optional[str]
) -> None:
if PREVIEW_CACHE_TTL_SECONDS <= 0:
return
expires_at = time.monotonic() + PREVIEW_CACHE_TTL_SECONDS
with _PREVIEW_CACHE_LOCK:
_evict_expired_preview_entries(time.monotonic())
_PREVIEW_CACHE[key] = (expires_at, preview_path, meta_text)
if len(_PREVIEW_CACHE) <= PREVIEW_CACHE_MAX_ENTRIES:
return
# Evict oldest by expiration timestamp first.
items = sorted(_PREVIEW_CACHE.items(), key=lambda it: it[1][0])
over = len(_PREVIEW_CACHE) - PREVIEW_CACHE_MAX_ENTRIES
for k, _ in items[:over]:
_PREVIEW_CACHE.pop(k, None)
def _norm_uint8(a: np.ndarray) -> np.ndarray:
v = a.astype(np.float32)
v = v - np.nanmin(v)
vmax = np.nanpercentile(v, 99.5) if np.isfinite(v).any() else 1.0
vmax = vmax if vmax > 0 else (v.max() if v.max() > 0 else 1.0)
v = np.clip(v / vmax, 0, 1)
return (v * 255).astype(np.uint8)
def _is_rgb_like(shape: tuple[int, ...]) -> bool:
"""True for 2D color images shaped (H, W, 3/4)."""
return len(shape) == 3 and shape[-1] in (3, 4) and shape[0] >= 16 and shape[1] >= 16
def _to_uint8_image(arr: np.ndarray) -> np.ndarray:
"""Convert any numeric array to a uint8 image without changing shape."""
a = np.asarray(arr)
if a.dtype == np.uint8:
return a
if np.issubdtype(a.dtype, np.floating):
if np.nanmax(a) <= 1.0:
a = np.clip(a, 0.0, 1.0) * 255.0
else:
a = np.clip(a, 0.0, 255.0)
return a.astype(np.uint8)
return np.clip(a, 0, 255).astype(np.uint8)
def _resize_for_preview(img: Image.Image, max_side_px: int = PREVIEW_MAX_SIDE_PX) -> Image.Image:
"""Resize oversized previews while preserving aspect ratio."""
max_side_px = max(1, int(max_side_px))
if max(img.size) <= max_side_px:
return img
resized = img.copy()
resized.thumbnail((max_side_px, max_side_px), Image.Resampling.LANCZOS)
return resized
def mip_montage(vol3d: np.ndarray, out_png: str | Path) -> str:
vol3d = _norm_uint8(vol3d)
axial = vol3d.max(axis=2)
cor = vol3d.max(axis=1)
sag = vol3d.max(axis=0).T
h1 = np.hstack([axial, cor])
# pad to rectangle
pad = np.zeros_like(axial)
img = np.vstack([h1, np.hstack([sag, pad])])
_resize_for_preview(Image.fromarray(img)).save(str(out_png))
return str(out_png)
def slice_gif(
vol: np.ndarray, out_gif: str | Path, axis: int = 2, step: int = 1, fps: int = 10
) -> str:
v = _norm_uint8(vol)
idxs = list(range(0, v.shape[axis], step))
frames = [np.take(v, i, axis=axis) for i in idxs]
if frames:
h, w = frames[0].shape[:2]
max_side_px = max(1, PREVIEW_MAX_SIDE_PX)
if max(h, w) > max_side_px:
scale = max_side_px / float(max(h, w))
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
resized_frames = []
for frame in frames:
pil_frame = Image.fromarray(frame)
resized_frames.append(
np.asarray(
pil_frame.resize((new_w, new_h), Image.Resampling.LANCZOS)
)
)
frames = resized_frames
iio.imwrite(str(out_gif), frames, plugin="pillow", duration=int(1000 / fps), loop=0)
return str(out_gif)
def contact_sheet_slices(
vol3d: np.ndarray,
out_png: str | Path,
max_slices: int = 36,
grid_cols: int = 6,
) -> str:
v = _norm_uint8(vol3d)
depth = v.shape[2]
step = max(1, depth // max_slices)
frames = [v[:, :, i] for i in range(0, depth, step)]
frames = frames[:max_slices] # cap exactly
# pad to full grid
cols = grid_cols
rows = int(np.ceil(len(frames) / cols))
h, w = frames[0].shape
canvas = np.zeros((rows * h, cols * w), dtype=np.uint8)
for idx, frame in enumerate(frames):
r = idx // cols
c = idx % cols
canvas[r * h : (r + 1) * h, c * w : (c + 1) * w] = frame
_resize_for_preview(Image.fromarray(canvas)).save(str(out_png))
return str(out_png)
def create_orthogonal_views(vol3d: np.ndarray, out_png: str | Path) -> str:
"""
Create a comprehensive 3-view (axial, coronal, sagittal) visualization.
Each view shows both a middle slice and a MIP projection.
Args:
vol3d: 3D volume array
out_png: Output path for PNG
"""
v = _norm_uint8(vol3d)
h, w, d = v.shape
# Middle slices
axial_slice = v[:, :, d // 2]
coronal_slice = v[:, w // 2, :]
sagittal_slice = v[h // 2, :, :].T
# MIP projections
axial_mip = v.max(axis=2)
coronal_mip = v.max(axis=1)
sagittal_mip = v.max(axis=0).T
# Ensure all views have similar aspect ratios by padding
def pad_to_square(img: np.ndarray, target_size: int) -> np.ndarray:
h, w = img.shape
if h == w:
return img
pad_h = (target_size - h) // 2 if h < target_size else 0
pad_w = (target_size - w) // 2 if w < target_size else 0
return np.pad(
img,
((pad_h, target_size - h - pad_h), (pad_w, target_size - w - pad_w)),
mode="constant",
)
max_dim = max(
axial_slice.shape[0],
axial_slice.shape[1],
coronal_slice.shape[0],
coronal_slice.shape[1],
sagittal_slice.shape[0],
sagittal_slice.shape[1],
)
# Create 2x3 grid: MIPs on top row, slices on bottom row
top_row = np.hstack(
[
pad_to_square(axial_mip, max_dim),
pad_to_square(coronal_mip, max_dim),
pad_to_square(sagittal_mip, max_dim),
]
)
bottom_row = np.hstack(
[
pad_to_square(axial_slice, max_dim),
pad_to_square(coronal_slice, max_dim),
pad_to_square(sagittal_slice, max_dim),
]
)
composite = np.vstack([top_row, bottom_row])
_resize_for_preview(Image.fromarray(composite)).save(str(out_png))
return str(out_png)
def _build_preview_for_vlm(
image_paths: Optional[List[str]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Build an enhanced preview image optimized for VLM analysis.
Strategy:
- 2D images: Convert and normalize when needed
- 3D volumes: Create orthogonal multi-view composite
- 4D data: Extract representative 3D volume, then multi-view
- Medical images: Ensure proper intensity windowing
Returns:
(preview_path, metadata_text)
"""
if not image_paths:
return None, None
cache_key = _fingerprint_paths(image_paths)
cached_preview, cached_meta = _preview_cache_get(cache_key)
if cached_preview:
log.info("Preview cache hit for %d file(s)", len(image_paths))
return cached_preview, cached_meta
meta_text = None
try:
meta_text = summarize_image_metadata(image_paths)
except Exception:
log.exception(
"Image metadata summarization failed; continuing without metadata."
)
try:
_cleanup_old_previews(hours=24)
except Exception:
pass
for p in image_paths:
try:
data, meta = load_any(p)
shp = getattr(meta, "shape", None) or meta.get("shape")
if shp is None:
shp = getattr(data, "shape", None)
if shp is None:
continue
tmpdir = Path(tempfile.mkdtemp(prefix="preview_"))
arr = np.asarray(data)
ext = Path(p).suffix.lower()
# Handle true color images (H, W, 3/4) safely
# For PNG/JPEG/WebP, (H,W,3/4) is almost certainly color.
if _is_rgb_like(arr.shape) and ext in {".png", ".jpg", ".jpeg", ".webp"}:
out = tmpdir / "image_preview.png"
img_uint8 = _to_uint8_image(arr)
_resize_for_preview(Image.fromarray(img_uint8)).save(str(out))
_preview_cache_set(cache_key, str(out), meta_text)
return str(out), meta_text
# For TIFF, (H,W,3) can be either RGB or a 3-slice stack.
# If tags say it's RGB, render as color; otherwise treat as stack (fall through).
if _is_rgb_like(arr.shape) and ext in {".tif", ".tiff"}:
try:
with tiff.TiffFile(p) as tf:
page = tf.pages[0]
spp = int(getattr(page, "samplesperpixel", 1))
photometric = str(getattr(page, "photometric", "")).upper()
if spp in (3, 4) and (
"RGB" in photometric or "YCBCR" in photometric
):
out = tmpdir / "image_preview.png"
img_uint8 = _to_uint8_image(arr)
_resize_for_preview(Image.fromarray(img_uint8)).save(str(out))
_preview_cache_set(cache_key, str(out), meta_text)
return str(out), meta_text
except Exception:
# If tags can't be read, prefer treating TIFF (H,W,3) as a stack
pass
# 3D volumes: Create enhanced multi-view composite
if len(shp) == 3:
png_path = tmpdir / "orthogonal_views.png"
try:
# Try orthogonal views first (best for VLM understanding)
create_orthogonal_views(arr, png_path)
if png_path.exists():
log.info(
f"Created orthogonal view composite for 3D volume {shp}"
)
_preview_cache_set(cache_key, str(png_path), meta_text)
return str(png_path), meta_text
except Exception as e:
log.warning(
f"Orthogonal views failed: {e}, falling back to contact sheet"
)
# Fallback to contact sheet
png_path = tmpdir / "slices_grid.png"
try:
contact_sheet_slices(arr, png_path, max_slices=36, grid_cols=6)
if png_path.exists():
_preview_cache_set(cache_key, str(png_path), meta_text)
return str(png_path), meta_text
except Exception as e:
log.warning(
f"Contact sheet preview failed: {e}, falling back to MIP montage"
)
# Final fallback: MIP montage
try:
mip_montage(arr, png_path)
if png_path.exists():
_preview_cache_set(cache_key, str(png_path), meta_text)
return str(png_path), meta_text
except Exception:
pass
# 4D data: Extract representative 3D volume (mean over time), then multi-view
if len(shp) == 4:
vol = np.asarray(data).mean(axis=-1) # Average over 4th dimension
out = tmpdir / "orthogonal_4d.png"
try:
create_orthogonal_views(vol, out)
if out.exists():
log.info(f"Created orthogonal view for 4D volume {shp}")
_preview_cache_set(cache_key, str(out), meta_text)
return str(out), meta_text
except Exception as e:
log.warning(f"4D orthogonal failed: {e}, trying gif")
# Fallback to animated GIF
out = tmpdir / "sweep.gif"
step = max(1, vol.shape[2] // 64)
slice_gif(vol, out, axis=2, step=step, fps=12)
_preview_cache_set(cache_key, str(out), meta_text)
return str(out), meta_text
# 2D images: Normalize and resize.
if len(shp) == 2:
out = tmpdir / "image_preview.png"
arr2 = _norm_uint8(arr) # Use consistent normalization
_resize_for_preview(Image.fromarray(arr2)).save(str(out))
_preview_cache_set(cache_key, str(out), meta_text)
return str(out), meta_text
except Exception as e:
log.warning(f"Preview generation failed for {p}: {e}")
continue
return None, meta_text
def _cleanup_old_previews(hours: int = 24) -> None:
"""
Delete preview_* folders older than `hours` from the system temp dir.
Best-effort; ignore errors.
"""
root = Path(tempfile.gettempdir())
cutoff = time.time() - hours * 3600
try:
for p in root.glob("preview_*"):
try:
if p.is_dir() and p.stat().st_mtime < cutoff:
for sub in p.glob("**/*"):
try:
if sub.is_file():
sub.unlink()
except Exception:
pass
p.rmdir()
except Exception:
pass
except Exception:
logging.getLogger("api").exception("Preview cleanup failed")