Skip to content

Commit 2297767

Browse files
committed
refactor and clean up of the code
1 parent d72d2ad commit 2297767

8 files changed

Lines changed: 137 additions & 133 deletions

File tree

src/ai_agent/api/pipeline.py

Lines changed: 6 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,9 @@
33

44
import os
55
import logging
6-
import tempfile
76
from pathlib import Path
87
from typing import Any, Dict, List, Optional, Tuple
98

10-
import numpy as np
11-
import imageio.v3 as iio
12-
139
import re
1410
from utils.tags import strip_tags, parse_exclusions, has_no_rerank, has_refine
1511

@@ -24,9 +20,9 @@
2420
from generator.schema import CandidateDoc, NoToolReason
2521
from utils.file_validator import FileValidator
2622

27-
from utils.image_meta import summarize_image_metadata, detect_ext_token
28-
from utils.image_io import load_any
29-
from utils.previews import mip_montage, slice_gif, stack_sweep_gif, contact_sheet_slices
23+
from utils.image_meta import detect_ext_token
24+
from utils.previews import _build_preview_for_vlm, _cleanup_old_previews
25+
from utils.utils import _best_runnable_link
3026

3127
log = logging.getLogger("pipeline")
3228

@@ -35,7 +31,6 @@ class RAGImagingPipeline:
3531
def __init__(
3632
self,
3733
docs: List[SoftwareDoc],
38-
hf_token: Optional[str] = None,
3934
index_dir: Optional[str] = None,
4035
):
4136
self.index_dir = Path(index_dir or os.getenv("RAG_INDEX_DIR", "artifacts/rag_index"))
@@ -44,10 +39,9 @@ def __init__(
4439
self.embedder = LocalBGEEmbedder()
4540
self.reranker = CrossEncoderReranker()
4641
self.selector_vlm = VLMToolSelector()
47-
self.hf_token = hf_token
4842

4943
try:
50-
self._cleanup_old_previews(hours=24)
44+
_cleanup_old_previews(hours=24)
5145
except Exception:
5246
logging.getLogger("api").exception("Preview cleanup at init failed; continuing")
5347

@@ -183,90 +177,6 @@ def _norm(s: str) -> str:
183177
return hits, {"top": top, "second": second, "margin": margin}
184178

185179

186-
def _build_preview_for_vlm(self, image_paths: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
187-
if not image_paths:
188-
return None, None
189-
190-
meta_text = None
191-
try:
192-
meta_text = summarize_image_metadata(image_paths)
193-
except Exception:
194-
log.exception("Image metadata summarization failed; continuing without metadata.")
195-
196-
for p in image_paths:
197-
try:
198-
data, meta = load_any(p)
199-
shp = getattr(meta, "shape", None) or meta.get("shape")
200-
if shp is None:
201-
shp = getattr(data, "shape", None)
202-
if shp is None:
203-
continue
204-
205-
tmpdir = Path(tempfile.mkdtemp(prefix="preview_"))
206-
207-
if len(shp) == 3:
208-
png_path = tmpdir / "slices_grid.png"
209-
gif_path = tmpdir / "sweep.gif"
210-
try:
211-
contact_sheet_slices(data, png_path, max_slices=36, grid_cols=6)
212-
except Exception:
213-
try:
214-
mip_montage(data, png_path)
215-
except Exception:
216-
pass
217-
try:
218-
stack_sweep_gif(data, gif_path, fps=12, max_frames=64)
219-
except Exception:
220-
pass
221-
if png_path.exists():
222-
return str(png_path), meta_text
223-
if gif_path.exists():
224-
return str(gif_path), meta_text
225-
226-
if len(shp) == 4:
227-
vol = np.asarray(data).mean(axis=-1)
228-
out = tmpdir / "sweep.gif"
229-
step = max(1, vol.shape[2] // 64)
230-
slice_gif(vol, out, axis=2, step=step, fps=12)
231-
return str(out), meta_text
232-
233-
if len(shp) == 2:
234-
out = tmpdir / "image.png"
235-
arr = data
236-
if arr.dtype != np.uint8:
237-
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
238-
iio.imwrite(str(out), arr)
239-
return str(out), meta_text
240-
except Exception:
241-
continue
242-
243-
return None, meta_text
244-
245-
def _cleanup_old_previews(self, hours: int = 24) -> None:
246-
"""
247-
Delete preview_* folders older than `hours` from the system temp dir.
248-
Best-effort; ignore errors.
249-
"""
250-
import time, tempfile
251-
root = Path(tempfile.gettempdir())
252-
cutoff = time.time() - hours * 3600
253-
try:
254-
for p in root.glob("preview_*"):
255-
try:
256-
if p.is_dir() and p.stat().st_mtime < cutoff:
257-
for sub in p.glob("**/*"):
258-
try:
259-
if sub.is_file():
260-
sub.unlink()
261-
except Exception:
262-
pass
263-
p.rmdir()
264-
except Exception:
265-
pass
266-
except Exception:
267-
logging.getLogger("api").exception("Preview cleanup failed")
268-
269-
270180
def _select(self, hits, image_meta_text, user_task, preview_path):
271181
num_choices = int(os.getenv("NUM_CHOICES", "3"))
272182

@@ -352,38 +262,6 @@ def _select(self, hits, image_meta_text, user_task, preview_path):
352262
sel_json["choices"] = sel_json.get("choices", [])[:num_choices]
353263
return sel_json
354264

355-
def _best_runnable_link(self, doc: SoftwareDoc) -> Optional[str]:
356-
def priority(item) -> float:
357-
if isinstance(item, dict) and "priority" in item:
358-
try:
359-
return float(item["priority"])
360-
except Exception:
361-
pass
362-
return 1e9
363-
364-
def extract_url(item) -> Optional[str]:
365-
if isinstance(item, str):
366-
u = item.strip()
367-
return u or None
368-
if isinstance(item, dict):
369-
for k in ("url", "href", "link", "contentUrl"):
370-
u = item.get(k)
371-
if isinstance(u, str) and u.strip():
372-
return u.strip()
373-
return None
374-
375-
for items in (getattr(doc, "runnable_example", None) or [], getattr(doc, "has_executable_notebook", None) or []):
376-
try:
377-
items_sorted = sorted(items, key=priority)
378-
except Exception:
379-
items_sorted = items
380-
for it in items_sorted:
381-
url = extract_url(it)
382-
if url:
383-
return url
384-
385-
return None
386-
387265
def recommend_and_link(
388266
self,
389267
image_paths: Optional[List[str]],
@@ -423,7 +301,7 @@ def _norm(s: str) -> str:
423301
preview_path = None
424302
image_meta_text = ""
425303
try:
426-
preview_path, image_meta_text = self._build_preview_for_vlm(image_paths or [])
304+
preview_path, image_meta_text = _build_preview_for_vlm(image_paths or [])
427305
except Exception:
428306
image_meta_text = ""
429307

@@ -548,7 +426,7 @@ def _fallback_score(i: int, hit: dict) -> float:
548426
for choice in result["choices"]:
549427
doc = next((h["doc"] for h in hits if getattr(h["doc"], "name", "") == choice["name"]), None)
550428
if doc:
551-
link = self._best_runnable_link(doc)
429+
link = _best_runnable_link(doc)
552430
if link:
553431
choice["demo_link"] = link
554432

src/ai_agent/generator/generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(self, model: Optional[str] = None):
2323
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
2424
self.model = (
2525
model
26-
or os.getenv("OPENAI_SELECTOR_MODEL")
2726
or os.getenv("OPENAI_VLM_MODEL")
2827
or os.getenv("OPENAI_MODEL")
2928
or "gpt-4o-mini"

src/ai_agent/generator/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from enum import Enum
5-
from typing import List, Optional, Dict, Any
5+
from typing import List, Optional
66
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
77

88

src/ai_agent/retriever/embedders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# retriever/embedders.py
22
from __future__ import annotations
3-
import os
3+
44
from dataclasses import dataclass
55
from typing import Iterable, List, Tuple, Dict, Any, Optional, Union
66

src/ai_agent/ui/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
from utils.file_validator import FileValidator
5959
from utils.tags import strip_tags, parse_exclusions
60+
from utils.previews import _build_preview_for_vlm
6061

6162
# --- config -------------------------------------------------------------------
6263
CATALOG_PATH = os.getenv("SOFTWARE_CATALOG", "data/sample.jsonl")
@@ -299,7 +300,7 @@ def handle_message(message: str,
299300

300301
preview_path = None
301302
try:
302-
preview_path, _meta_text = get_pipeline()._build_preview_for_vlm(paths)
303+
preview_path, _meta_text = _build_preview_for_vlm(paths)
303304
except Exception:
304305
preview_path = None
305306

src/ai_agent/utils/previews.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
from pathlib import Path
44
import numpy as np
55
import imageio.v3 as iio
6+
import tempfile
7+
import logging
8+
import time
9+
from typing import List, Optional, Tuple
10+
from utils.image_meta import summarize_image_metadata
11+
from utils.image_io import load_any
12+
13+
log = logging.getLogger("pipeline")
14+
615

716
def _norm_uint8(a: np.ndarray) -> np.ndarray:
817
v = a.astype(np.float32)
@@ -64,3 +73,85 @@ def contact_sheet_slices(
6473

6574
iio.imwrite(str(out_png), canvas)
6675
return str(out_png)
76+
77+
def _build_preview_for_vlm(image_paths: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
78+
if not image_paths:
79+
return None, None
80+
81+
meta_text = None
82+
try:
83+
meta_text = summarize_image_metadata(image_paths)
84+
except Exception:
85+
log.exception("Image metadata summarization failed; continuing without metadata.")
86+
87+
for p in image_paths:
88+
try:
89+
data, meta = load_any(p)
90+
shp = getattr(meta, "shape", None) or meta.get("shape")
91+
if shp is None:
92+
shp = getattr(data, "shape", None)
93+
if shp is None:
94+
continue
95+
96+
tmpdir = Path(tempfile.mkdtemp(prefix="preview_"))
97+
98+
if len(shp) == 3:
99+
png_path = tmpdir / "slices_grid.png"
100+
gif_path = tmpdir / "sweep.gif"
101+
try:
102+
contact_sheet_slices(data, png_path, max_slices=36, grid_cols=6)
103+
except Exception:
104+
try:
105+
mip_montage(data, png_path)
106+
except Exception:
107+
pass
108+
try:
109+
stack_sweep_gif(data, gif_path, fps=12, max_frames=64)
110+
except Exception:
111+
pass
112+
if png_path.exists():
113+
return str(png_path), meta_text
114+
if gif_path.exists():
115+
return str(gif_path), meta_text
116+
117+
if len(shp) == 4:
118+
vol = np.asarray(data).mean(axis=-1)
119+
out = tmpdir / "sweep.gif"
120+
step = max(1, vol.shape[2] // 64)
121+
slice_gif(vol, out, axis=2, step=step, fps=12)
122+
return str(out), meta_text
123+
124+
if len(shp) == 2:
125+
out = tmpdir / "image.png"
126+
arr = data
127+
if arr.dtype != np.uint8:
128+
arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
129+
iio.imwrite(str(out), arr)
130+
return str(out), meta_text
131+
except Exception:
132+
continue
133+
134+
return None, meta_text
135+
136+
def _cleanup_old_previews(hours: int = 24) -> None:
137+
"""
138+
Delete preview_* folders older than `hours` from the system temp dir.
139+
Best-effort; ignore errors.
140+
"""
141+
root = Path(tempfile.gettempdir())
142+
cutoff = time.time() - hours * 3600
143+
try:
144+
for p in root.glob("preview_*"):
145+
try:
146+
if p.is_dir() and p.stat().st_mtime < cutoff:
147+
for sub in p.glob("**/*"):
148+
try:
149+
if sub.is_file():
150+
sub.unlink()
151+
except Exception:
152+
pass
153+
p.rmdir()
154+
except Exception:
155+
pass
156+
except Exception:
157+
logging.getLogger("api").exception("Preview cleanup failed")

src/ai_agent/utils/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from retriever.embedders import SoftwareDoc
2+
from typing import Optional
3+
4+
5+
def _best_runnable_link(doc: SoftwareDoc) -> Optional[str]:
6+
def priority(item) -> float:
7+
if isinstance(item, dict) and "priority" in item:
8+
try:
9+
return float(item["priority"])
10+
except Exception:
11+
pass
12+
return 1e9
13+
14+
def extract_url(item) -> Optional[str]:
15+
if isinstance(item, str):
16+
u = item.strip()
17+
return u or None
18+
if isinstance(item, dict):
19+
for k in ("url", "href", "link", "contentUrl"):
20+
u = item.get(k)
21+
if isinstance(u, str) and u.strip():
22+
return u.strip()
23+
return None
24+
25+
for items in (getattr(doc, "runnable_example", None) or [], getattr(doc, "has_executable_notebook", None) or []):
26+
try:
27+
items_sorted = sorted(items, key=priority)
28+
except Exception:
29+
items_sorted = items
30+
for it in items_sorted:
31+
url = extract_url(it)
32+
if url:
33+
return url
34+
35+
return None

tests/full_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_pipeline_against_sheet_with_urls(tmp_path: Path, monkeypatch, subtests,
262262
docs = _load_catalog_docs()
263263
from api.pipeline import RAGImagingPipeline
264264

265-
pipe = RAGImagingPipeline(docs=docs, hf_token=None)
265+
pipe = RAGImagingPipeline(docs=docs)
266266

267267
# For determinism in this unit test, we also disable the (patched) reranker.
268268
pipe.reranker = None

0 commit comments

Comments
 (0)