Skip to content

Commit 121c268

Browse files
pklprivPrzemysław Klocekyoavkatz
authored
Create new IntersectCorrespondingFields operator (#1531)
* filter for entity types intro * code update * optimalisation * improv * remove filter and add its functionality to intersect * typo * Created a new type of intersect operator Signed-off-by: Yoav Katz <[email protected]> * Updated documentation Signed-off-by: Yoav Katz <[email protected]> --------- Signed-off-by: Yoav Katz <[email protected]> Co-authored-by: Przemysław Klocek <[email protected]> Co-authored-by: Yoav Katz <[email protected]> Co-authored-by: Yoav Katz <[email protected]>
1 parent bc65c5c commit 121c268

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

src/unitxt/operators.py

+108-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
from .random_utils import new_random_generator
8686
from .settings_utils import get_settings
8787
from .stream import DynamicStream, Stream
88-
from .text_utils import nested_tuple_to_string
88+
from .text_utils import nested_tuple_to_string, to_pretty_string
8989
from .type_utils import isoftype
9090
from .utils import (
9191
LRUCache,
@@ -1477,6 +1477,113 @@ def process_value(self, value: Any) -> Any:
14771477
return [e for e in value if e in self.allowed_values]
14781478

14791479

1480+
class IntersectCorrespondingFields(InstanceOperator):
1481+
"""Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1482+
1483+
For example:
1484+
1485+
Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1486+
1487+
IntersectCorrespondingFields(field="label",
1488+
allowed_values=["b", "f"],
1489+
corresponding_fields_to_intersect=["position"])
1490+
1491+
would keep only "b" and "f" values in 'labels' field and
1492+
their respective values in the 'position' field.
1493+
(All other fields are not effected)
1494+
1495+
Given this input:
1496+
1497+
[
1498+
{"label": ["a", "b"],"position": [0,1],"other" : "not"},
1499+
{"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1500+
{"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1501+
]
1502+
1503+
So the output would be:
1504+
[
1505+
{"label": ["b"], "position":[1],"other" : "not"},
1506+
{"label": [], "position": [], "other" : "relevant"},
1507+
{"label": ["b", "f"],"position": [1,2], "other" : "field"},
1508+
]
1509+
1510+
Args:
1511+
field - the field to intersected (must contain list values)
1512+
allowed_values (list) - list of values to keep
1513+
corresponding_fields_to_intersect (list) - additional list fields from which values
1514+
are removed based the corresponding indices of values removed from the 'field'
1515+
"""
1516+
1517+
field: str
1518+
allowed_values: List[str]
1519+
corresponding_fields_to_intersect: List[str]
1520+
1521+
def verify(self):
1522+
super().verify()
1523+
1524+
if not isinstance(self.allowed_values, list):
1525+
raise ValueError(
1526+
f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
1527+
)
1528+
1529+
def process(
1530+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
1531+
) -> Dict[str, Any]:
1532+
if self.field not in instance:
1533+
raise ValueError(
1534+
f"Field '{self.field}' is not in provided instance.\n"
1535+
+ to_pretty_string(instance)
1536+
)
1537+
1538+
for corresponding_field in self.corresponding_fields_to_intersect:
1539+
if corresponding_field not in instance:
1540+
raise ValueError(
1541+
f"Field '{corresponding_field}' is not in provided instance.\n"
1542+
+ to_pretty_string(instance)
1543+
)
1544+
1545+
if not isinstance(instance[self.field], list):
1546+
raise ValueError(
1547+
f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1548+
+ to_pretty_string(instance, keys=[self.field])
1549+
)
1550+
1551+
num_values_in_field = len(instance[self.field])
1552+
1553+
if set(self.allowed_values) == set(instance[self.field]):
1554+
return instance
1555+
1556+
indices_to_keep = [
1557+
i
1558+
for i, value in enumerate(instance[self.field])
1559+
if value in set(self.allowed_values)
1560+
]
1561+
1562+
result_instance = {}
1563+
for field_name, field_value in instance.items():
1564+
if (
1565+
field_name in self.corresponding_fields_to_intersect
1566+
or field_name == self.field
1567+
):
1568+
if not isinstance(field_value, list):
1569+
raise ValueError(
1570+
f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1571+
)
1572+
if len(field_value) != num_values_in_field:
1573+
raise ValueError(
1574+
f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1575+
+ to_pretty_string(instance, keys=[self.field, field_name])
1576+
)
1577+
result_instance[field_name] = [
1578+
value
1579+
for index, value in enumerate(field_value)
1580+
if index in indices_to_keep
1581+
]
1582+
else:
1583+
result_instance[field_name] = field_value
1584+
return result_instance
1585+
1586+
14801587
class RemoveValues(FieldOperator):
14811588
"""Removes elements in a field, which must be a list, using a given list of unallowed.
14821589

tests/library/test_operators.py

+129
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
FromIterables,
3232
IndexOf,
3333
Intersect,
34+
IntersectCorrespondingFields,
3435
IterableSource,
3536
JoinStr,
3637
LengthBalancer,
@@ -658,6 +659,134 @@ def test_intersect(self):
658659
tester=self,
659660
)
660661

662+
def test_intersect_corresponding_fields(self):
663+
inputs = [
664+
{"label": ["a", "b"], "position": [0, 1], "other": "not"},
665+
{"label": ["a", "c", "d"], "position": [0, 1, 2], "other": "relevant"},
666+
{"label": ["a", "b", "f"], "position": [0, 1, 2], "other": "field"},
667+
]
668+
669+
targets = [
670+
{"label": ["b"], "position": [1], "other": "not"},
671+
{"label": [], "position": [], "other": "relevant"},
672+
{"label": ["b", "f"], "position": [1, 2], "other": "field"},
673+
]
674+
675+
check_operator(
676+
operator=IntersectCorrespondingFields(
677+
field="label",
678+
allowed_values=["b", "f"],
679+
corresponding_fields_to_intersect=["position"],
680+
),
681+
inputs=inputs,
682+
targets=targets,
683+
tester=self,
684+
)
685+
686+
exception_texts = [
687+
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
688+
"""Field 'acme_field' is not in provided instance.
689+
label (list):
690+
[0] (str):
691+
a
692+
[1] (str):
693+
b
694+
position (list):
695+
[0] (int):
696+
0
697+
[1] (int):
698+
1
699+
other (str):
700+
not
701+
""",
702+
]
703+
check_operator_exception(
704+
operator=IntersectCorrespondingFields(
705+
field="acme_field",
706+
allowed_values=["b", "f"],
707+
corresponding_fields_to_intersect=["other"],
708+
),
709+
inputs=inputs,
710+
exception_texts=exception_texts,
711+
tester=self,
712+
)
713+
714+
exception_texts = [
715+
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
716+
"""Field 'acme_field' is not in provided instance.
717+
label (list):
718+
[0] (str):
719+
a
720+
[1] (str):
721+
b
722+
position (list):
723+
[0] (int):
724+
0
725+
[1] (int):
726+
1
727+
other (str):
728+
not
729+
""",
730+
]
731+
check_operator_exception(
732+
operator=IntersectCorrespondingFields(
733+
field="label",
734+
allowed_values=["b", "f"],
735+
corresponding_fields_to_intersect=["acme_field"],
736+
),
737+
inputs=inputs,
738+
exception_texts=exception_texts,
739+
tester=self,
740+
)
741+
742+
exception_texts = [
743+
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
744+
"Value of field 'other' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\nother (str):\n not\n",
745+
]
746+
check_operator_exception(
747+
operator=IntersectCorrespondingFields(
748+
field="other",
749+
allowed_values=["b", "f"],
750+
corresponding_fields_to_intersect=["other"],
751+
),
752+
inputs=inputs,
753+
exception_texts=exception_texts,
754+
tester=self,
755+
)
756+
757+
inputs = [
758+
{"label": ["a", "b"], "position": [0, 1, 2], "other": "not"},
759+
{"label": ["a", "c", "d"], "position": [0, 1, 2], "other": "relevant"},
760+
{"label": ["a", "b", "f"], "position": [0, 1, 2], "other": "field"},
761+
]
762+
exception_texts = [
763+
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
764+
"""Number of elements in field 'position' is not the same as the number of elements in field 'label' so the IntersectCorrespondingFields can not remove corresponding values.
765+
label (list):
766+
[0] (str):
767+
a
768+
[1] (str):
769+
b
770+
position (list):
771+
[0] (int):
772+
0
773+
[1] (int):
774+
1
775+
[2] (int):
776+
2
777+
""",
778+
]
779+
check_operator_exception(
780+
operator=IntersectCorrespondingFields(
781+
field="label",
782+
allowed_values=["b", "f"],
783+
corresponding_fields_to_intersect=["position"],
784+
),
785+
inputs=inputs,
786+
exception_texts=exception_texts,
787+
tester=self,
788+
)
789+
661790
def test_remove_none(self):
662791
inputs = [
663792
{"references": [["none"], ["none"]]},

0 commit comments

Comments
 (0)