Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions config/hallucination_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ training:
learning_rate: 1e-5
weight_decay: 0.01
push_to_hub: True
max_length: 8192
max_length: 4096

models:
hallu_detect_model: mmBERT-small
Expand All @@ -40,17 +40,6 @@ models:

language: da

selfcheckgpt:
num_samples: 10
sampling_temperature: 1.0
reference_temperature: 0.0
reference_do_sample: false
prompt_model: gpt-4o-mini
output_dir: data/final/selfcheckgpt
max_retries: 3
request_timeout: null
context_char_limit: null

generation:
max_examples: 1000
max_new_tokens: 32768
max_new_tokens: 256
7 changes: 6 additions & 1 deletion src/factuality_eval/dataset_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def load_qa_data(

if len(ds.keys()) > 1: # Dataset is already split
ds = ds[split]
elif "train" in ds:
ds = ds["train"].train_test_split(test_size=0.2, seed=42)[split]
else:
ds = ds[split].train_test_split(test_size=0.2, seed=42)[split]
raise Exception(
"Dataset can not be split into test and train. Please check if"
Comment thread
FrejaThoresen marked this conversation as resolved.
Outdated
"'train' is a subset of the dataset."
)

logger.info("Preparing dataset...")
contexts: list[list[str]] = [[ctx] for ctx in ds[context_key]]
Expand Down
10 changes: 3 additions & 7 deletions src/factuality_eval/hallucination_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@ def detect_hallucinations(
dataset["context"], dataset["question"], dataset["answer"]
):
# Use the detector to predict if the answer is hallucinated
try:
predict_answer = detector.predict(
context=context, question=question, answer=answer
)
except Exception as e:
logger.error(f"Error during hallucination detection: {e}. Skipping...")
continue
predict_answer = detector.predict(
context=context, question=question, answer=answer
)
predict_answers.append(predict_answer)
Comment thread
FrejaThoresen marked this conversation as resolved.

if "hallucinated_parts" in dataset.column_names:
Expand Down
35 changes: 28 additions & 7 deletions src/factuality_eval/model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import Iterable, cast
Expand All @@ -24,6 +25,18 @@

logger = logging.getLogger(__name__)

# Pattern to strip markdown bold/italic markers from model output
_MD_MARKERS_RE = re.compile(r"(\*{1,3}|_{1,3})(.+?)\1")


def _strip_markdown(text: str) -> str:
"""Remove markdown bold/italic markers from text.

Replaces ``**bold**``, ``*italic*``, ``***both***`` (and underscore
equivalents) with just the inner text.
"""
return _MD_MARKERS_RE.sub(r"\2", text)


def generate_single_answer(
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -53,10 +66,11 @@ def generate_single_answer(
prompt = PromptUtils.format_context(list(context), question, lang=lang)
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False, enable_thinking=False
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
input_length = model_inputs["input_ids"].shape[-1]

# Only include temperature in generation parameters if it's specified
generation_kwargs: dict[str, int | float] = {"max_new_tokens": max_new_tokens}
Expand All @@ -66,17 +80,21 @@ def generate_single_answer(
generated_ids = model.generate( # type: ignore[operator]
**model_inputs, **generation_kwargs
)
output_ids: list[int] = cast(torch.Tensor, generated_ids)[0].tolist()

content = tokenizer.decode(output_ids, skip_special_tokens=False)
# Only decode newly generated tokens, excluding the input prompt
output_ids = cast(torch.Tensor, generated_ids)[0][input_length:].tolist()

content = tokenizer.decode(output_ids, skip_special_tokens=True)

# Clear generated content of special tokens
if tokenizer.bos_token is not None:
content = content.split(tokenizer.bos_token)[-1]
if "</think>" in content:
content = content.split("</think>")[-1]
content.replace(tokenizer.eos_token, "")
content.replace("\n", "")
eos_token = tokenizer.eos_token
if eos_token:
content = content.replace(eos_token, "")
for special_token in tokenizer.all_special_tokens:
content = content.replace(special_token, "")
content = content.strip()

return content

Expand Down Expand Up @@ -203,6 +221,9 @@ def generate_answers_from_qa_data(
logger.error(f"Error during generation: {e}. Skipping...")
continue

# Strip markdown bold/italic markers that models sometimes add
answer = _strip_markdown(answer)

record = dict(hash=hash_, context=context, question=question, answer=answer)
records.append(record)
hashes.add(hash_)
Expand Down
115 changes: 2 additions & 113 deletions src/factuality_eval/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,6 @@
"uk",
]

LANG_TO_PASSAGE = {
"bs": "odlomak", # Bosnian
"bg": "пасаж", # Bulgarian
"ca": "passatge", # Catalan
"hr": "odlomak", # Croatian
"cs": "pasáž", # Czech
"da": "afsnit", # Danish
"nl": "passage", # Dutch
"en": "passage", # English
"et": "lõik", # Estonian
"fo": "grein", # Faroese
"fi": "kappale", # Finnish
"fr": "passage", # French
"de": "Passage", # German
"el": "απόσπασμα", # Greek
"hu": "szövegrészlet", # Hungarian
"is": "efnisgrein", # Icelandic
"it": "brano", # Italian
"lv": "posms", # Latvian
"lt": "ištrauka", # Lithuanian
"no": "avsnitt", # Norwegian
"pl": "fragment", # Polish
"pt": "passagem", # Portuguese
"ro": "pasaj", # Romanian
"sr": "одломак", # Serbian
"sk": "pasáž", # Slovak
"sl": "odlomek", # Slovenian
"es": "pasaje", # Spanish
"sv": "stycke", # Swedish
"uk": "уривок", # Ukrainian
}

LANG_TO_FULL_NAME = {
"bs": "Bosnian",
"bg": "Bulgarian",
Expand Down Expand Up @@ -103,70 +71,6 @@
"uk": "Ukrainian",
}

YES_WORDS = {
"bs": "da",
"bg": "да",
"ca": "sí",
"hr": "da",
"cs": "ano",
"da": "ja",
"nl": "ja",
"en": "yes",
"et": "jah",
"fo": "ja",
"fi": "kyllä",
"fr": "oui",
"de": "ja",
"el": "ναι",
"hu": "igen",
"is": "já",
"it": "sì",
"lv": "jā",
"lt": "taip",
"no": "ja",
"pl": "tak",
"pt": "sim",
"ro": "da",
"sr": "да",
"sk": "áno",
"sl": "da",
"es": "sí",
"sv": "ja",
"uk": "так",
}

NO_WORDS = {
"bs": "ne",
"bg": "не",
"ca": "no",
"hr": "ne",
"cs": "ne",
"da": "nej",
"nl": "nee",
"en": "no",
"et": "ei",
"fo": "nei",
"fi": "ei",
"fr": "non",
"de": "nein",
"el": "όχι",
"hu": "nem",
"is": "nei",
"it": "no",
"lv": "nē",
"lt": "ne",
"no": "nei",
"pl": "nie",
"pt": "não",
"ro": "nu",
"sr": "не",
"sk": "nie",
"sl": "ne",
"es": "no",
"sv": "nej",
"uk": "ні",
}

PROMPT_DIR = Path(__file__).parent.parent / "prompts"


Expand All @@ -193,16 +97,6 @@ def load_prompt(filename: str) -> Template:
raise FileNotFoundError(f"Prompt file not found: {path}")
return Template(path.read_text(encoding="utf-8"))

@staticmethod
def load_selfcheckgpt_prompt(context: str, sentence: str, lang: Lang) -> str:
"""Load the SelfCheckGPT prompt template.

Returns:
Template object for the SelfCheckGPT prompt.
"""
tmpl = PromptUtils.load_prompt(f"selfcheckgpt_prompt_{lang.lower()}.txt")
return tmpl.substitute(context=context, sentence=sentence)

@staticmethod
def format_context(context: list[str], question: str | None, lang: Lang) -> str:
"""Format context and question into a prompt.
Expand All @@ -218,19 +112,14 @@ def format_context(context: list[str], question: str | None, lang: Lang) -> str:
Returns:
Formatted prompt.
"""
passage_word = LANG_TO_PASSAGE[lang]
ctx_block = "\n".join(
f"{passage_word} {i + 1}: {p}" for i, p in enumerate(context)
)
ctx_block = "\n".join(context)

if question is None:
tmpl = PromptUtils.load_prompt(f"summary_prompt_{lang.lower()}.txt")
return tmpl.substitute(text=ctx_block)

tmpl = PromptUtils.load_prompt(f"qa_prompt_{lang.lower()}.txt")
return tmpl.substitute(
question=question, num_passages=len(context), context=ctx_block
)
return tmpl.substitute(question=question, text=ctx_block)
Comment thread
FrejaThoresen marked this conversation as resolved.
Outdated

@staticmethod
def get_full_language_name(lang: Lang) -> str:
Expand Down
105 changes: 0 additions & 105 deletions src/factuality_eval/selfcheckgpt.py

This file was deleted.

Loading