Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from . import pawsx
from . import xnli
from . import mgsm
from . import xquad

########################################
# Translation tasks
Expand Down Expand Up @@ -92,6 +93,19 @@


TASK_REGISTRY = {
# Xquad
"xquad_ar": xquad.XquadAR,
"xquad_de": xquad.XquadDE,
"xquad_el": xquad.XquadEL,
"xquad_en": xquad.XquadEN,
"xquad_es": xquad.XquadES,
"xquad_hi": xquad.XquadHI,
"xquad_ro": xquad.XquadRO,
"xquad_ru": xquad.XquadRU,
"xquad_th": xquad.XquadTH,
"xquad_tr": xquad.XquadTR,
"xquad_vi": xquad.XquadVI,
"xquad_zh": xquad.XquadZH,
# GLUE
"cola": glue.CoLA,
"mnli": glue.MNLI,
Expand Down
285 changes: 285 additions & 0 deletions lm_eval/tasks/xquad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
"""
XQuAD (Cross-lingual Question Answering Dataset) is a benchmark dataset for
evaluating cross-lingual question answering performance. The dataset consists of
a subset of 240 paragraphs and 1190 question-answer pairs from the development set
of SQuAD v1.1 (Rajpurkar et al., 2016) together with their professional translations
into ten languages: Spanish, German, Greek, Russian, Turkish, Arabic, Vietnamese,
Thai, Chinese, and Hindi. Consequently, the dataset is entirely parallel across 11 languages.

huggingface: https://huggingface.co/datasets/xquad
"""
import datasets
from math import exp
from lm_eval.base import rf, Task
from functools import partial
from packaging import version


_CITATION = """
@article{Artetxe:etal:2019,
author = {Mikel Artetxe and Sebastian Ruder and Dani Yogatama},
title = {On the cross-lingual transferability of monolingual representations},
journal = {CoRR},
volume = {abs/1910.11856},
year = {2019},
archivePrefix = {arXiv},
eprint = {1910.11856}
}

@inproceedings{
dumitrescu2021liro,
title={LiRo: Benchmark and leaderboard for Romanian language tasks},
author={Stefan Daniel Dumitrescu and Petru Rebeja and Beata Lorincz and Mihaela Gaman and Andrei Avram and Mihai Ilie and Andrei Pruteanu and Adriana Stan and Lorena Rosia and Cristina Iacobescu and Luciana Morogan and George Dima and Gabriel Marchidan and Traian Rebedea and Madalina Chitez and Dani Yogatama and Sebastian Ruder and Radu Tudor Ionescu and Razvan Pascanu and Viorica Patraucean},
booktitle={Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 1)},
year={2021},
url={https://openreview.net/forum?id=JH61CD7afTv}
}
"""


def _squad_metric(predictions, references):
squad_metric = datasets.load_metric("squad_v2")
return squad_metric.compute(predictions=predictions, references=references)


def _squad_agg(key, items):
predictions, references = zip(*items)

return _squad_metric(predictions=predictions, references=references).get(key, 0)


class XquadTask(Task):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = None

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return False

def validation_docs(self):
return self.dataset["validation"]

def doc_to_text(self, doc):
return (
"Background: "
+ doc["context"]
+ "\n\n"
+ "Question: "
+ doc["question"]
+ "\n\n"
+ "Answer:"
)

def should_decontaminate(self):
return True

def doc_to_decontamination_query(self, doc):
return doc["context"]

def doc_to_target(self, doc):
answer_list = doc["answers"]["text"]
if len(answer_list) > 0:
answer = answer_list[0]
else:
answer = "unanswerable"
return " " + answer

def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.

:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n"]})
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
return continuation, is_unanswerable

def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document

:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
continuation, (logprob_unanswerable, _) = results

no_answer_probability = exp(logprob_unanswerable)

predictions = {
"id": doc["id"],
"prediction_text": continuation,
"no_answer_probability": no_answer_probability,
}

references = {
"id": doc["id"],
"answers": doc["answers"],
}

return {
"exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": (
predictions,
references,
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": (
predictions,
references,
), # The F-score of predicted tokens versus the gold answer
"best_exact": (
predictions,
references,
), # Best exact match (with varying threshold)
"best_f1": (predictions, references), # Best F1 (with varying threshold)
}

def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return {
"exact": partial(
_squad_agg, "exact"
), # Exact match (the normalized answer exactly match the gold answer)
"f1": partial(
_squad_agg, "f1"
), # The F-score of predicted tokens versus the gold answer
"HasAns_exact": partial(
_squad_agg, "HasAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": partial(
_squad_agg, "HasAns_f1"
), # The F-score of predicted tokens versus the gold answer
"NoAns_exact": partial(
_squad_agg, "NoAns_exact"
), # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": partial(
_squad_agg, "NoAns_f1"
), # The F-score of predicted tokens versus the gold answer
"best_exact": partial(
_squad_agg, "best_exact"
), # Best exact match (with varying threshold)
"best_f1": partial(
_squad_agg, "best_f1"
), # Best F1 (with varying threshold)
}

def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
"f1": True, # The F-score of predicted tokens versus the gold answer
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
"best_exact": True, # Best exact match (with varying threshold)
"best_f1": True, # Best F1 (with varying threshold)
}


class XquadAR(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.ar"


class XquadDE(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.de"


class XquadEL(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.el"


class XquadEN(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.en"


class XquadES(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.es"


class XquadHI(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.hi"


class XquadRO(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.ro"


class XquadRU(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.ru"


class XquadTH(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.th"


class XquadTR(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.tr"


class XquadVI(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.vi"


class XquadZH(XquadTask):
VERSION = 0
DATASET_PATH = "xquad"
DATASET_NAME = "xquad.zh"