Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion OUTPUT_SCHEMAS.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ Printed to stdout once per input example.
"start": 0,
"end": 5,
"text": "Alice",
"placeholder": "<PRIVATE_PERSON>"
"placeholder": "<PRIVATE_PERSON>",
"score": 0.9321
}
],
"redacted_text": "<PRIVATE_PERSON> was born on <PRIVATE_DATE>."
Expand All @@ -33,7 +34,9 @@ Printed to stdout once per input example.
Notes:

- In `--output-mode redacted`, every `detected_spans[*].label` becomes `redacted`.
- `detected_spans[*].score` is a confidence probability in `[0.0, 1.0]`.
- `warning` is present only when tokenizer decode does not exactly round-trip the input text.
- `score` is additive, existing clients can ignore it without changing `schema_version`.

## 2. `opf eval` Predictions Output (`--predictions-out`)

Expand Down
18 changes: 15 additions & 3 deletions opf/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def to_dict(self) -> dict[str, object]:
"end": span.end,
"text": span.text,
"placeholder": span.placeholder,
"score": span.score,
}
for span in self.detected_spans
],
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
decode_mode: Literal["viterbi", "argmax"] = "viterbi",
discard_overlapping_predicted_spans: bool = False,
output_text_only: bool = False,
min_score: float | None = None,
) -> None:
"""Create a reusable local OPF redactor.

Expand All @@ -214,12 +216,17 @@ def __init__(
predicted spans per label.
output_text_only: Whether :meth:`redact` should return only the
redacted text instead of a structured result.
min_score: Optional minimum confidence threshold in ``[0.0, 1.0]``
used to filter predicted spans before redaction.

Raises:
ValueError: If ``output_mode`` is unsupported.
"""
if output_mode not in OUTPUT_MODES:
raise ValueError(f"Unsupported output_mode: {output_mode!r}")
if min_score is not None and not (0.0 <= float(min_score) <= 1.0):
raise ValueError("min_score must be in [0.0, 1.0]")
self._min_score = float(min_score) if min_score is not None else None
self._checkpoint = resolve_checkpoint_path(model)
self._context_window_length = context_window_length
self._trim_whitespace = bool(trim_whitespace)
Expand Down Expand Up @@ -256,19 +263,24 @@ def redact(
"""
runtime, decoder = self.get_prediction_components(decode=decode)
prediction = predict_text(runtime, text, decoder=decoder)
redacted_text = _redact_text(prediction.text, prediction.spans)
filtered_spans = tuple(prediction.spans)
if self._min_score is not None:
filtered_spans = tuple(
span for span in filtered_spans if float(span.score) >= self._min_score
)
redacted_text = _redact_text(prediction.text, filtered_spans)
if self._output_text_only:
return redacted_text
summary = build_detection_summary(
output_mode=runtime.output_mode,
labels=[span.label for span in prediction.spans],
labels=[span.label for span in filtered_spans],
decoded_mismatch=prediction.decoded_mismatch,
)
return RedactionResult(
schema_version=SCHEMA_VERSION,
summary=summary,
text=prediction.text,
detected_spans=tuple(prediction.spans),
detected_spans=filtered_spans,
redacted_text=redacted_text,
warning=_warning_for_prediction(prediction),
)
Expand Down
2 changes: 2 additions & 0 deletions opf/_cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
add_decode_mode_arg,
add_device_arg,
add_discard_overlapping_predicted_spans_arg,
add_min_score_arg,
add_n_ctx_arg,
add_output_mode_arg,
add_trim_whitespace_args,
Expand Down Expand Up @@ -57,6 +58,7 @@ def add_common_redaction_args(
add_n_ctx_arg(runtime_group)
add_decode_mode_arg(decode_group)
add_discard_overlapping_predicted_spans_arg(decode_group)
add_min_score_arg(decode_group)
add_trim_whitespace_args(parser, decode_group)
add_viterbi_args(decode_group)
add_output_mode_arg(output_group)
Expand Down
21 changes: 21 additions & 0 deletions opf/_cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,24 @@ def add_viterbi_args(parser: object) -> None:
"uses all-zero transition biases."
),
)


def _min_score_arg(value: str) -> float:
"""Validate ``--min-score`` CLI input."""
try:
score = float(value)
except ValueError as exc:
raise argparse.ArgumentTypeError("--min-score must be a float") from exc
if not (0.0 <= score <= 1.0):
raise argparse.ArgumentTypeError("--min-score must be in [0.0, 1.0]")
return score


def add_min_score_arg(parser: object) -> None:
"""Add the shared minimum confidence threshold argument."""
parser.add_argument(
"--min-score",
type=_min_score_arg,
default=None,
help="Drop predicted spans with confidence score below this threshold.",
)
1 change: 1 addition & 0 deletions opf/_cli/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def build_redactor_from_args(
output_mode=args.output_mode,
discard_overlapping_predicted_spans=args.discard_overlapping_predicted_spans,
output_text_only=output_text_only,
min_score=args.min_score,
)
if args.decode_mode == "viterbi":
return redactor.set_viterbi_decoder(
Expand Down
82 changes: 74 additions & 8 deletions opf/_core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
decode_text_with_offsets,
discard_overlapping_spans_by_label,
labels_to_spans,
token_spans_to_char_spans,
trim_char_spans_whitespace,
)
from .sequence_labeling import (
ExampleAggregation,
Expand Down Expand Up @@ -67,6 +65,7 @@ class DetectedSpan:
end: int
text: str
placeholder: str
score: float


@dataclass(frozen=True)
Expand Down Expand Up @@ -113,6 +112,7 @@ def _apply_output_mode_to_detected_spans(
end=span.end,
text=span.text,
placeholder=REDACTED_OUTPUT_PLACEHOLDER,
score=span.score,
)
for span in spans
]
Expand Down Expand Up @@ -202,6 +202,47 @@ def _select_non_overlapping_spans(spans: Sequence[DetectedSpan]) -> list[Detecte
return kept


def _trim_char_spans_with_scores(
spans: Sequence[tuple[int, int, int, float]], text: str
) -> list[tuple[int, int, int, float]]:
"""Trim leading and trailing whitespace from character spans."""
trimmed: list[tuple[int, int, int, float]] = []
for label_idx, start, end, score in spans:
if not (0 <= start < end <= len(text)):
continue
while start < end and text[start].isspace():
start += 1
while end > start and text[end - 1].isspace():
end -= 1
if end > start:
trimmed.append((label_idx, start, end, score))
return trimmed


def _discard_overlapping_spans_with_scores(
spans: Sequence[tuple[int, int, int, float]],
) -> list[tuple[int, int, int, float]]:
"""Drop overlapping spans independently within each label id."""
if not spans:
return []
kept_without_scores = discard_overlapping_spans_by_label(
[(label_idx, start, end) for label_idx, start, end, _score in spans]
)
counts: dict[tuple[int, int, int], int] = {}
for key in kept_without_scores:
counts[key] = counts.get(key, 0) + 1
kept: list[tuple[int, int, int, float]] = []
for label_idx, start, end, score in spans:
key = (label_idx, start, end)
remaining = counts.get(key, 0)
if remaining <= 0:
continue
kept.append((label_idx, start, end, score))
counts[key] = remaining - 1
kept.sort(key=lambda span: (span[1], span[2], span[0]))
return kept


def load_inference_runtime(
*,
checkpoint: str,
Expand Down Expand Up @@ -348,25 +389,49 @@ def predict_text(
predicted_token_spans = labels_to_spans(
predicted_labels_by_index, runtime.label_info
)
token_row_by_index = {
int(token_idx): row_idx for row_idx, token_idx in enumerate(token_positions)
}

decoded_text, char_starts, char_ends = decode_text_with_offsets(
token_ids, runtime.encoding
)
decoded_mismatch = decoded_text != text
source_text = decoded_text if decoded_mismatch else text

predicted_char_spans = token_spans_to_char_spans(
predicted_token_spans, char_starts, char_ends
)
predicted_char_spans: list[tuple[int, int, int, float]] = []
for label_idx, token_start, token_end in predicted_token_spans:
if not (0 <= token_start < token_end <= len(char_starts)):
continue
char_start = char_starts[token_start]
char_end = char_ends[token_end - 1]
if char_end <= char_start:
continue
span_log_probs: list[float] = []
for token_idx in range(token_start, token_end):
row_idx = token_row_by_index.get(token_idx)
token_label_idx = predicted_labels_by_index.get(token_idx)
if row_idx is None or token_label_idx is None:
continue
span_log_probs.append(float(stacked_scores[row_idx, int(token_label_idx)]))
if not span_log_probs:
continue
mean_logprob = sum(span_log_probs) / float(len(span_log_probs))
predicted_char_spans.append(
(label_idx, char_start, char_end, float(math.exp(mean_logprob)))
)

if runtime.trim_span_whitespace:
predicted_char_spans = trim_char_spans_whitespace(
predicted_char_spans = _trim_char_spans_with_scores(
predicted_char_spans, source_text
)
if runtime.discard_overlapping_predicted_spans:
predicted_char_spans = discard_overlapping_spans_by_label(predicted_char_spans)
predicted_char_spans = _discard_overlapping_spans_with_scores(
predicted_char_spans
)

detected: list[DetectedSpan] = []
for label_idx, start, end in predicted_char_spans:
for label_idx, start, end, score in predicted_char_spans:
if not (0 <= start < end <= len(source_text)):
continue
label = (
Expand All @@ -381,6 +446,7 @@ def predict_text(
end=int(end),
text=source_text[start:end],
placeholder=_label_placeholder(label),
score=score,
)
)

Expand Down
Loading