-
Notifications
You must be signed in to change notification settings - Fork 0
[Multi-Agent Privacy] Detection tools implementation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
794e23e
d8503b1
11205df
ec9126f
9797115
8be4dc8
58b8572
66b0c77
f1060f9
5814294
d175fd4
6a91ebb
3eae01f
df24d81
825ff24
53e384b
a050c18
fd784c6
39d6a58
ec8001e
46bc5ab
fd49131
4c34eab
07848bb
f2ae78f
1d9c58e
9e0b74d
e5723a5
52f4e8c
a9b21bc
8f16202
570b013
e3890d1
6281137
03f3426
88ee637
42dbcdf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| """PII detection engines exposed as agent tools.""" | ||
|
|
||
| from .base import DetectionEngine, PIISpan | ||
| from .config import DetectionConfig | ||
| from .gliner_engine import GLiNEREngine, detect_pii_gliner | ||
| from .llm_engine import LLMDetectionEngine, detect_pii_llm | ||
| from .openai_filter_engine import OpenAIFilterEngine, detect_pii_openai_filter | ||
| from .presidio_engine import PresidioEngine, detect_pii_presidio | ||
|
|
||
| __all__ = [ | ||
| "DetectionConfig", | ||
| "DetectionEngine", | ||
| "GLiNEREngine", | ||
| "LLMDetectionEngine", | ||
| "OpenAIFilterEngine", | ||
| "PIISpan", | ||
| "PresidioEngine", | ||
| "detect_pii_gliner", | ||
| "detect_pii_llm", | ||
| "detect_pii_openai_filter", | ||
| "detect_pii_presidio", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """PII detection interface. | ||
|
|
||
| Each engine implements ``DetectionEngine.detect`` and returns a list of | ||
| ``PIISpan`` records. Engines are independently registered as agent tools so a | ||
| sanitizer agent can resolve them by name from YAML. | ||
| """ | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass | ||
| from typing import List | ||
|
|
||
|
|
||
| @dataclass | ||
| class PIISpan: | ||
| """A single detected PII occurrence in some text.""" | ||
|
|
||
| start: int | ||
| end: int | ||
| label: str | ||
| score: float | ||
|
|
||
|
|
||
| class DetectionEngine(ABC): | ||
| """Abstract base for PII detection backends.""" | ||
|
|
||
| @abstractmethod | ||
| def detect(self, text: str) -> List[PIISpan]: | ||
| """Return all PII spans found in ``text``.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| """Configuration dataclass for the PII detection toolkit.""" | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import List, Optional | ||
|
|
||
| from ...rag.llm import LLMConfig | ||
| from .defaults import DEFAULT_CONFIDENCE_THRESHOLD | ||
|
|
||
|
|
||
| @dataclass | ||
| class DetectionConfig: | ||
| """Schema for the ``privacy.detection`` block of a YAML config.""" | ||
|
|
||
| engine: str | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the field
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unless it's meant to be saved
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it is still not used in this PR (it will be the case most likely in the PR were we wire the privacy layer into the mmore's RAG pipeline), but I think we want to keep it because later using this parameter the user will be able to choose a specific detection engine (instead of falling back to the default one or having the Analyzer agent infer one for the task)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. then it could make sense to have an enum with the supported detection engines
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 8f16202 |
||
| entity_types: List[str] = field(default_factory=list) | ||
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD | ||
| llm: Optional[LLMConfig] = None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| """Shared defaults for the PII detection engines.""" | ||
|
|
||
| from ...rag.llm import LLMConfig | ||
|
|
||
| DEFAULT_LANGUAGE = "en" | ||
|
|
||
| DEFAULT_GLINER_MODEL = "nvidia/gliner-PII" | ||
| DEFAULT_OPENAI_FILTER_MODEL = "openai/privacy-filter" | ||
| DEFAULT_PRESIDIO_SPACY_MODEL = "en_core_web_lg" | ||
|
|
||
| DEFAULT_LLM_CONFIG = LLMConfig( | ||
| llm_name="Qwen/Qwen2.5-3B-Instruct", | ||
| max_new_tokens=512, | ||
| ) | ||
|
|
||
| DEFAULT_CONFIDENCE_THRESHOLD = 0.7 | ||
|
|
||
| # TODO: Later add new labels to the list | ||
| DEFAULT_LABELS = [ | ||
| "PERSON", | ||
| "PHONE", | ||
| "EMAIL", | ||
| "MRN", | ||
| "DATE", | ||
| "LOCATION", | ||
| "SSN", | ||
| "INSURANCE_ID", | ||
| ] | ||
|
|
||
| # TODO: Later add new patterns to the list | ||
| PRESIDIO_CLINICAL_PATTERNS = [ | ||
| { | ||
| "entity": "MRN", | ||
| "patterns": [ | ||
| ("mrn_with_prefix", r"\bMRN[\s:#]*\d{6,10}\b", 0.9), | ||
| ("mrn_bare_8_digits", r"\b\d{8}\b", 0.4), | ||
| ], | ||
| "context": ["mrn", "medical record", "record number", "patient id"], | ||
| }, | ||
| { | ||
| "entity": "HOSPITAL_DATE", | ||
| "patterns": [ | ||
| ("iso_date", r"\b\d{4}-\d{2}-\d{2}\b", 0.6), | ||
| ("us_date", r"\b\d{1,2}/\d{1,2}/\d{4}\b", 0.6), | ||
| ], | ||
| "context": ["admission", "discharge", "appointment", "hospital", "clinic"], | ||
| }, | ||
| { | ||
| "entity": "INSURANCE_ID", | ||
| "patterns": [ | ||
| ("insurance_alnum", r"\b[A-Z]{2,3}\d{6,12}\b", 0.7), | ||
| ], | ||
| "context": ["insurance", "policy", "member id", "subscriber"], | ||
| }, | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| """GLiNER-based PII detection engine.""" | ||
|
|
||
| import logging | ||
| import threading | ||
| from typing import Any, Dict, List, Optional, Sequence | ||
|
|
||
| from typing_extensions import Self | ||
|
|
||
| from ..agents.registry import register_tool | ||
| from .base import DetectionEngine, PIISpan | ||
| from .config import DetectionConfig | ||
| from .defaults import ( | ||
| DEFAULT_CONFIDENCE_THRESHOLD, | ||
| DEFAULT_GLINER_MODEL, | ||
| DEFAULT_LABELS, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _model_cache: Dict[str, Any] = {} | ||
|
fabnemEPFL marked this conversation as resolved.
Outdated
|
||
| _model_cache_lock = threading.Lock() | ||
|
|
||
|
|
||
| def _load_gliner_model(model_name: str) -> Any: | ||
|
JCHAVEROT marked this conversation as resolved.
Outdated
|
||
| from gliner import GLiNER | ||
|
|
||
| return GLiNER.from_pretrained(model_name) | ||
|
|
||
|
|
||
| def _get_or_load_model(model_name: str) -> Any: | ||
|
JCHAVEROT marked this conversation as resolved.
Outdated
|
||
| cached = _model_cache.get(model_name) | ||
| if cached is not None: | ||
| return cached | ||
| with _model_cache_lock: | ||
| cached = _model_cache.get(model_name) | ||
| if cached is None: | ||
| cached = _load_gliner_model(model_name) | ||
| _model_cache[model_name] = cached | ||
| return cached | ||
|
|
||
|
|
||
| def clear_gliner_cache() -> None: | ||
| """Drop all cached GLiNER models.""" | ||
| with _model_cache_lock: | ||
| _model_cache.clear() | ||
|
|
||
|
|
||
| class GLiNEREngine(DetectionEngine): | ||
| """Detect PII spans with a GLiNER model. | ||
|
|
||
| Each instance carries its own ``entity_types`` and ``confidence_threshold``, | ||
| models with the same ``model_name`` are shared via ``_models_cache``. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str = DEFAULT_GLINER_MODEL, | ||
| entity_types: Optional[Sequence[str]] = None, | ||
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, | ||
| ): | ||
| self._model_name = model_name | ||
| self._entity_types: List[str] = ( | ||
| list(entity_types) if entity_types else list(DEFAULT_LABELS) | ||
| ) | ||
| self._confidence_threshold = confidence_threshold | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: DetectionConfig) -> Self: | ||
| """Build an engine from a ``DetectionConfig``.""" | ||
| return cls( | ||
| entity_types=config.entity_types or None, | ||
| confidence_threshold=config.confidence_threshold, | ||
| ) | ||
|
|
||
| @property | ||
| def model(self) -> Any: | ||
|
JCHAVEROT marked this conversation as resolved.
Outdated
|
||
| """Lazy-load and cache the LLM on first access.""" | ||
| return _get_or_load_model(self._model_name) | ||
|
|
||
| def detect(self, text: str) -> List[PIISpan]: | ||
| raw = self.model.predict_entities( | ||
| text=text, | ||
| labels=self._entity_types, | ||
| threshold=self._confidence_threshold, | ||
| multi_label=False, | ||
| ) | ||
| return [ | ||
| PIISpan( | ||
| start=int(r["start"]), | ||
| end=int(r["end"]), | ||
| label=str(r["label"]), | ||
| score=float(r["score"]), | ||
| ) | ||
| for r in raw | ||
| ] | ||
|
|
||
|
|
||
| @register_tool("detect_pii_gliner") | ||
| def detect_pii_gliner(text: str) -> List[PIISpan]: | ||
| """Detect PII spans in ``text`` using a default-configured GLiNER engine. | ||
|
|
||
| Agents needing per-config behavior should be wired by setup code that | ||
| builds a ``GLiNEREngine.from_config(detection_cfg)`` and registers its | ||
| ``detect()`` function under a distinct tool name, e.g.:: | ||
|
|
||
| engine = GLiNEREngine.from_config(detection_cfg) | ||
| register_tool("detect_pii_gliner_custom", engine.detect) | ||
| """ | ||
| return GLiNEREngine().detect(text) | ||
Uh oh!
There was an error while loading. Please reload this page.