Skip to content

Commit 30a5d19

Browse files
Changed API of Key Value Extraction task to use Dict and not List[Tuple] (NON BACKWARD COMPATIBLE CHANGE) (#1675)
* Moved key value extraction task and metrics to use Dict[str,str] to compare and not List[Tuple[str,str]] Also created a dedicated metric. Signed-off-by: Yoav Katz <[email protected]> * Updated metric to better address in case keys are used in prediction and not references Signed-off-by: Yoav Katz <[email protected]> * Added ability to customize inner metric used to compare each entity in KeyValueExtraction Signed-off-by: Yoav Katz <[email protected]> * Fixed unitest and bug with handling nones Signed-off-by: Yoav Katz <[email protected]> * Fixed KeyValueExtraction prepare Signed-off-by: Yoav Katz <[email protected]> * Added example of multiple metrics Signed-off-by: Yoav Katz <[email protected]> * Fix catalog Signed-off-by: Yoav Katz <[email protected]> * Fix some bugs in inference engine tests Signed-off-by: elronbandel <[email protected]> * Fix some bugs in inference engine tests Signed-off-by: elronbandel <[email protected]> * Updated key value extraction metric names Signed-off-by: Yoav Katz <[email protected]> * Updated key value extraction metric names Signed-off-by: Yoav Katz <[email protected]> * Updated documentation string Signed-off-by: Yoav Katz <[email protected]> * Fixed unit test. * Updated to use metric as artifact and not string * Fix bug in name conversion in rits Signed-off-by: elronbandel <[email protected]> * Add engine id Signed-off-by: elronbandel <[email protected]> * Improved output message when using inference cache Also fixed issue when all data was in the cache and an empty list was passed to _infer. Signed-off-by: Yoav Katz <[email protected]> * Fixed bug due to indentation change Signed-off-by: Yoav Katz <[email protected]> * fix Signed-off-by: elronbandel <[email protected]> * fix Signed-off-by: elronbandel <[email protected]> * Removed warning of legacy name. Signed-off-by: Yoav Katz <[email protected]> * Use greedy decoding and remove redundant cache Signed-off-by: elronbandel <[email protected]> * Merge branch 'improve_inference_log' into entity_squad_metric * Ensure temperature is 0 in extraction task * Removed unneeded changes from past merge --------- Signed-off-by: Yoav Katz <[email protected]> Signed-off-by: elronbandel <[email protected]> Co-authored-by: elronbandel <[email protected]>
1 parent f131b94 commit 30a5d19

17 files changed

+209
-108
lines changed

examples/api_call_evaluation.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import List, Tuple
2+
from typing import Dict, List, Tuple
33

44
from unitxt import get_logger
55
from unitxt.api import create_dataset, evaluate
@@ -205,7 +205,7 @@ class CurlStrToListOfKeyValuePairs(FieldOperator):
205205
206206
becomes
207207
208-
[('url', 'curl -X GET -H "Content-Type: application/json" https://petstore.swagger.io/v2/pets'), ('tags', 'dogs'), ('limit', '5')]
208+
{ 'url' : 'curl -X GET -H "Content-Type: application/json" https://petstore.swagger.io/v2/pets', 'tags' : 'dogs', 'limit' : '5'}
209209
210210
"""
211211

@@ -217,11 +217,11 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:
217217

218218
splits = text.split("?")
219219
split_command = re.split(r"((?=GET|POST|DELETE)GET|POST|DELETE)", splits[0])
220-
result = [
221-
("command", split_command[0]),
222-
("operation", split_command[1]),
223-
("url", split_command[2]),
224-
]
220+
result = {
221+
"command": split_command[0],
222+
"operation": split_command[1],
223+
"url": split_command[2],
224+
}
225225
if len(splits) == 1:
226226
return result
227227
params = splits[1]
@@ -234,7 +234,7 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:
234234
(key, value) = key_value_splits
235235
value_splits = value.split(",")
236236
if len(value_splits) == 1:
237-
result.append((f"query_param_{key}", f"{value}"))
237+
result[f"query_param_{key}"]= f"{value}"
238238

239239
return result
240240

@@ -249,10 +249,9 @@ def process_value(self, text: str) -> List[Tuple[str, str]]:
249249
task = Task(
250250
input_fields={"user_request": str, "api_spec": str},
251251
reference_fields={"reference_query": str},
252-
prediction_type=List[Tuple[str, str]],
252+
prediction_type=Dict[str,str],
253253
metrics=[
254-
"metrics.accuracy",
255-
"metrics.key_value_extraction",
254+
"metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap",
256255
],
257256
)
258257

examples/key_value_extraction_evaluation.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ def text_to_image(text: str):
3535

3636
test_set = [
3737
{
38-
"input": text_to_image("John lives in Texas."),
38+
"input": text_to_image("John lives in New York."),
3939
"keys": keys,
40-
"key_value_pairs_answer": {"Worker": "John", "LivesIn": "Texas"},
40+
"key_value_pairs_answer": {"Worker": "John", "LivesIn": "New York"},
4141
},
4242
{
4343
"input": text_to_image("Phil works at Apple and eats an apple."),
@@ -53,10 +53,11 @@ def text_to_image(text: str):
5353
test_set=test_set,
5454
split="test",
5555
format="formats.chat_api",
56+
metrics=["metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap"]
5657
)
5758

5859
model = CrossProviderInferenceEngine(
59-
model="llama-3-2-11b-vision-instruct", provider="watsonx"
60+
model="llama-3-2-90b-vision-instruct", provider="watsonx", temperature=0
6061
)
6162

6263
predictions = model(dataset)

prepare/metrics/custom_f1.py

+1-48
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from unitxt import add_to_catalog
2-
from unitxt.metrics import NER, KeyValueExtraction
2+
from unitxt.metrics import NER
33
from unitxt.test_utils.metrics import test_metric
44

55
metric = NER()
@@ -434,50 +434,3 @@ class NERWithoutClassReporting(NER):
434434
)
435435

436436
add_to_catalog(metric, "metrics.ner", overwrite=True)
437-
438-
439-
metric = KeyValueExtraction()
440-
441-
predictions = [
442-
[("key1", "value1"), ("key2", "value2"), ("unknown_key", "unknown_value")]
443-
]
444-
445-
references = [[[("key1", "value1"), ("key2", "value3")]]]
446-
#
447-
instance_targets = [
448-
{
449-
"f1_key1": 1.0,
450-
"f1_key2": 0.0,
451-
"f1_macro": 0.5,
452-
"f1_micro": 0.4,
453-
"in_classes_support": 0.67,
454-
"precision_macro": 0.5,
455-
"precision_micro": 0.33,
456-
"recall_macro": 0.5,
457-
"recall_micro": 0.5,
458-
"score": 0.4,
459-
"score_name": "f1_micro",
460-
}
461-
]
462-
global_target = {
463-
"f1_key1": 1.0,
464-
"f1_key2": 0.0,
465-
"f1_macro": 0.5,
466-
"in_classes_support": 0.67,
467-
"f1_micro": 0.4,
468-
"recall_micro": 0.5,
469-
"recall_macro": 0.5,
470-
"precision_micro": 0.33,
471-
"precision_macro": 0.5,
472-
"score": 0.4,
473-
"score_name": "f1_micro",
474-
"num_of_instances": 1,
475-
}
476-
outputs = test_metric(
477-
metric=metric,
478-
predictions=predictions,
479-
references=references,
480-
instance_targets=instance_targets,
481-
global_target=global_target,
482-
)
483-
add_to_catalog(metric, "metrics.key_value_extraction", overwrite=True)
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from unitxt import add_to_catalog
2+
from unitxt.metrics import KeyValueExtraction
3+
from unitxt.test_utils.metrics import test_metric
4+
5+
metric = KeyValueExtraction(__description__ = """Metric that evaluates key value pairs predictions (provided as dictionaries)
6+
with reference key value pairs (also provided as dictionaries). By default uses an accuracy (exact match) between each for the fields.
7+
Reports average accuracy for each key , as well as micro and macro averages across all keys.
8+
""", metric="metrics.accuracy",)
9+
10+
predictions = [
11+
{"key1": "value1", "key2": "value2", "unknown_key": "unknown_value"}
12+
]
13+
14+
references = [[{"key1": "value1", "key2" : "value3"}]]
15+
#
16+
instance_targets = [
17+
{"accuracy_key1": 1.0, "accuracy_key2": 0.0, "accuracy_legal_keys_in_predictions": 0.67, "accuracy_macro": 0.5, "accuracy_micro": 0.5, "score": 0.5, "score_name": "accuracy_micro"}
18+
]
19+
global_target = {"accuracy_key1": 1.0, "accuracy_key2": 0.0, "accuracy_legal_keys_in_predictions": 0.67, "accuracy_macro": 0.5, "accuracy_micro": 0.5, "score": 0.5, "score_name": "accuracy_micro", "num_of_instances" : 1}
20+
outputs = test_metric(
21+
metric=metric,
22+
predictions=predictions,
23+
references=references,
24+
instance_targets=instance_targets,
25+
global_target=global_target,
26+
)
27+
add_to_catalog(metric, "metrics.key_value_extraction.accuracy", overwrite=True)
28+
29+
metric = KeyValueExtraction(__description__ = """Metric that evaluates key value pairs predictions (provided as dictionary)
30+
with reference key value pairs (also provided as dictionary).
31+
Calculates token overlap between values of corresponding value in reference and prediction.
32+
Reports f1 per key, micro f1 averages across all key/value pairs, and macro f1 averages across keys.
33+
""",
34+
metric="metrics.token_overlap",score_prefix="token_overlap_")
35+
36+
add_to_catalog(metric, "metrics.key_value_extraction.token_overlap", overwrite=True)

prepare/tasks/key_value_extraction.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Tuple
1+
from typing import Any, Dict, List
22

33
from unitxt.blocks import Task
44
from unitxt.catalog import add_to_catalog
@@ -8,8 +8,8 @@
88
__description__="This is a key value extraction task, where a specific list of possible 'keys' need to be extracted from the input. The ground truth is provided key-value pairs in the form of the dictionary. The results are evaluating using F1 score metric, that expects the predictions to be converted into a list of (key,value) pairs. ",
99
input_fields={"input": Any, "keys": List[str]},
1010
reference_fields={"key_value_pairs_answer": Dict[str, str]},
11-
prediction_type=List[Tuple[str, str]],
12-
metrics=["metrics.key_value_extraction"],
11+
prediction_type=Dict[str, str],
12+
metrics=["metrics.key_value_extraction.accuracy","metrics.key_value_extraction.token_overlap"],
1313
default_template="templates.key_value_extraction.extract_in_json_format",
1414
),
1515
"tasks.key_value_extraction",

prepare/templates/key_value_extraction/templates.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
ListSerializer,
77
MultiTypeSerializer,
88
)
9-
from unitxt.struct_data_operators import JsonStrToListOfKeyValuePairs
9+
from unitxt.struct_data_operators import JsonStrToDict
1010
from unitxt.templates import (
1111
InputOutputTemplate,
1212
)
@@ -17,7 +17,7 @@
1717
input_format="{input}",
1818
output_format="{key_value_pairs_answer}",
1919
postprocessors=[
20-
PostProcess(JsonStrToListOfKeyValuePairs()),
20+
PostProcess(JsonStrToDict()),
2121
],
2222
serializer=MultiTypeSerializer(
2323
serializers=[ImageSerializer(), DictAsJsonSerializer(), ListSerializer()]

src/unitxt/catalog/metrics/key_value_extraction.json

-3
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"__type__": "key_value_extraction",
3+
"__description__": "Metric that evaluates key value pairs predictions (provided as dictionaries)\nwith reference key value pairs (also provided as dictionaries). By default uses an accuracy (exact match) between each for the fields.\nReports average accuracy for each key , as well as micro and macro averages across all keys.\n",
4+
"metric": "metrics.accuracy"
5+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"__type__": "key_value_extraction",
3+
"__description__": "Metric that evaluates key value pairs predictions (provided as dictionary)\nwith reference key value pairs (also provided as dictionary).\nCalculates token overlap between values of corresponding value in reference and prediction.\nReports f1 per key, micro f1 averages across all key/value pairs, and macro f1 averages across keys.\n",
4+
"metric": "metrics.token_overlap",
5+
"score_prefix": "token_overlap_"
6+
}

src/unitxt/catalog/tasks/key_value_extraction.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
"reference_fields": {
99
"key_value_pairs_answer": "Dict[str, str]"
1010
},
11-
"prediction_type": "List[Tuple[str, str]]",
11+
"prediction_type": "Dict[str, str]",
1212
"metrics": [
13-
"metrics.key_value_extraction"
13+
"metrics.key_value_extraction.accuracy",
14+
"metrics.key_value_extraction.token_overlap"
1415
],
1516
"default_template": "templates.key_value_extraction.extract_in_json_format"
1617
}

src/unitxt/catalog/templates/key_value_extraction/extract_in_json_format.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
{
88
"__type__": "post_process",
99
"operator": {
10-
"__type__": "json_str_to_list_of_key_value_pairs"
10+
"__type__": "json_str_to_dict"
1111
}
1212
}
1313
],

src/unitxt/inference.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ def infer(
233233
result = self._mock_infer(dataset)
234234
else:
235235
if self.use_cache:
236+
number_of_batches = len(dataset) // self.cache_batch_size + 1
236237
result = []
237-
for batch_num, batch in enumerate(batched(dataset, self.cache_batch_size)):
238+
for batch_index, batch in enumerate(batched(dataset, self.cache_batch_size)):
238239
cached_results = []
239240
missing_examples = []
240241
for i, item in enumerate(batch):
@@ -245,7 +246,8 @@ def infer(
245246
else:
246247
missing_examples.append((i, item)) # each element is index in batch and example
247248
# infare on missing examples only, without indices
248-
logger.info(f"Inferring batch {batch_num} / {len(dataset) // self.cache_batch_size} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
249+
250+
logger.info(f"Inferring batch {batch_index + 1} / {number_of_batches} with {len(missing_examples)} instances (found {len(cached_results)} instances in {self._cache.directory})")
249251
if (len(missing_examples) > 0):
250252
inferred_results = self._infer([e[1] for e in missing_examples], return_meta_data)
251253
# recombined to index and value
@@ -257,9 +259,7 @@ def infer(
257259
cache_key = self._get_cache_key(item)
258260
self._cache[cache_key] = prediction
259261
else:
260-
261-
inferred_results = []
262-
262+
inferred_results=[]
263263
# Combine cached and inferred results in original order
264264
batch_predictions = [p[1] for p in sorted(cached_results + inferred_results)]
265265
result.extend(batch_predictions)
@@ -3313,8 +3313,7 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
33133313
}
33143314

33153315
def get_engine_id(self):
3316-
return get_model_and_label_id(self.model, self.label)
3317-
3316+
return get_model_and_label_id(self.model_name, self.label)
33183317

33193318
def prepare_engine(self):
33203319
from transformers import AutoModelForCausalLM, AutoTokenizer

src/unitxt/loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ def verify(self):
845845

846846
def _maybe_set_classification_policy(self):
847847
self.set_default_data_classification(
848-
self.data_classification_policy or ["proprietary"], "when loading from python dictionary"
848+
["proprietary"], "when loading from python dictionary"
849849
)
850850

851851
def load_iterables(self) -> MultiStream:

src/unitxt/metrics.py

+68-10
Original file line numberDiff line numberDiff line change
@@ -3414,25 +3414,83 @@ def add_macro_scores(self, f1_result, recall_result, precision_result, result):
34143414
result["precision_macro"] = self.zero_division
34153415

34163416

3417-
class NER(CustomF1):
3418-
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
3417+
class KeyValueExtraction(GlobalMetric):
34193418

3420-
prediction_type = List[Tuple[str, str]]
3419+
prediction_type = Dict[str,str]
3420+
metric : Metric
3421+
single_reference_per_prediction = True
3422+
main_score = ""
3423+
def prepare(self):
3424+
super().prepare()
3425+
self.main_score = f"{self.metric.main_score}_micro"
34213426

3422-
def get_element_group(self, element, additional_input):
3423-
return element[1]
3427+
def compute(
3428+
self,
3429+
references: List[List[Any]],
3430+
predictions: List[Any],
3431+
task_data: List[Dict],
3432+
) -> dict:
3433+
references = [element[0] for element in references]
34243434

3425-
def get_element_representation(self, element, additional_input):
3426-
return str(element)
3435+
key_statistics = {}
3436+
all_reference_keys = set()
3437+
for reference in references:
3438+
all_reference_keys.update(list(reference.keys()))
3439+
for key in all_reference_keys:
3440+
key_statistics[key]= []
3441+
3442+
num_prediction_keys=0
3443+
illegal_prediction_keys=0
3444+
for reference, prediction in zip(references, predictions):
3445+
for key in all_reference_keys:
3446+
if (key not in reference and key not in prediction):
3447+
continue
3448+
if (key in reference and key in prediction):
3449+
multi_stream = MultiStream.from_iterables({"test": [{"prediction" : prediction[key],
3450+
"references" : [reference[key]]}
3451+
]})
3452+
output_multi_stream = self.metric(multi_stream)
3453+
output_stream = output_multi_stream["test"]
3454+
score = next(iter(output_stream))["score"]["global"]["score"]
3455+
key_statistics[key].append(score)
3456+
else:
3457+
key_statistics[key].append(0.0)
3458+
3459+
for key in prediction.keys():
3460+
num_prediction_keys += 1
3461+
if key not in all_reference_keys:
3462+
illegal_prediction_keys += 1
34273463

3464+
result={}
34283465

3429-
class KeyValueExtraction(CustomF1):
3430-
"""F1 Metrics that receives as input a list of (Key,Value) pairs."""
3466+
average = 0
3467+
total = 0
3468+
3469+
weighted_average = 0
3470+
for key in key_statistics:
3471+
mean_for_key = numpy.mean(key_statistics[key])
3472+
num = len(key_statistics[key])
3473+
total += num
3474+
average += mean_for_key
3475+
weighted_average += mean_for_key * num
3476+
result[f"{self.metric.main_score}_{key}"] = mean_for_key
3477+
3478+
result[f"{self.metric.main_score}_micro"] = weighted_average / total
3479+
result[f"{self.metric.main_score}_macro"] = average / len(key_statistics)
3480+
if (num_prediction_keys !=0):
3481+
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 1 - 1.0 * illegal_prediction_keys / num_prediction_keys
3482+
else:
3483+
result[f"{self.metric.main_score}_legal_keys_in_predictions"] = 0
3484+
3485+
return result
3486+
3487+
class NER(CustomF1):
3488+
"""F1 Metrics that receives as input a list of (Entity,EntityType) pairs."""
34313489

34323490
prediction_type = List[Tuple[str, str]]
34333491

34343492
def get_element_group(self, element, additional_input):
3435-
return element[0]
3493+
return element[1]
34363494

34373495
def get_element_representation(self, element, additional_input):
34383496
return str(element)

0 commit comments

Comments
 (0)