diff --git a/pyproject.toml b/pyproject.toml index 4e8d29ca..3bdc3eae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "aiohttp>=3.10.5", "aiohttp-retry>=2.9.1", "pytest-asyncio>=0.24.0", - "gcsfs>=2025.5.1,<2026", + "gcsfs>=2025.5.1", "deprecated>=1.2.18", "types-deprecated>=1.2.15.20250304", ] diff --git a/src/dapla_pseudo/utils.py b/src/dapla_pseudo/utils.py index f8a6a839..563aa42a 100644 --- a/src/dapla_pseudo/utils.py +++ b/src/dapla_pseudo/utils.py @@ -20,6 +20,7 @@ from dapla_pseudo.v1.models.core import PseudoKeyset from dapla_pseudo.v1.models.core import PseudoRule from dapla_pseudo.v1.models.core import RedactKeywordArgs +from dapla_pseudo.v1.mutable_dataframe import FieldMatch from dapla_pseudo.v1.mutable_dataframe import MutableDataFrame from dapla_pseudo.v1.supported_file_format import SupportedOutputFileFormat @@ -161,54 +162,170 @@ def build_pseudo_field_request( """Builds a FieldRequest object.""" mutable_df.match_rules(rules, target_rules) matched_fields = mutable_df.get_matched_fields() - requests: list[PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest] = ( - [] + if mutable_df.hierarchical: + return _build_hierarchical_field_requests( + pseudo_operation=pseudo_operation, + mutable_df=mutable_df, + matched_fields=matched_fields, + custom_keyset=custom_keyset, + target_custom_keyset=target_custom_keyset, + target_rules=target_rules, + ) + + else: + return _build_tabular_field_requests( + pseudo_operation=pseudo_operation, + matched_fields=matched_fields, + custom_keyset=custom_keyset, + target_custom_keyset=target_custom_keyset, + target_rules=target_rules, + ) + + +def _build_tabular_field_requests( + pseudo_operation: PseudoOperation, + matched_fields: dict[str, FieldMatch], + custom_keyset: PseudoKeyset | str | None, + target_custom_keyset: PseudoKeyset | str | None, + target_rules: list[PseudoRule] | None, +) -> list[PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest]: + return [ + _build_single_field_request( + pseudo_operation=pseudo_operation, + request_name=field.path, + representative=field, + values=field.get_value(), + custom_keyset=custom_keyset, + target_custom_keyset=target_custom_keyset, + target_rules=target_rules, + ) + for field in matched_fields.values() + ] + + +def _build_hierarchical_field_requests( + pseudo_operation: PseudoOperation, + mutable_df: MutableDataFrame, + matched_fields: dict[str, FieldMatch], + custom_keyset: PseudoKeyset | str | None, + target_custom_keyset: PseudoKeyset | str | None, + target_rules: list[PseudoRule] | None, +) -> list[PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest]: + grouped_matches = _group_hierarchical_fields_for_requests( + mutable_df=mutable_df, + matched_fields=matched_fields, ) + return [ + _build_single_field_request( + pseudo_operation=pseudo_operation, + request_name=request_name, + representative=fields[0], + values=[value for field in fields for value in field.get_value()], + custom_keyset=custom_keyset, + target_custom_keyset=target_custom_keyset, + target_rules=target_rules, + ) + for request_name, fields in grouped_matches + ] + + +def _build_single_field_request( + pseudo_operation: PseudoOperation, + request_name: str, + representative: FieldMatch, + values: list[str | int | None], + custom_keyset: PseudoKeyset | str | None, + target_custom_keyset: PseudoKeyset | str | None, + target_rules: list[PseudoRule] | None, +) -> PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest: req: PseudoFieldRequest | DepseudoFieldRequest | RepseudoFieldRequest - match pseudo_operation: - case PseudoOperation.PSEUDONYMIZE: - for field in matched_fields.values(): - try: - req = PseudoFieldRequest( - pseudo_func=field.func, - name=field.path, - pattern=field.pattern, - values=field.get_value(), - keyset=KeyWrapper(custom_keyset).keyset, - ) - requests.append(req) - except ValidationError as e: - raise Exception(f"Path or column: {field.path}") from e - case PseudoOperation.DEPSEUDONYMIZE: - for field in matched_fields.values(): - try: - req = DepseudoFieldRequest( - pseudo_func=field.func, - name=field.path, - pattern=field.pattern, - values=field.get_value(), - keyset=KeyWrapper(custom_keyset).keyset, - ) - requests.append(req) - except ValidationError as e: - raise Exception(f"Path or column: {field.path}") from e - - case PseudoOperation.REPSEUDONYMIZE: - if target_rules is not None: - for field in matched_fields.values(): - try: - req = RepseudoFieldRequest( - source_pseudo_func=field.func, - target_pseudo_func=field.target_func, - name=field.path, - pattern=field.pattern, - values=field.get_value(), - source_keyset=KeyWrapper(custom_keyset).keyset, - target_keyset=KeyWrapper(target_custom_keyset).keyset, - ) - requests.append(req) - except ValidationError as e: - raise Exception(f"Path or column: {field.path}") from e - else: - raise ValueError("Found no target rules") - return requests + + try: + match pseudo_operation: + case PseudoOperation.PSEUDONYMIZE: + req = PseudoFieldRequest( + pseudo_func=representative.func, + name=request_name, + pattern=representative.pattern, + values=values, + keyset=KeyWrapper(custom_keyset).keyset, + ) + case PseudoOperation.DEPSEUDONYMIZE: + req = DepseudoFieldRequest( + pseudo_func=representative.func, + name=request_name, + pattern=representative.pattern, + values=values, + keyset=KeyWrapper(custom_keyset).keyset, + ) + case PseudoOperation.REPSEUDONYMIZE: + if target_rules is None: + raise ValueError("Found no target rules") + + req = RepseudoFieldRequest( + source_pseudo_func=representative.func, + target_pseudo_func=representative.target_func, + name=request_name, + pattern=representative.pattern, + values=values, + source_keyset=KeyWrapper(custom_keyset).keyset, + target_keyset=KeyWrapper(target_custom_keyset).keyset, + ) + except ValidationError as e: + raise Exception(f"Path or column: {request_name}") from e + + return req + + +def _group_hierarchical_fields_for_requests( + mutable_df: MutableDataFrame, + matched_fields: dict[str, FieldMatch], +) -> list[tuple[str, list[FieldMatch]]]: + """Group hierarchical field matches into pseudo-service requests. + + Example input paths: + - ``person_info[0]/fnr`` + - ``person_info[1]/fnr`` + - ``person_info[2]/fnr`` + + These are grouped into one request named ``person_info/fnr``. The grouped + request contains all values from all matching leaf paths. + + Two paths share a request only when all of these match: + - normalized request name (array indices removed) + - pattern + - source pseudo function + - target pseudo function (for repseudonymize) + + """ + grouped: dict[tuple[str, str, str, str | None], list[FieldMatch]] = {} + + # Group paths that can share one API request. + for field in matched_fields.values(): + request_name = _remove_array_indices(field.path) + target_func = str(field.target_func) if field.target_func else None + group_key = (request_name, field.pattern, str(field.func), target_func) + grouped.setdefault(group_key, []).append(field) + + grouped_matches: list[tuple[str, list[FieldMatch]]] = [] + + # If a group is batched, store slice boundaries so one response list can be + # written back to the original leaf paths. + for _, fields in grouped.items(): + representative = fields[0] + request_name = _remove_array_indices(representative.path) + + if len(fields) > 1: + mutable_df.map_batch_to_leaf_slices( + request_name, + [(field.path, len(field.get_value())) for field in fields], + ) + grouped_matches.append((request_name, fields)) + else: + grouped_matches.append((representative.path, fields)) + + return grouped_matches + + +def _remove_array_indices(path: str) -> str: + return re.sub(r"\[\d+]", "", path) diff --git a/src/dapla_pseudo/v1/client.py b/src/dapla_pseudo/v1/client.py index 8dd23469..439c7149 100644 --- a/src/dapla_pseudo/v1/client.py +++ b/src/dapla_pseudo/v1/client.py @@ -150,7 +150,7 @@ async def _post( split_pseudo_requests = self._split_requests(pseudo_requests) aio_session = ClientSession( - connector=TCPConnector(limit=100, enable_cleanup_closed=True), + connector=TCPConnector(limit=100), timeout=ClientTimeout(total=TIMEOUT_DEFAULT), ) async with RetryClient( @@ -315,7 +315,7 @@ async def _post_to_sid_endpoint( ] async with ClientSession( - connector=TCPConnector(limit=100, enable_cleanup_closed=True), + connector=TCPConnector(limit=100), timeout=ClientTimeout(total=TIMEOUT_DEFAULT), ) as session: diff --git a/src/dapla_pseudo/v1/mutable_dataframe.py b/src/dapla_pseudo/v1/mutable_dataframe.py index 4f3f50c2..8e7084e7 100644 --- a/src/dapla_pseudo/v1/mutable_dataframe.py +++ b/src/dapla_pseudo/v1/mutable_dataframe.py @@ -63,6 +63,7 @@ def __init__( """Initialize the class.""" self.dataset: pl.DataFrame | dict[str, Any] | pl.LazyFrame = dataframe self.matched_fields: dict[str, FieldMatch] = {} + self.batched_fields: dict[str, list[tuple[str, int, int]]] = {} self.matched_fields_metrics: dict[str, int] | None = None self.hierarchical: bool = hierarchical self.schema = ( @@ -75,6 +76,7 @@ def match_rules( self, rules: list[PseudoRule], target_rules: list[PseudoRule] | None ) -> None: """Create references to all the columns that matches the given pseudo rules.""" + self.batched_fields = {} if self.hierarchical is False: assert isinstance(self.dataset, pl.DataFrame) or isinstance( self.dataset, pl.LazyFrame @@ -119,6 +121,40 @@ def extract_column_data( for match in matches: self.matched_fields[match.path] = match + def map_batch_to_leaf_slices( + self, batch_name: str, segments: list[tuple[str, int]] + ) -> None: + """Store how a batched response maps back to concrete leaf paths. + + A hierarchical batch request flattens values from multiple concrete paths + into one list that is sent to the pseudo service. Example: + + - ``person_info[0]/fnr`` contributes 1 value + - ``person_info[1]/fnr`` contributes 2 values + - ``person_info[2]/fnr`` contributes 1 value + + This method converts that to slice boundaries for the batched request name + (e.g. ``person_info/fnr``): + + - ``person_info[0]/fnr`` -> ``[0:1]`` + - ``person_info[1]/fnr`` -> ``[1:3]`` + - ``person_info[2]/fnr`` -> ``[3:4]`` + + Later, ``update(batch_name, data)`` uses these slices to split one flat + response list into the right chunks and write each chunk back to the + correct leaf path. + """ + if len(segments) <= 1: + return + + offset = 0 + indexed_segments: list[tuple[str, int, int]] = [] + for path, length in segments: + indexed_segments.append((path, offset, offset + length)) + offset += length + + self.batched_fields[batch_name] = indexed_segments + def get_matched_fields(self) -> dict[str, FieldMatch]: """Get a reference to all the columns that matched pseudo rules.""" return self.matched_fields @@ -130,6 +166,9 @@ def update(self, path: str, data: list[str | None]) -> None: self.dataset, pl.LazyFrame ) self.dataset = self.dataset.with_columns(pl.Series(data).alias(path)) + elif (batched_segments := self.batched_fields.get(path)) is not None: + for leaf_path, start, end in batched_segments: + self.update(leaf_path, data[start:end]) elif (field_match := self.matched_fields.get(path)) is not None: assert isinstance(self.dataset, dict) tree = self.dataset diff --git a/tests/data/datadoc/expected_metadata_test_pseudonymize_hierarchical_redact.json b/tests/data/datadoc/expected_metadata_test_pseudonymize_hierarchical_redact.json index ebfcda91..76c8c3a1 100644 --- a/tests/data/datadoc/expected_metadata_test_pseudonymize_hierarchical_redact.json +++ b/tests/data/datadoc/expected_metadata_test_pseudonymize_hierarchical_redact.json @@ -1,31 +1,7 @@ [ { "short_name": "fnr", - "data_element_path": "person_info[0].fnr", - "pseudonymization": { - "encryption_algorithm": "REDACT", - "encryption_algorithm_parameters": [ - { - "placeholder": ":" - } - ] - } - }, - { - "short_name": "fnr", - "data_element_path": "person_info[1].fnr", - "pseudonymization": { - "encryption_algorithm": "REDACT", - "encryption_algorithm_parameters": [ - { - "placeholder": ":" - } - ] - } - }, - { - "short_name": "fnr", - "data_element_path": "person_info[2].fnr", + "data_element_path": "person_info.fnr", "pseudonymization": { "encryption_algorithm": "REDACT", "encryption_algorithm_parameters": [ diff --git a/tests/data/datadoc/expected_metadata_test_pseudonymize_sid_null.json b/tests/data/datadoc/expected_metadata_test_pseudonymize_sid_null.json index 9a4c0a11..0e6c6649 100644 --- a/tests/data/datadoc/expected_metadata_test_pseudonymize_sid_null.json +++ b/tests/data/datadoc/expected_metadata_test_pseudonymize_sid_null.json @@ -4,7 +4,7 @@ "data_element_path": "fnr", "pseudonymization": { "stable_identifier_type": "FREG_SNR", - "stable_identifier_version": "2026-01-31", + "stable_identifier_version": "2026-02-28", "encryption_algorithm": "TINK-FPE", "encryption_key_reference": "papis-common-key-1", "encryption_algorithm_parameters": [ diff --git a/tests/test_utils.py b/tests/test_utils.py index 299e963d..9d218ad4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -152,3 +152,31 @@ def test_build_repseudo_field_request() -> None: pattern="**/foo", values=["baz"], ) + + +def test_build_pseudo_field_request_hierarchical_batching() -> None: + data = [ + {"struct": {"foo": "baz"}}, + {"struct": {"foo": "qux"}}, + ] + df = MutableDataFrame(pl.DataFrame(data), hierarchical=True) + rules = [ + PseudoRule.from_json( + '{"name":"my-rule","pattern":"**/foo","path":"struct/foo","func":"daead(keyId=ssb-common-key-1)"}' + ) + ] + + requests = build_pseudo_field_request(PseudoOperation.PSEUDONYMIZE, df, rules) + + assert len(requests) == 1 + assert requests[0].name == "struct/foo" + assert requests[0].pattern == "**/foo" + assert requests[0].values == ["baz", "qux"] + + # Ensure batched responses can be scattered back to concrete paths. + df.update("struct/foo", ["#", "#"]) + modified_df = df.to_polars() + assert isinstance(modified_df, pl.DataFrame) + struct_values = modified_df.get_column("struct").to_list() + assert struct_values[0]["foo"] == "#" + assert struct_values[1]["foo"] == "#" diff --git a/tests/v1/integration/test_hierarchical.py b/tests/v1/integration/test_hierarchical.py index 34977ffc..be904feb 100644 --- a/tests/v1/integration/test_hierarchical.py +++ b/tests/v1/integration/test_hierarchical.py @@ -156,3 +156,67 @@ def test_pseudonymize_hierarchical_complex( assert_frame_equal( result.to_polars(), df_personer_hierarchical_complex_pseudonymized ) + + +@pytest.mark.usefixtures("setup") +@integration_test() +def test_hierarchical_request_batching_daead( + df_personer_hierarchical: pl.DataFrame, + df_personer_hierarchical_pseudonymized: pl.DataFrame, +) -> None: + """Verify hierarchical non-REDACT behavior is grouped into one logical field result. + + This test asserts real end-to-end behavior without patching internals. For the + chosen fixture there are three leaf paths under ``person_info/fnr``. If requests + are batched, the result metadata is aggregated under a single field key. + """ + rule = PseudoRule( + name="my-rule", + func=PseudoFunction( + function_type=PseudoFunctionTypes.DAEAD, kwargs=DaeadKeywordArgs() + ), + pattern="**/person_info/fnr", + path="person_info/fnr", + ) + + result = ( + Pseudonymize.from_polars(df_personer_hierarchical) + .add_rules(rule) + .run(hierarchical=True) + ) + + datadoc_model = result.datadoc_model + assert isinstance(datadoc_model, list) + assert len(datadoc_model) == 1 + assert datadoc_model[0]["data_element_path"] == "person_info.fnr" + assert_frame_equal(result.to_polars(), df_personer_hierarchical_pseudonymized) + + +@pytest.mark.usefixtures("setup") +@integration_test() +def test_hierarchical_request_batching_redact( + df_personer_hierarchical: pl.DataFrame, + df_personer_hierarchical_redacted: pl.DataFrame, +) -> None: + """Verify hierarchical REDACT requests are grouped like other functions.""" + rule = PseudoRule( + name="my-rule", + func=PseudoFunction( + function_type=PseudoFunctionTypes.REDACT, + kwargs=RedactKeywordArgs(placeholder=":"), + ), + pattern="**/person_info/fnr", + path="person_info/fnr", + ) + + result = ( + Pseudonymize.from_polars(df_personer_hierarchical) + .add_rules(rule) + .run(hierarchical=True) + ) + + datadoc_model = result.datadoc_model + assert isinstance(datadoc_model, list) + assert len(datadoc_model) == 1 + assert datadoc_model[0]["data_element_path"] == "person_info.fnr" + assert_frame_equal(result.to_polars(), df_personer_hierarchical_redacted) diff --git a/uv.lock b/uv.lock index cf0845d4..81291b34 100644 --- a/uv.lock +++ b/uv.lock @@ -789,7 +789,7 @@ requires-dist = [ { name = "dapla-toolbelt-metadata", specifier = ">=0.10.5,<1.0.0" }, { name = "deprecated", specifier = ">=1.2.18" }, { name = "fsspec", specifier = ">=2023.5.0" }, - { name = "gcsfs", specifier = ">=2025.5.1,<2026" }, + { name = "gcsfs", specifier = ">=2025.5.1" }, { name = "msgspec", specifier = ">=0.18.6" }, { name = "numpy", specifier = ">=1.26.4" }, { name = "orjson", specifier = ">=3.10.1" },