Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed API of Key Value Extraction task to use Dict and not List[Tuple] (NON BACKWARD COMPATIBLE CHANGE) #1675

Merged
merged 34 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d0890f9
Moved key value extraction task and metrics to use Dict[str,str] to …
yoavkatz Mar 12, 2025
583f4d3
Updated metric to better address in case keys are used in prediction …
yoavkatz Mar 12, 2025
d70b4cc
Added ability to customize inner metric used to compare each entity i…
yoavkatz Mar 12, 2025
abd0d7f
Fixed unitest and bug with handling nones
yoavkatz Mar 12, 2025
e4e32f4
Fixed KeyValueExtraction prepare
yoavkatz Mar 12, 2025
223f3a7
Added example of multiple metrics
yoavkatz Mar 12, 2025
25ed3ea
Fix catalog
yoavkatz Mar 12, 2025
1ddb3bb
Fix some bugs in inference engine tests
elronbandel Mar 13, 2025
ecb1391
Fix some bugs in inference engine tests
elronbandel Mar 13, 2025
bbdbfd8
Updated key value extraction metric names
yoavkatz Mar 16, 2025
989da3f
Updated key value extraction metric names
yoavkatz Mar 16, 2025
f8ca4fe
Updated documentation string
yoavkatz Mar 16, 2025
01774f5
Fixed unit test.
yoavkatz Mar 16, 2025
7b4e0af
Updated to use metric as artifact and not string
yoavkatz Mar 17, 2025
1ae750c
Merge branch 'main' into entity_squad_metric
yoavkatz Mar 17, 2025
4ce63b8
Fix bug in name conversion in rits
elronbandel Mar 17, 2025
5e15eef
Add engine id
elronbandel Mar 17, 2025
1794474
Improved output message when using inference cache
yoavkatz Mar 17, 2025
955fb85
Fixed bug due to indentation change
yoavkatz Mar 17, 2025
9900113
Merge branch 'main' into fix-inference-tests
elronbandel Mar 17, 2025
8b91761
fix
elronbandel Mar 17, 2025
d7d9fb6
fix
elronbandel Mar 17, 2025
810f26b
Removed warning of legacy name.
yoavkatz Mar 17, 2025
91fad1f
Merge remote-tracking branch 'origin/main' into fix-inference-tests
elronbandel Mar 17, 2025
049ec22
Merge branch 'main' into entity_squad_metric
yoavkatz Mar 17, 2025
4f17a6c
Use greedy decoding and remove redundant cache
elronbandel Mar 17, 2025
125ad9c
Merge branch 'main' into fix-inference-tests
elronbandel Mar 17, 2025
47b74a9
Merge branch 'fix-inference-tests' into entity_squad_metric
yoavkatz Mar 17, 2025
d7959fc
Merge branch 'improve_inference_log' into entity_squad_metric
yoavkatz Mar 18, 2025
3b3ea38
Merge branch 'improve_inference_log' into entity_squad_metric
yoavkatz Mar 18, 2025
e75e625
Ensure temperature is 0 in extraction task
yoavkatz Mar 18, 2025
f78af39
Merge remote-tracking branch 'origin/main' into entity_squad_metric
yoavkatz Mar 18, 2025
440120c
Merge remote-tracking branch 'origin/main' into entity_squad_metric
yoavkatz Mar 19, 2025
3ea0056
Removed unneeded changes from past merge
yoavkatz Mar 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions examples/api_call_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Tuple
from typing import Dict, List, Tuple

from unitxt import get_logger
from unitxt.api import create_dataset, evaluate
Expand Down Expand Up @@ -205,7 +205,7 @@ class CurlStrToListOfKeyValuePairs(FieldOperator):

becomes

[('url', 'curl -X GET -H "Content-Type: application/json" https://petstore.swagger.io/v2/pets'), ('tags', 'dogs'), ('limit', '5')]
{ 'url' : 'curl -X GET -H "Content-Type: application/json" https://petstore.swagger.io/v2/pets', 'tags' : 'dogs', 'limit' : '5'}

"""

Expand All @@ -217,11 +217,11 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:

splits = text.split("?")
split_command = re.split(r"((?=GET|POST|DELETE)GET|POST|DELETE)", splits[0])
result = [
("command", split_command[0]),
("operation", split_command[1]),
("url", split_command[2]),
]
result = {
"command": split_command[0],
"operation": split_command[1],
"url": split_command[2],
}
if len(splits) == 1:
return result
params = splits[1]
Expand All @@ -234,7 +234,7 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:
(key, value) = key_value_splits
value_splits = value.split(",")
if len(value_splits) == 1:
result.append((f"query_param_{key}", f"{value}"))
result[f"query_param_{key}"]= f"{value}"

return result

Expand All @@ -249,10 +249,9 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:
task = Task(
input_fields={"user_request": str, "api_spec": str},
reference_fields={"reference_query": str},
prediction_type=List[Tuple[str, str]],
prediction_type=Dict[str,str],
metrics=[
"metrics.accuracy",
"metrics.key_value_extraction",
"metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap",
],
)

Expand Down
7 changes: 4 additions & 3 deletions examples/key_value_extraction_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def text_to_image(text: str):

test_set = [
{
"input": text_to_image("John lives in Texas."),
"input": text_to_image("John lives in New York."),
"keys": keys,
"key_value_pairs_answer": {"Worker": "John", "LivesIn": "Texas"},
"key_value_pairs_answer": {"Worker": "John", "LivesIn": "New York"},
},
{
"input": text_to_image("Phil works at Apple and eats an apple."),
Expand All @@ -53,10 +53,11 @@ def text_to_image(text: str):
test_set=test_set,
split="test",
format="formats.chat_api",
metrics=["metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap"]
)

model = CrossProviderInferenceEngine(
model="llama-3-2-11b-vision-instruct", provider="watsonx"
model="llama-3-2-90b-vision-instruct", provider="watsonx", temperature=0
)

predictions = model(dataset)
Expand Down
49 changes: 1 addition & 48 deletions prepare/metrics/custom_f1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unitxt import add_to_catalog
from unitxt.metrics import NER, KeyValueExtraction
from unitxt.metrics import NER
from unitxt.test_utils.metrics import test_metric

metric = NER()
Expand Down Expand Up @@ -434,50 +434,3 @@ class NERWithoutClassReporting(NER):
)

add_to_catalog(metric, "metrics.ner", overwrite=True)


metric = KeyValueExtraction()

predictions = [
[("key1", "value1"), ("key2", "value2"), ("unknown_key", "unknown_value")]
]

references = [[[("key1", "value1"), ("key2", "value3")]]]
#
instance_targets = [
{
"f1_key1": 1.0,
"f1_key2": 0.0,
"f1_macro": 0.5,
"f1_micro": 0.4,
"in_classes_support": 0.67,
"precision_macro": 0.5,
"precision_micro": 0.33,
"recall_macro": 0.5,
"recall_micro": 0.5,
"score": 0.4,
"score_name": "f1_micro",
}
]
global_target = {
"f1_key1": 1.0,
"f1_key2": 0.0,
"f1_macro": 0.5,
"in_classes_support": 0.67,
"f1_micro": 0.4,
"recall_micro": 0.5,
"recall_macro": 0.5,
"precision_micro": 0.33,
"precision_macro": 0.5,
"score": 0.4,
"score_name": "f1_micro",
"num_of_instances": 1,
}
outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
add_to_catalog(metric, "metrics.key_value_extraction", overwrite=True)
36 changes: 36 additions & 0 deletions prepare/metrics/key_value_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from unitxt import add_to_catalog
from unitxt.metrics import KeyValueExtraction
from unitxt.test_utils.metrics import test_metric

metric = KeyValueExtraction(__description__ = """Metric that evaluates key value pairs predictions (provided as dictionaries)
with reference key value pairs (also provided as dictionaries). By default uses an accuracy (exact match) between each for the fields.
Reports average accuracy for each key , as well as micro and macro averages across all keys.
""", metric="metrics.accuracy",)

predictions = [
{"key1": "value1", "key2": "value2", "unknown_key": "unknown_value"}
]

references = [[{"key1": "value1", "key2" : "value3"}]]
#
instance_targets = [
{"accuracy_key1": 1.0, "accuracy_key2": 0.0, "accuracy_legal_keys_in_predictions": 0.67, "accuracy_macro": 0.5, "accuracy_micro": 0.5, "score": 0.5, "score_name": "accuracy_micro"}
]
global_target = {"accuracy_key1": 1.0, "accuracy_key2": 0.0, "accuracy_legal_keys_in_predictions": 0.67, "accuracy_macro": 0.5, "accuracy_micro": 0.5, "score": 0.5, "score_name": "accuracy_micro", "num_of_instances" : 1}
outputs = test_metric(
metric=metric,
predictions=predictions,
references=references,
instance_targets=instance_targets,
global_target=global_target,
)
add_to_catalog(metric, "metrics.key_value_extraction.accuracy", overwrite=True)

metric = KeyValueExtraction(__description__ = """Metric that evaluates key value pairs predictions (provided as dictionary)
with reference key value pairs (also provided as dictionary).
Calculates token overlap between values of corresponding value in reference and prediction.
Reports f1 per key, micro f1 averages across all key/value pairs, and macro f1 averages across keys.
""",
metric="metrics.token_overlap",score_prefix="token_overlap_")

add_to_catalog(metric, "metrics.key_value_extraction.token_overlap", overwrite=True)
6 changes: 3 additions & 3 deletions prepare/tasks/key_value_extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List

from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog
Expand All @@ -8,8 +8,8 @@
__description__="This is a key value extraction task, where a specific list of possible 'keys' need to be extracted from the input. The ground truth is provided key-value pairs in the form of the dictionary. The results are evaluating using F1 score metric, that expects the predictions to be converted into a list of (key,value) pairs. ",
input_fields={"input": Any, "keys": List[str]},
reference_fields={"key_value_pairs_answer": Dict[str, str]},
prediction_type=List[Tuple[str, str]],
metrics=["metrics.key_value_extraction"],
prediction_type=Dict[str, str],
metrics=["metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap"],
default_template="templates.key_value_extraction.extract_in_json_format",
),
"tasks.key_value_extraction",
Expand Down
4 changes: 2 additions & 2 deletions prepare/templates/key_value_extraction/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ListSerializer,
MultiTypeSerializer,
)
from unitxt.struct_data_operators import JsonStrToListOfKeyValuePairs
from unitxt.struct_data_operators import JsonStrToDict
from unitxt.templates import (
InputOutputTemplate,
)
Expand All @@ -17,7 +17,7 @@
input_format="{input}",
output_format="{key_value_pairs_answer}",
postprocessors=[
PostProcess(JsonStrToListOfKeyValuePairs()),
PostProcess(JsonStrToDict()),
],
serializer=MultiTypeSerializer(
serializers=[ImageSerializer(), DictAsJsonSerializer(), ListSerializer()]
Expand Down
3 changes: 0 additions & 3 deletions src/unitxt/catalog/metrics/key_value_extraction.json

This file was deleted.

5 changes: 5 additions & 0 deletions src/unitxt/catalog/metrics/key_value_extraction/accuracy.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"__type__": "key_value_extraction",
"__description__": "Metric that evaluates key value pairs predictions (provided as dictionaries)\nwith reference key value pairs (also provided as dictionaries). By default uses an accuracy (exact match) between each for the fields.\nReports average accuracy for each key , as well as micro and macro averages across all keys.\n",
"metric": "metrics.accuracy"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"__type__": "key_value_extraction",
"__description__": "Metric that evaluates key value pairs predictions (provided as dictionary)\nwith reference key value pairs (also provided as dictionary).\nCalculates token overlap between values of corresponding value in reference and prediction.\nReports f1 per key, micro f1 averages across all key/value pairs, and macro f1 averages across keys.\n",
"metric": "metrics.token_overlap",
"score_prefix": "token_overlap_"
}
5 changes: 3 additions & 2 deletions src/unitxt/catalog/tasks/key_value_extraction.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
"reference_fields": {
"key_value_pairs_answer": "Dict[str, str]"
},
"prediction_type": "List[Tuple[str, str]]",
"prediction_type": "Dict[str, str]",
"metrics": [
"metrics.key_value_extraction"
"metrics.key_value_extraction.accuracy",
"metrics.key_value_extraction.token_overlap"
],
"default_template": "templates.key_value_extraction.extract_in_json_format"
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
{
"__type__": "post_process",
"operator": {
"__type__": "json_str_to_list_of_key_value_pairs"
"__type__": "json_str_to_dict"
}
}
],
Expand Down
13 changes: 6 additions & 7 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ def infer(
result = self._mock_infer(dataset)
else:
if self.use_cache:
number_of_batches = len(dataset) // self.cache_batch_size + 1
result = []
for batch_num, batch in enumerate(batched(dataset, self.cache_batch_size)):
for batch_index, batch in enumerate(batched(dataset, self.cache_batch_size)):
cached_results = []
missing_examples = []
for i, item in enumerate(batch):
Expand All @@ -245,7 +246,8 @@ def infer(
else:
missing_examples.append((i, item)) # each element is index in batch and example
# infare on missing examples only, without indices
logger.info(f"Inferring batch {batch_num} / {len(dataset) // self.cache_batch_size} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")

logger.info(f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
if (len(missing_examples) > 0):
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
# recombined to index and value
Expand All @@ -257,9 +259,7 @@ def infer(
cache_key = self._get_cache_key(item)
self._cache[cache_key] = prediction
else:

inferred_results = []

inferred_results=[]
# Combine cached and inferred results in original order
batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
result.extend(batch_predictions)
Expand Down Expand Up @@ -3313,8 +3313,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
}

def get_engine_id(self):
return get_model_and_label_id(self.model, self.label)

return get_model_and_label_id(self.model_name, self.label)

def prepare_engine(self):
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def verify(self):

def _maybe_set_classification_policy(self):
self.set_default_data_classification(
self.data_classification_policy or ["proprietary"], "when loading from python dictionary"
["proprietary"], "when loading from python dictionary"
)

def load_iterables(self) -> MultiStream:
Expand Down
78 changes: 68 additions & 10 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3414,25 +3414,83 @@ def add_macro_scores(self, f1_result, recall_result, precision_result, result):
result["precision_macro"] = self.zero_division


class NER(CustomF1):
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
class KeyValueExtraction(GlobalMetric):

prediction_type = List[Tuple[str, str]]
prediction_type = Dict[str,str]
metric : Metric
single_reference_per_prediction = True
main_score = ""
def prepare(self):
super().prepare()
self.main_score = f"{self.metric.main_score}_micro"

def get_element_group(self, element, additional_input):
return element[1]
def compute(
self,
references: List[List[Any]],
predictions: List[Any],
task_data: List[Dict],
) -> dict:
references = [element[0] for element in references]

def get_element_representation(self, element, additional_input):
return str(element)
key_statistics = {}
all_reference_keys = set()
for reference in references:
all_reference_keys.update(list(reference.keys()))
for key in all_reference_keys:
key_statistics[key]= []

num_prediction_keys=0
illegal_prediction_keys=0
for reference, prediction in zip(references, predictions):
for key in all_reference_keys:
if (key not in reference and key not in prediction):
continue
if (key in reference and key in prediction):
multi_stream = MultiStream.from_iterables({"test": [{"prediction" : prediction[key],
"references" : [reference[key]]}
]})
output_multi_stream = self.metric(multi_stream)
output_stream = output_multi_stream["test"]
score = next(iter(output_stream))["score"]["global"]["score"]
key_statistics[key].append(score)
else:
key_statistics[key].append(0.0)

for key in prediction.keys():
num_prediction_keys += 1
if key not in all_reference_keys:
illegal_prediction_keys += 1

result={}

class KeyValueExtraction(CustomF1):
"""F1 Metrics that receives as input a list of (Key,Value) pairs."""
average = 0
total = 0

weighted_average = 0
for key in key_statistics:
mean_for_key = numpy.mean(key_statistics[key])
num = len(key_statistics[key])
total += num
average += mean_for_key
weighted_average += mean_for_key * num
result[f"{self.metric.main_score}_{key}"] = mean_for_key

result[f"{self.metric.main_score}_micro"] = weighted_average / total
result[f"{self.metric.main_score}_macro"] = average / len(key_statistics)
if (num_prediction_keys !=0):
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 1 - 1.0 * illegal_prediction_keys / num_prediction_keys
else:
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 0

return result

class NER(CustomF1):
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""

prediction_type = List[Tuple[str, str]]

def get_element_group(self, element, additional_input):
return element[0]
return element[1]

def get_element_representation(self, element, additional_input):
return str(element)
Expand Down
Loading