feat: add support for long-context documents#179
Conversation
|
All contributors have signed the DCO ✍️ ✅ |
|
I have read the DCO document and I hereby sign the DCO. |
Greptile SummaryThis PR adds a comprehensive windowing layer across every long-document bottleneck in the anonymizer pipeline, replacing single-call
Confidence Score: 4/5Safe to merge with minor fixes; all correctness issues are limited to long-document paths and affect output cosmetics or serialization consistency rather than data loss or pipeline crashes. The windowing design is sound and well-tested for core paths. The most visible issue is the stitching in chunked_rewrite: because iter_boundary_windows aligns cuts to newlines, each chunk already ends with a newline, and models naturally mirror that — so every chunk boundary in a long anonymized document will have a spurious blank line. The fast-path cap measurement is also off (preamble overhead counted but not sent), silently routing near-cap documents into the chunked path. In qa_generation.py both paths call .model_dump() without mode='json', inconsistently with every other windowed generator in the PR. chunked_rewrite.py (stitching and cap measurement) and qa_generation.py (model_dump mode and private import of _compile_template) Important Files Changed
Reviews (1): Last reviewed commit: "format scripts" | Re-trigger Greptile |
| _clip(summary), | ||
| ) | ||
|
|
||
| stitched = "\n".join(part for part in rewritten_parts if part) |
There was a problem hiding this comment.
Chunk boundaries are aligned to newlines by
iter_boundary_windows, so each tagged[start:end] slice already ends with " ". When the LLM mirrors that structure in its output (natural for paragraph-aware models), every rewritten_chunk also ends with " ", and " ".join(...) then inserts a second newline — producing a blank line between every chunk boundary in the final anonymized document. Joining with "" is sufficient because the delimiter is already part of each chunk.
| stitched = "\n".join(part for part in rewritten_parts if part) | |
| stitched = "".join(part for part in rewritten_parts if part) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| # Fast path: the full single-call rewrite prompt fits under the cap. | ||
| single_rendered = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=row, summary="") | ||
| if len(single_rendered) <= cap: | ||
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | ||
| text = _rewrite_chunk( | ||
| facade=facade, | ||
| prompt=_compile_template(params.single_call_prompt_template).render(**row), | ||
| system_prompt=params.system_prompt, | ||
| purpose="rewrite-generation", | ||
| ) |
There was a problem hiding this comment.
The fast path measures
single_rendered as _render_chunk_prompt(..., summary=""), which prepends the ~270-char continuity preamble, but then the actual LLM call omits that preamble. This means a document whose body-only prompt falls in (cap - 270, cap] chars will be routed into the chunked path unnecessarily. Measure with just the body to match what is actually sent.
| # Fast path: the full single-call rewrite prompt fits under the cap. | |
| single_rendered = _render_chunk_prompt(template=params.single_call_prompt_template, chunk_row=row, summary="") | |
| if len(single_rendered) <= cap: | |
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | |
| text = _rewrite_chunk( | |
| facade=facade, | |
| prompt=_compile_template(params.single_call_prompt_template).render(**row), | |
| system_prompt=params.system_prompt, | |
| purpose="rewrite-generation", | |
| ) | |
| # Fast path: measure body-only prompt (no continuity preamble) since that is what is sent. | |
| single_rendered = _compile_template(params.single_call_prompt_template).render(**row) | |
| if len(single_rendered) <= cap: | |
| logger.debug("rewrite: single-call fast path (rendered=%d chars <= cap=%d)", len(single_rendered), cap) | |
| text = _rewrite_chunk( | |
| facade=facade, | |
| prompt=single_rendered, | |
| system_prompt=params.system_prompt, | |
| purpose="rewrite-generation", | |
| ) |
| ) | ||
| from anonymizer.engine.ndd.model_loader import resolve_model_alias | ||
| from anonymizer.engine.prompt_utils import substitute_placeholders |
There was a problem hiding this comment.
Private symbol imported across module boundary.
_compile_template is module-private (underscore-prefixed) in chunked_steps.py. Importing it here creates a hidden coupling: if the function is renamed or inlined, qa_generation.py breaks without any clear contract. Consider exposing it as a public helper in chunked_steps.py or defining a local copy with its own lru_cache in this module.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump() | ||
| return row | ||
|
|
||
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | ||
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | ||
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | ||
| items: list[dict[str, Any]] = [] | ||
| for batch_idx, batch in enumerate(batches): | ||
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | ||
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | ||
| for item in out.items: | ||
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | ||
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump() |
There was a problem hiding this comment.
The fast path stores the result via
.model_dump() (no mode="json"), while every other windowed generator in this PR consistently uses .model_dump(mode="json"). Without mode="json", Pydantic returns native Python objects rather than JSON-serializable equivalents, which can cause downstream serialization failures. The batched path has the same inconsistency.
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump() | |
| return row | |
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | |
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | |
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | |
| items: list[dict[str, Any]] = [] | |
| for batch_idx, batch in enumerate(batches): | |
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | |
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | |
| for item in out.items: | |
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | |
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump() | |
| row[COL_QUALITY_QA] = _generate(full_rendered, "quality-qa-generation").model_dump(mode="json") | |
| return row | |
| units = json.loads(row.get(COL_MEANING_UNITS_SERIALIZED) or "[]") | |
| base_len = len(compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: "[]"})) | |
| batches = _batch_units_by_size(units, base_len, max_render_chars - safety_margin_chars) | |
| items: list[dict[str, Any]] = [] | |
| for batch_idx, batch in enumerate(batches): | |
| rendered = compiled.render(**{**row, COL_MEANING_UNITS_SERIALIZED: json.dumps(batch, ensure_ascii=False)}) | |
| out = _generate(rendered, f"quality-qa-generation-batch-{batch_idx}") | |
| for item in out.items: | |
| items.append({**item.model_dump(mode="json"), "id": len(items) + 1}) | |
| row[COL_QUALITY_QA] = QualityQAPairsSchema.model_validate({"items": items}).model_dump(mode="json") |
| _DEFAULT_MAX_RENDER_CHARS: int = _DetectConfig.model_fields["detection_window_max_render_chars"].default | ||
| _DEFAULT_SAFETY_MARGIN_CHARS: int = _DetectConfig.model_fields["detection_window_safety_margin_chars"].default | ||
|
|
||
|
|
There was a problem hiding this comment.
Unguarded index on potentially empty list.
_first_output calls outputs[0] without checking length. In run_windowed_step with first_only=True, if iter_boundary_windows returns an empty list, outputs is empty and this raises IndexError. The fast path makes this unreachable today, but a defensive guard would make the failure mode explicit.
andreatgretel
left a comment
There was a problem hiding this comment.
Thanks for taking this on. This is a substantial first PR, and the overall direction makes sense: split long records into bounded windows, carry forward the state needed for consistency, and keep the replacement map explicit.
I left a few comments on edge cases I think are worth tightening before merge. The main themes are:
- thread the user-supplied window sizing through every windowed stage
- make per-window failures local where possible, instead of dropping the whole record
- validate overlap settings early so a bad config cannot explode into thousands of model calls
- avoid silently accepting empty rewrite chunks as successful output
The tests and docs coverage are in good shape, and I think the feature is close. These changes should make it more reliable on real long documents.
| ), | ||
| *self._qa_wf.columns(selected_models=selected_models), | ||
| *self._rewrite_gen_wf.columns( | ||
| window_max_render_chars=window_max_render_chars, |
There was a problem hiding this comment.
this only threads the user-supplied window cap into rewrite generation. domain classification, sensitivity disposition, QA generation, and final judge still build their window params from module defaults, so a user who lowers Detect.detection_window_max_render_chars still gets ~128k prompts in those stages. Could pass the same kwargs through those columns() calls and _run_final_judge too?
| if params.first_only: | ||
| windows = windows[:1] | ||
| outputs = [] | ||
| for start, end in windows: |
There was a problem hiding this comment.
Claude Code caught this one: once this takes the windowed path, a single transient model error or a chunk that legitimately has no meaning units can drop the whole record. Could wrap each window call, skip/log failed windows, and handle the all-failed case explicitly?
| "prompt scaffolding and tags when sizing augmentation/latent windows." | ||
| ), | ||
| ) | ||
| detection_window_overlap_chars: int = Field( |
There was a problem hiding this comment.
suggestion: can we validate that detection_window_overlap_chars is smaller than the effective window size? Right now overlap == window is accepted and the planners advance one character at a time. My smoke test turned a 20k-char row into 16,001 windows.
| _clip(summary), | ||
| ) | ||
|
|
||
| stitched = "\n".join(part for part in rewritten_parts if part) |
There was a problem hiding this comment.
separate from the newline-stitching comment already here: filtering with if part also hides an empty rewrite chunk. If one window returns {"rewritten_text": ""}, that section disappears with no failed-window count or review signal. Maybe count/flag empty chunks instead of treating them as successful output?
Summary
Several stages embedded the whole document in a single prompt and hit DataDesigner's 512K (
MAX_RENDERED_LEN) render cap, failing outright on long inputs. Every such stage is now windowed: each chunked generator renders its own per-window prompt and calls the model directly, bypassing the cap. Stages keep a single-call fast path when the rendered prompt already fits, so short-document behavior is unchanged.Per-stage windowing
chunked_detection.py, new): Overlapping fixed-size character windows; each window is a raw text slice sent to the detector. Per-window offsets are rebased to global, boundary-touching spans are dropped, and overlaps are resolved (resolve_overlaps).chunked_validation.py): Not a text window — batches candidate entities (≤100 per call), each with a ±500-character excerpt. Calls run in parallel across the validator pool with round-robin + failover. Decisions are merged per row; the row is dropped only if every pool member fails.chunked_augmentation.py): Overlapping character windows over tagged text plus seed JSON. A window dynamically shrinks if its rendered prompt exceeds the cap. Outputs are unioned and deduped by(value, label).chunked_latent.py): Same mechanism as augmentation (rewrite mode only); deduped by(label, value).chunked_replace.py): Abutting newline-aligned windows, no overlap. Each chunk carries the accumulated replacement map and a rolling summary, proposing replacements only for new entities so mappings stay consistent across chunks.chunked_rewrite.py): Abutting newline-aligned windows, no overlap. Runs sequentially, passing a continuity preamble and rolling summary between chunks; rewritten parts are stitched.chunked_final_judge.py, new): Splits original and rewritten text into N positionally-paired slices, scores each, and aggregates per-dimension by minimum. Rubric scales are embedded in the prompt with structured output. Replaces the non-windowedLLMJudgeColumnConfig.Parallel processing
ThreadPoolExecutor; the per-alias rate limit still governs real in-flight calls) and merge afterward.Window sizing
detection_window_max_render_chars(default 128 KiB, clamped ≤ NDD's render cap) is the single knob; it is threaded into detection, augmentation, latent, substitute-map, rewrite, and judge.detection_window_safety_margin_chars(8K) leaves headroom for prompt scaffolding;detection_window_overlap_chars(1K) sets the overlap for the overlapping stages; a 4K floor prevents pathological shrinking.Fault tolerance & failure tracking
trace_dataframe(COL_AUGMENTATION_FAILED_WINDOWS/COL_LATENT_FAILED_WINDOWS); the judge degrades to defaults if all windows fail.Observability
Per-window debug logging across all chunked stages: window ranges/sizes, rendered length vs cap, shrink events, rolling-summary contents, and per-stage entity/replacement/window counts.
Type of Change
Testing
make testpasses locallymake checkpasses locally (format + lint + typecheck + lock-check)Documentation
make docs-buildpasses locally