|
85 | 85 | from .random_utils import new_random_generator
|
86 | 86 | from .settings_utils import get_settings
|
87 | 87 | from .stream import DynamicStream, Stream
|
88 |
| -from .text_utils import nested_tuple_to_string |
| 88 | +from .text_utils import nested_tuple_to_string, to_pretty_string |
89 | 89 | from .type_utils import isoftype
|
90 | 90 | from .utils import (
|
91 | 91 | LRUCache,
|
@@ -1477,6 +1477,113 @@ def process_value(self, value: Any) -> Any:
|
1477 | 1477 | return [e for e in value if e in self.allowed_values]
|
1478 | 1478 |
|
1479 | 1479 |
|
| 1480 | +class IntersectCorrespondingFields(InstanceOperator): |
| 1481 | + """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields. |
| 1482 | +
|
| 1483 | + For example: |
| 1484 | +
|
| 1485 | + Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text. |
| 1486 | +
|
| 1487 | + IntersectCorrespondingFields(field="label", |
| 1488 | + allowed_values=["b", "f"], |
| 1489 | + corresponding_fields_to_intersect=["position"]) |
| 1490 | +
|
| 1491 | + would keep only "b" and "f" values in 'labels' field and |
| 1492 | + their respective values in the 'position' field. |
| 1493 | + (All other fields are not effected) |
| 1494 | +
|
| 1495 | + Given this input: |
| 1496 | +
|
| 1497 | + [ |
| 1498 | + {"label": ["a", "b"],"position": [0,1],"other" : "not"}, |
| 1499 | + {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"}, |
| 1500 | + {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"} |
| 1501 | + ] |
| 1502 | +
|
| 1503 | + So the output would be: |
| 1504 | + [ |
| 1505 | + {"label": ["b"], "position":[1],"other" : "not"}, |
| 1506 | + {"label": [], "position": [], "other" : "relevant"}, |
| 1507 | + {"label": ["b", "f"],"position": [1,2], "other" : "field"}, |
| 1508 | + ] |
| 1509 | +
|
| 1510 | + Args: |
| 1511 | + field - the field to intersected (must contain list values) |
| 1512 | + allowed_values (list) - list of values to keep |
| 1513 | + corresponding_fields_to_intersect (list) - additional list fields from which values |
| 1514 | + are removed based the corresponding indices of values removed from the 'field' |
| 1515 | + """ |
| 1516 | + |
| 1517 | + field: str |
| 1518 | + allowed_values: List[str] |
| 1519 | + corresponding_fields_to_intersect: List[str] |
| 1520 | + |
| 1521 | + def verify(self): |
| 1522 | + super().verify() |
| 1523 | + |
| 1524 | + if not isinstance(self.allowed_values, list): |
| 1525 | + raise ValueError( |
| 1526 | + f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'" |
| 1527 | + ) |
| 1528 | + |
| 1529 | + def process( |
| 1530 | + self, instance: Dict[str, Any], stream_name: Optional[str] = None |
| 1531 | + ) -> Dict[str, Any]: |
| 1532 | + if self.field not in instance: |
| 1533 | + raise ValueError( |
| 1534 | + f"Field '{self.field}' is not in provided instance.\n" |
| 1535 | + + to_pretty_string(instance) |
| 1536 | + ) |
| 1537 | + |
| 1538 | + for corresponding_field in self.corresponding_fields_to_intersect: |
| 1539 | + if corresponding_field not in instance: |
| 1540 | + raise ValueError( |
| 1541 | + f"Field '{corresponding_field}' is not in provided instance.\n" |
| 1542 | + + to_pretty_string(instance) |
| 1543 | + ) |
| 1544 | + |
| 1545 | + if not isinstance(instance[self.field], list): |
| 1546 | + raise ValueError( |
| 1547 | + f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n" |
| 1548 | + + to_pretty_string(instance, keys=[self.field]) |
| 1549 | + ) |
| 1550 | + |
| 1551 | + num_values_in_field = len(instance[self.field]) |
| 1552 | + |
| 1553 | + if set(self.allowed_values) == set(instance[self.field]): |
| 1554 | + return instance |
| 1555 | + |
| 1556 | + indices_to_keep = [ |
| 1557 | + i |
| 1558 | + for i, value in enumerate(instance[self.field]) |
| 1559 | + if value in set(self.allowed_values) |
| 1560 | + ] |
| 1561 | + |
| 1562 | + result_instance = {} |
| 1563 | + for field_name, field_value in instance.items(): |
| 1564 | + if ( |
| 1565 | + field_name in self.corresponding_fields_to_intersect |
| 1566 | + or field_name == self.field |
| 1567 | + ): |
| 1568 | + if not isinstance(field_value, list): |
| 1569 | + raise ValueError( |
| 1570 | + f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values." |
| 1571 | + ) |
| 1572 | + if len(field_value) != num_values_in_field: |
| 1573 | + raise ValueError( |
| 1574 | + f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n" |
| 1575 | + + to_pretty_string(instance, keys=[self.field, field_name]) |
| 1576 | + ) |
| 1577 | + result_instance[field_name] = [ |
| 1578 | + value |
| 1579 | + for index, value in enumerate(field_value) |
| 1580 | + if index in indices_to_keep |
| 1581 | + ] |
| 1582 | + else: |
| 1583 | + result_instance[field_name] = field_value |
| 1584 | + return result_instance |
| 1585 | + |
| 1586 | + |
1480 | 1587 | class RemoveValues(FieldOperator):
|
1481 | 1588 | """Removes elements in a field, which must be a list, using a given list of unallowed.
|
1482 | 1589 |
|
|
0 commit comments