Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"httpx>=0.28.1",
"hydra-core>=1.3.2",
"lettucedetect>=0.1.8",
"nltk>=3.9.1",
"nltk>=3.9.4",
"protobuf>=6.32.1",
"python-dotenv>=1.0.1",
"openai>=1.4.0",
Expand Down Expand Up @@ -116,6 +116,15 @@ ignore = [
"src/scripts/train_hallucination_detector.py" = [
"E501",
]
"src/scripts/translate.py" = [
"E501",
]
"src/scripts/generate_hallucination_dataset.py" = [
"E501",
]
"src/scripts/train_hallucination_detector.py" = [
"E501",
]

[tool.ruff.lint.isort]
split-on-trailing-comma = false
Expand Down
123 changes: 123 additions & 0 deletions src/scripts/evaluate_ground_truth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Validate the trained hallucination detector against token-level labels.

This script loads the synthetic-hallucinations dataset from the HuggingFace Hub
(it is not regenerated), uses its ``hallucinated_labels`` column as ground truth,
and computes token-level metrics via ``lettucedetect``'s ``evaluate_model``. Its
purpose is to check whether the trained detection method can identify
hallucinated tokens.

Usage:
uv run src/scripts/evaluate_ground_truth.py <config_key>=<config_value> ...
"""

import logging
import os

import hydra
import torch
from datasets import load_dataset
from dotenv import load_dotenv
from lettucedetect import HallucinationDataset
from lettucedetect.models.evaluator import evaluate_model, print_metrics
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
)

from factuality_eval.dataset_generation import (
generate_lettucedetect_hallucination_samples,
)
from factuality_eval.train import format_dataset_to_ragtruth

load_dotenv()

logger = logging.getLogger("evaluate_ground_truth")


def _training_sources_suffix(config: DictConfig) -> str:
"""Return the model-name suffix used by the training script."""
sources = []
if config.multiwikiqa.enable:
sources.append("mwqa")
if config.ragtruth.enable:
sources.append("ragtruth")
return f"-{'+'.join(sources)}" if sources else ""


def _resolve_model_path(config: DictConfig, target_dataset_name: str) -> str:
"""Return the local model directory if it exists, else the Hub repo id."""
suffix = _training_sources_suffix(config)
local_path = (
f"{config.training.output_dir}/"
f"{config.models.hallu_detect_model}-{target_dataset_name}-"
f"{config.language}{suffix}"
)
if os.path.isdir(local_path):
logger.info(f"Using local model checkpoint at {local_path}")
return local_path

hub_path = (
f"{config.hub_organisation}/"
f"{config.models.hallu_detect_model}-{target_dataset_name}-"
f"{config.language}{suffix}"
)
logger.info(f"Local checkpoint not found; using Hub model {hub_path}")
return hub_path


@hydra.main(
config_path="../../config", config_name="hallucination_detection", version_base=None
)
def main(config: DictConfig) -> None:
"""Run token-level validation for the trained hallucination detector."""
logging.getLogger("httpx").setLevel(logging.WARNING)

target_dataset_name = f"{config.base_dataset.id}-synthetic-hallucinations"
model_path = _resolve_model_path(config, target_dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info("Running token-level ground-truth evaluation")

dataset = load_dataset(
f"{config.hub_organisation}/{target_dataset_name}", name=config.language
)
test_split = dataset["train"].train_test_split(
test_size=0.2, seed=42, shuffle=False
)["test"]

test_ragtruth = format_dataset_to_ragtruth(
test_split, language=config.language, split="test"
)

tokenizer = AutoTokenizer.from_pretrained(
config.models.pretrained_model, trust_remote_code=True
)
data_collator = DataCollatorForTokenClassification(
tokenizer=tokenizer, label_pad_token_id=-100
)
test_hallu_dataset = HallucinationDataset(
generate_lettucedetect_hallucination_samples(test_ragtruth),
tokenizer,
max_length=config.training.max_length,
)
test_loader = DataLoader(
test_hallu_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=data_collator,
)

model = AutoModelForTokenClassification.from_pretrained(
model_path, trust_remote_code=True
)
model.to(device)

metrics = evaluate_model(model, test_loader, device)
print_metrics(metrics)


if __name__ == "__main__":
main()
175 changes: 175 additions & 0 deletions src/scripts/generate_hallucination_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Generate hallucination-annotated dataset.

Usage:
uv run src/scripts/generate_hallucination_dataset.py <config_key>=<config_value> ...

The script generates answers with the eval model and then tags hallucinated
segments using the hallucination classifier. The output JSONL contains per-row
question, context, answer, and hallucinated tokens with their character spans
(relative to the answer string).
"""

from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any

import hydra
from datasets import Dataset
from dotenv import load_dotenv
from lettucedetect.models.inference import HallucinationDetector
from omegaconf import DictConfig

from factuality_eval.dataset_generation import load_qa_data
from factuality_eval.model_generation import generate_answers_from_qa_data
from factuality_eval.prompt_utils import Lang, PromptUtils

load_dotenv()

logger = logging.getLogger(__name__)


def _build_detector_model_path(config: DictConfig) -> str:
target_dataset_name = f"{config.base_dataset.id}-synthetic-hallucinations"
return (
f"{config.hub_organisation}/"
f"{config.models.hallu_detect_model}-{target_dataset_name}-{config.language}"
)


def _format_context(context: list | tuple | str) -> list[str]:
"""Normalize an arbitrary context value to a list of strings."""
if isinstance(context, (list, tuple)):
return [str(c) for c in context]
return [str(context)]


def _extract_hallucinated_tokens(
detector: HallucinationDetector,
context: list[str],
question: str,
answer: str,
lang: Lang,
) -> list[dict[str, Any]]:
prompt = PromptUtils.format_context(context, question, lang)

# Use detector's span output to avoid offset drift between our formatter and the model tokenizer.
spans = detector.predict_prompt(prompt=prompt, answer=answer, output_format="spans")

# Normalize span schema without confidence/probability.
normalized: list[dict[str, Any]] = []
for span in spans:
normalized.append(
{
"text": span.get("text", ""),
"start": span.get("start", 0),
"end": span.get("end", 0),
}
)

return normalized


@hydra.main(
config_path="../../config", config_name="hallucination_detection", version_base=None
)
def main(config: DictConfig) -> None:
"""Run hallucination annotation over the configured QA dataset."""
logging.getLogger("httpx").setLevel(logging.WARNING)

target_dataset_name = (
f"{config.base_dataset.id}-{config.language}-"
f"{config.models.eval_model.split('/')[1]}"
)

logger.info(
"Loading dataset %s for hallucination annotation...",
f"{config.base_dataset.organisation}/{config.base_dataset.id}:{config.language}",
)

contexts, questions, answers = load_qa_data(
base_dataset_id=f"{config.base_dataset.organisation}/{config.base_dataset.id}:{config.language}",
split="test",
context_key=config.base_dataset.context_key,
question_key=config.base_dataset.question_key,
answer_key=config.base_dataset.answer_key,
squad_format=config.base_dataset.squad_format,
testing=config.testing,
max_examples=config.generation.max_examples,
)

generated_answers = generate_answers_from_qa_data(
eval_model=config.models.eval_model,
contexts=contexts,
questions=questions,
answers=answers,
lang=config.language,
max_new_tokens=config.generation.max_new_tokens,
output_jsonl_path=Path("data", "final", f"{target_dataset_name}.jsonl"),
)

detector_model_path = _build_detector_model_path(config)
logger.info("Loading hallucination detector: %s", detector_model_path)

if len(generated_answers) == 0:
logger.warning("No generated answers to annotate. Skipping.")
return

detector = HallucinationDetector(
method="transformer",
model_path=detector_model_path,
device_map="auto",
torch_dtype="auto",
)

# Build a lookup from question text to ground-truth answer so that
# the ground truth stays aligned even when generate_answers_from_qa_data
# reorders or skips entries (e.g. from caching or errors).
gt_lookup: dict[str, str] = {q: a for q, a in zip(questions, answers)}

records: list[dict[str, Any]] = []
logger.info("Annotating %d generated answers...", len(generated_answers))

for context, question, answer in zip(
generated_answers["context"],
generated_answers["question"],
generated_answers["answer"],
):
ground_truth_answer = gt_lookup.get(question, "")
formatted_context = _format_context(context)
hallucinated_tokens = _extract_hallucinated_tokens(
detector=detector,
context=formatted_context,
question=question,
answer=answer,
lang=config.language,
)

records.append(
{
"context": formatted_context,
"question": question,
"ground_truth_answer": ground_truth_answer,
"answer": answer,
"hallucinated_tokens": hallucinated_tokens,
}
)

output_path = Path("data", "final", f"{target_dataset_name}-hallucinations.jsonl")
output_path.parent.mkdir(parents=True, exist_ok=True)

logger.info("Writing hallucination-annotated dataset to %s", output_path)
with output_path.open("w") as f:
for record in records:
f.write(json.dumps(record, ensure_ascii=False) + "\n")

dataset = Dataset.from_list(records)
dataset.save_to_disk(str(output_path) + "_hf")

logger.info("Done. Wrote %d records.", len(records))


if __name__ == "__main__":
main()
Loading
Loading