Skip to content

Commit 1689e25

Browse files
authored
Updated question extraction (#32)
1 parent fc51389 commit 1689e25

13 files changed

+423
-57
lines changed

ai_feedback/code_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def process_code(
8585
submission_file=submission_file,
8686
solution_file=solution_file,
8787
test_output=test_output_file,
88-
question_num=args.question,
88+
question=args.question,
8989
system_instructions=system_instructions,
9090
llama_mode=args.llama_mode,
9191
json_schema=args.json_schema,

ai_feedback/helpers/image_extractor.py

Lines changed: 232 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import base64
2+
import importlib
23
import json
34
import os
5+
import re
6+
import tempfile
7+
from pathlib import Path
8+
from typing import Any, Dict, List, Optional
49

510

6-
def extract_images(input_notebook_path: os.PathLike, output_directory: os.PathLike, output_name: str):
11+
def extract_images(input_notebook_path: os.PathLike, output_directory: os.PathLike, output_name: str) -> List[Path]:
12+
image_paths = []
713
with open(input_notebook_path, "r") as file:
814
notebook = json.load(file)
915
os.makedirs(output_directory, exist_ok=True)
@@ -42,6 +48,7 @@ def extract_images(input_notebook_path: os.PathLike, output_directory: os.PathLi
4248
image_data = base64.b64decode(data)
4349
with open(image_path, "wb") as img_file:
4450
img_file.write(image_data)
51+
image_paths.append(os.path.abspath(image_path))
4552

4653
# Save question context (source of previous cell)
4754
if cell_number >= 1 and notebook["cells"][cell_number - 1]["cell_type"] == "markdown":
@@ -51,3 +58,227 @@ def extract_images(input_notebook_path: os.PathLike, output_directory: os.PathLi
5158
question_context_path = os.path.join(output_directory, question_name, question_context_filename)
5259
with open(question_context_path, "w") as txt_file:
5360
txt_file.write(question_context_data)
61+
return image_paths
62+
63+
64+
def extract_qmd_python_chunks_with_context(qmd_path: str) -> List[Dict[str, Any]]:
65+
"""
66+
Extract ONLY Python code chunks from a QMD and annotate each with context from # / ## headings.
67+
Supports ```{python ...}, ```python, and ~~~ variants. Skips YAML front matter.
68+
"""
69+
_PY_CHUNK_START = re.compile(
70+
r"""^(
71+
```\{python\b[^}]*\}\s*$ | # ```{python ...}
72+
```python\s*$ | # ```python
73+
~~~\{python\b[^}]*\}\s*$ | # ~~~{python ...}
74+
~~~python\s*$ # ~~~python
75+
)""",
76+
re.IGNORECASE | re.VERBOSE,
77+
)
78+
_FENCE_END_TICKS = re.compile(r"^```\s*$")
79+
_FENCE_END_TILDES = re.compile(r"^~~~\s*$")
80+
_H1 = re.compile(r"^#\s+(.*)$")
81+
_H2 = re.compile(r"^##\s+(.*)$")
82+
83+
qp = Path(qmd_path)
84+
if not qp.exists():
85+
raise FileNotFoundError(f"QMD file not found: {qmd_path}")
86+
87+
lines = qp.read_text(encoding="utf-8", errors="ignore").splitlines()
88+
89+
# Skip YAML front matter if present
90+
i = 0
91+
if lines and re.match(r"^---\s*$", lines[0]):
92+
i = 1
93+
while i < len(lines) and not re.match(r"^---\s*$", lines[i]):
94+
i += 1
95+
i = min(i + 1, len(lines))
96+
97+
current_main = None
98+
current_sub = None
99+
chunks = []
100+
in_py = False
101+
cur = []
102+
start_line = 0
103+
fence_kind = None # "```" or "~~~"
104+
105+
for i, raw in enumerate(lines):
106+
line = raw.rstrip("\n")
107+
108+
if not in_py and _PY_CHUNK_START.match(line):
109+
in_py = True
110+
cur = []
111+
start_line = i
112+
fence_kind = "~~~" if line.strip().startswith("~~~") else "```"
113+
continue
114+
115+
if in_py:
116+
if (fence_kind == "```" and _FENCE_END_TICKS.match(line)) or (
117+
fence_kind == "~~~" and _FENCE_END_TILDES.match(line)
118+
):
119+
in_py = False
120+
fence_kind = None
121+
context = (
122+
f"{current_main}__{current_sub}" if current_main and current_sub else (current_main or "unknown")
123+
)
124+
if cur:
125+
chunks.append(
126+
{
127+
"context": context,
128+
"code": cur[:],
129+
"start_line": start_line + 1, # 1-based
130+
}
131+
)
132+
continue
133+
else:
134+
cur.append(raw)
135+
continue
136+
137+
m1 = _H1.match(line)
138+
if m1:
139+
current_main = _clean_heading_text(m1.group(1))
140+
current_sub = None
141+
continue
142+
m2 = _H2.match(line)
143+
if m2:
144+
current_sub = _clean_heading_text(m2.group(1))
145+
continue
146+
147+
return chunks
148+
149+
150+
def extract_qmd_python_images(qmd_path: str, output_dir: Optional[str] = None, dpi: int = 120) -> List[str]:
151+
"""
152+
Runs the python blocks of code and saves images using matplotlib.
153+
Args:
154+
qmd_path: path to qmd file to extract images from
155+
output_dir(Optional): directory to save images to
156+
dpi: The resolution of the output image
157+
158+
Returns:
159+
List[str]: Returns a list of image paths
160+
161+
"""
162+
chunks = extract_qmd_python_chunks_with_context(qmd_path)
163+
if not chunks:
164+
return []
165+
166+
outdir = Path(output_dir or tempfile.mkdtemp(prefix="qmd_py_imgs_"))
167+
outdir.mkdir(parents=True, exist_ok=True)
168+
169+
mpl = importlib.import_module("matplotlib")
170+
mpl.use("Agg", force=True)
171+
plt = importlib.import_module("matplotlib.pyplot")
172+
173+
saved_files = []
174+
per_context_counter = {}
175+
current_context = "unknown"
176+
177+
def _save_fig_with_context(fig, ctx: str):
178+
try:
179+
fig.canvas.draw() # force render
180+
except Exception as e:
181+
print(f"[warn] could not draw canvas: {e}")
182+
183+
cnt = per_context_counter.get(ctx, 0) + 1
184+
per_context_counter[ctx] = cnt
185+
fname = f"plot__{ctx}__{cnt:03d}.png"
186+
path = outdir / fname
187+
188+
fig.savefig(path.as_posix(), dpi=dpi, bbox_inches="tight")
189+
saved_files.append(path.as_posix())
190+
191+
# mirror user savefig
192+
_orig_savefig = plt.savefig
193+
194+
def _mirror_savefig(*args, **kwargs):
195+
nonlocal current_context, saved_fignums_this_chunk
196+
ctx = current_context
197+
_orig_savefig(*args, **kwargs)
198+
fig = plt.gcf()
199+
fnum = fig.number
200+
if fnum not in saved_fignums_this_chunk:
201+
_save_fig_with_context(fig, ctx)
202+
saved_fignums_this_chunk.add(fnum)
203+
204+
plt.savefig = _mirror_savefig
205+
206+
exec_env = {"__name__": "__qmd_exec__", "plt": plt}
207+
208+
for ch in chunks:
209+
current_context = ch["context"].replace(os.sep, "_")
210+
per_context_counter.setdefault(current_context, 0)
211+
212+
fignums_before = set(plt.get_fignums())
213+
214+
created_figs = []
215+
saved_fignums_this_chunk = set()
216+
217+
_orig_figure = plt.figure
218+
_orig_subplots = plt.subplots
219+
220+
def _wrapped_figure(*args, **kwargs):
221+
fig = _orig_figure(*args, **kwargs)
222+
created_figs.append((fig.number, fig))
223+
return fig
224+
225+
def _wrapped_subplots(*args, **kwargs):
226+
fig, ax = _orig_subplots(*args, **kwargs)
227+
created_figs.append((fig.number, fig))
228+
return fig, ax
229+
230+
plt.figure = _wrapped_figure
231+
plt.subplots = _wrapped_subplots
232+
233+
code = "\n".join(ch["code"])
234+
try:
235+
exec(compile(code, filename=f"{qmd_path}:{ch['start_line']}", mode="exec"), exec_env, exec_env)
236+
except Exception:
237+
pass
238+
finally:
239+
plt.figure = _orig_figure
240+
plt.subplots = _orig_subplots
241+
242+
fignums_after = set(plt.get_fignums())
243+
new_fignums = sorted(fignums_after - fignums_before)
244+
245+
for fnum, fig in created_figs:
246+
if fnum in saved_fignums_this_chunk:
247+
continue
248+
_save_fig_with_context(fig, current_context)
249+
saved_fignums_this_chunk.add(fnum)
250+
251+
# save any remaining new fig numbers (re-open only if not saved)
252+
for fnum in new_fignums:
253+
if fnum in saved_fignums_this_chunk:
254+
continue
255+
fig = plt.figure(fnum)
256+
_save_fig_with_context(fig, current_context)
257+
saved_fignums_this_chunk.add(fnum)
258+
259+
# close figs we touched
260+
for _, fig in created_figs:
261+
try:
262+
plt.close(fig)
263+
except Exception:
264+
pass
265+
for fnum in new_fignums:
266+
try:
267+
plt.close(plt.figure(fnum))
268+
except Exception:
269+
pass
270+
271+
plt.savefig = _orig_savefig
272+
273+
seen, uniq = set(), []
274+
for p in saved_files:
275+
if p not in seen:
276+
seen.add(p)
277+
uniq.append(p)
278+
return uniq
279+
280+
281+
def _clean_heading_text(s: str) -> str:
282+
s = s.strip()
283+
s = re.sub(r"\s+", " ", s)
284+
return s

0 commit comments

Comments
 (0)