Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revived no_deep just to compare performance #1254

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"files": "^.secrets.baseline$",
"lines": null
},
"generated_at": "2024-09-30T08:33:42Z",
"generated_at": "2024-10-04T17:48:13Z",
"plugins_used": [
{
"name": "AWSKeyDetector"
Expand Down Expand Up @@ -82,7 +82,7 @@
"hashed_secret": "fa172616e9af3d2a24b5597f264eab963fe76889",
"is_secret": false,
"is_verified": false,
"line_number": 2022,
"line_number": 2014,
"type": "Hex High Entropy String",
"verified_result": null
}
Expand Down
2 changes: 1 addition & 1 deletion prepare/cards/xlsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from unitxt.collections_operators import Wrap
from unitxt.test_utils.card import test_card

configs = get_dataset_config_names("GEM/xlsum") # the languages
configs = get_dataset_config_names("GEM/xlsum", trust_remote_code=True) # the languages
# now configs is the list of all languages showing in the dataset


Expand Down
24 changes: 8 additions & 16 deletions src/unitxt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import field
from operator import itemgetter
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import evaluate
Expand Down Expand Up @@ -664,24 +663,18 @@ class BulkInstanceMetric(StreamOperator, MetricWithConfidenceInterval):

def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
global_score = {}

instances = []

# consume the stream
references, predictions = map(
list,
zip(
*[
itemgetter("references", "prediction")(
self.verify_instance(instance)
)
for instance in stream
]
),
)
for instance in stream:
self.verify_instance(instance)
instances.append(instance)

predictions = [instance["prediction"] for instance in instances]
references = [instance["references"] for instance in instances]
task_data = [
instance["task_data"] if "task_data" in instance else {}
for instance in stream
for instance in instances
]
self._validate_references_and_prediction(references, predictions)
# compute the metric over all refs and preds
Expand All @@ -696,7 +689,7 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
instance_score["score"] = instance_score[self.main_score]
instance_score["score_name"] = self.main_score

for instance, score in zip(stream, instance_scores):
for instance, score in zip(instances, instance_scores):
if "score" not in instance:
instance["score"] = {"global": {}, "instance": {}}

Expand All @@ -705,7 +698,6 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
score, instance["score"]["instance"]
)
)
instances.append(instance)

for reduction, fields in self.reduction_map.items():
assert (
Expand Down
74 changes: 38 additions & 36 deletions src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@
StreamOperator,
)
from .random_utils import new_random_generator
from .settings_utils import get_constants, get_settings
from .stream import DynamicStream, Stream
from .settings_utils import get_settings
from .stream import DynamicStream, ListStream, Stream
from .text_utils import nested_tuple_to_string
from .type_utils import isoftype
from .utils import deepcopy, flatten_dict

settings = get_settings()
constants = get_constants()


class FromIterables(StreamInitializerOperator):
Expand Down Expand Up @@ -203,13 +202,12 @@ def process(

def get_mapped_value(self, instance, key, mapper, val):
val_as_str = str(val) # make sure the value is a string
if self.strict and (val_as_str not in mapper):
if val_as_str in mapper:
return mapper[val_as_str]
if self.strict:
raise KeyError(
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
)
# By default deep copy the value in mapper to avoid shared modifications
if val_as_str in mapper:
return deepcopy(mapper[val_as_str])
return val


Expand Down Expand Up @@ -429,11 +427,6 @@ def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
self.verify_field_definition()
# Need to deep copy instance, because when assigning two dictionary fields,
# dict_set() the target field dictionary fields.
# This means that if this target field was assigned to another field before,
# the field is updated as well.
instance = deepcopy(instance)
for from_field, to_field in self._field_to_field:
try:
old_value = dict_get(
Expand Down Expand Up @@ -847,14 +840,15 @@ class Copy(FieldOperator):

"""

use_deep_copy: bool = True

def process_value(self, value: Any) -> Any:
if self.use_deep_copy:
return copy.deepcopy(value)
return value


class DeepCopy(FieldOperator):
def process_value(self, value: Any) -> Any:
return copy.deepcopy(value)


@deprecation(version="2.0.0", alternative=Copy)
class CopyFields(Copy):
pass
Expand Down Expand Up @@ -1602,7 +1596,21 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
from .metrics import Metric

first_instance = stream.peek()
# Number of instances in input stream is assumed to be small. This is why
# each metric consumes all of them and lays them in its main memory, and even generates
# some 1000 copies thereof for the sake of CI.
# So we start with deep copying here, to make a 'frozen' status of the stream, having
# passed the preprocess_steps of the task, and inference, and now getting to be evaluated,
# a frozen status to be fed into each of the metrics listed in metric_field,
# so that the evaluation of one does not affect the evaluation of another
# (typically, affecting via change of instance as part of
# preprocess_steps of MetricPipeline, as illustrated in docs/adding_metrics/Using Metric Pipelines).

instances_upon_entrance_to_metrics_evaluations = []
for instance in stream:
instances_upon_entrance_to_metrics_evaluations.append(deepcopy(instance))

first_instance = instances_upon_entrance_to_metrics_evaluations[0]

metric_names = first_instance.get(self.metric_field, [])
if not metric_names:
Expand All @@ -1619,16 +1627,6 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato
# by the first listed metric (as desired).
metric_names = list(reversed(metric_names))

# Workaround: The metric/MetricPipeline modifies the stream itself, sometimes making it incompatible
# for further metrics' processing, instead of just modifying the score field.
# Here we keep all the fields besides the score, and restore them after the metric finishes.
first_instance = stream.peek()
keys_to_restore = set(first_instance.keys()).difference({"score"})
multi_stream = MultiStream({stream_name: stream})
multi_stream = CopyFields(
field_to_field={k: f"{k}_orig" for k in keys_to_restore}
)(multi_stream)

for metric_name in metric_names:
metric = self.get_artifact(metric_name)
assert isinstance(
Expand All @@ -1637,17 +1635,21 @@ def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generato

if not self.calc_confidence_intervals:
metric.disable_confidence_interval_calculation()

multi_stream = MultiStream(
{
"tmp": ListStream(
instances_list=instances_upon_entrance_to_metrics_evaluations,
copying=True, # ensures deep copy when iterating over instances
)
}
)
multi_stream = metric(multi_stream)
multi_stream = CopyFields(
field_to_field={f"{k}_orig": k for k in keys_to_restore}
)(multi_stream)
for evaluated_instance, freezed_instance in zip(
multi_stream["tmp"], instances_upon_entrance_to_metrics_evaluations
):
freezed_instance["score"] = deepcopy(evaluated_instance["score"])

multi_stream = RemoveFields(fields=[f"{k}_orig" for k in keys_to_restore])(
multi_stream
)
stream = multi_stream[stream_name]
yield from stream
yield from instances_upon_entrance_to_metrics_evaluations


class MergeStreams(MultiStreamOperator):
Expand Down
8 changes: 5 additions & 3 deletions src/unitxt/stream_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

"""

import copy
from typing import (
List,
Literal,
Expand Down Expand Up @@ -154,6 +155,7 @@ class DuplicateSplit(MultiStreamOperator):

def process(self, multi_stream: MultiStream) -> MultiStream:
assert self.split in multi_stream
generators = multi_stream
generators[self.to_split] = generators[self.split]
return MultiStream(generators)
new_stream = copy.deepcopy(multi_stream[self.split])
new_stream.set_copying(copying=True)
multi_stream[self.to_split] = new_stream
return multi_stream
8 changes: 7 additions & 1 deletion src/unitxt/test_utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from copy import deepcopy
from typing import Any, List, Optional

from ..eval_utils import evaluate
Expand Down Expand Up @@ -68,7 +69,12 @@ def apply_metric(
{"prediction": prediction, "references": reference}
for prediction, reference in zip(predictions, references)
]
multi_stream = MultiStream.from_iterables({"test": test_iterable}, copying=True)
# break any cross reference from one instance to another,
# imitating what's done at the entrance to operators.ApplyMetric
ti = []
for instance in test_iterable:
ti.append(deepcopy(instance))
multi_stream = MultiStream.from_iterables({"test": ti})

output_multi_stream = metric(multi_stream)
output_stream = output_multi_stream["test"]
Expand Down
Loading