diff --git a/src/country_workspace/datasources/rdi.py b/src/country_workspace/datasources/rdi.py index dce5a9ed..24c8728c 100644 --- a/src/country_workspace/datasources/rdi.py +++ b/src/country_workspace/datasources/rdi.py @@ -1,4 +1,6 @@ import io +from collections.abc import Iterable +from typing import Mapping, Any, TypedDict, cast from django.db.transaction import atomic from hope_smart_import.readers import open_xls_multi @@ -7,53 +9,157 @@ from country_workspace.utils.fields import clean_field_name RDI = str | io.BytesIO +Row = Mapping[str, Any] +Sheet = Iterable[Row] + +INDIVIDUAL = "individual" +HOUSEHOLD = "household" + + +class Config(TypedDict): + batch_name: str + household_pk_col: str + master_column_label: str + detail_column_label: str + check_before: bool + + +class ColumnConfigurationError(Exception): + def __init__(self, column_name: str) -> None: + super().__init__(column_name) + self.column_name = column_name + + def __str__(self) -> str: + return f"Column {self.column_name} not found." + + +class SheetProcessingError(Exception): + def __init__(self, sheet_name: str, row_index: int) -> None: + super().__init__(sheet_name, row_index) + self.sheet_name = sheet_name + self.row_index = row_index + + def __str__(self) -> str: + return f"Failed to process {self.sheet_name} sheet at row {self.row_index}" + + +class MissingHouseholdError(Exception): + def __init__(self, row_index: int, household_key: str) -> None: + super().__init__(row_index, household_key) + self.row_index = row_index + self.household_key = household_key + + def __str__(self) -> str: + return f"Missing household {self.household_key} for individual at row {self.row_index}" + + +class HouseholdValidationError(Exception): + def __init__(self, household_key: int) -> None: + super().__init__(household_key) + self.household_key = household_key + + def __str__(self) -> str: + return f"Failed to validate household {self.household_key}." + + +def normalize_row(row: Row) -> Mapping[str, Any]: + return {clean_field_name(k): v for k, v in row.items()} + + +def get_value(row: Row, column_name: str) -> Any: + if column_name in row: + return row[column_name] + + raise ColumnConfigurationError(column_name) + + +def filter_rows_with_household_pk(config: Config, *sheets: Sheet) -> Iterable[Sheet]: + household_pk_col = config["household_pk_col"] + + def has_household_pk(row: Row) -> bool: + return bool(get_value(row, household_pk_col)) + + return (filter(has_household_pk, sheet) for sheet in sheets) + + +def process_households(sheet: Sheet, job: AsyncJob, batch: Batch, config: Config) -> Mapping[int, Household]: + mapping = {} + + for i, row in enumerate(sheet, 1): + name = get_value(row, config["master_column_label"]) + household_key = get_value(row, config["household_pk_col"]) + + try: + mapping[household_key] = cast( + Household, + job.program.households.create( + batch=batch, + name=name, + flex_fields=normalize_row(row), + ), + ) + except Exception as e: + raise SheetProcessingError(HOUSEHOLD, i) from e + + return mapping + + +def process_individuals( + sheet: Sheet, household_mapping: Mapping[int, Household], job: AsyncJob, batch: Batch, config: Config +) -> int: + processed = 0 + + for i, row in enumerate(sheet, 1): + name = get_value(row, config["detail_column_label"]) + household_key = get_value(row, config["household_pk_col"]) + household = household_mapping.get(household_key) + + if not household: + raise MissingHouseholdError(i, household_key) + + try: + job.program.individuals.create( + batch=batch, + name=name, + household_id=household.pk, + flex_fields=normalize_row(row), + ) + except Exception as e: + raise SheetProcessingError(INDIVIDUAL, i) from e + + processed += 1 + + return processed + + +def validate_households(config: Config, household_mapping: Mapping[int, Household]) -> None: + if config["check_before"]: + for household_key, household in household_mapping.items(): + if not household.validate_with_checker(): + raise HouseholdValidationError(household_key) def import_from_rdi(job: AsyncJob) -> dict[str, int]: - ret = {"household": 0, "individual": 0} - hh_ids = {} with atomic(): - batch_name = job.config["batch_name"] - household_pk_col = job.config["household_pk_col"] - master_column_label = job.config["master_column_label"] - detail_column_label = job.config["detail_column_label"] + config: Config = job.config rdi = job.file batch = Batch.objects.create( - name=batch_name, + name=config["batch_name"], program=job.program, country_office=job.program.country_office, imported_by=job.owner, source=Batch.BatchSource.RDI, ) - for sheet_index, sheet_generator in open_xls_multi(rdi, sheets=[0, 1]): - for line, raw_record in enumerate(sheet_generator, 1): - record = {} - for k, v in raw_record.items(): - record[clean_field_name(k)] = v - if record[household_pk_col]: - try: - if sheet_index == 0: - hh: "Household" = job.program.households.create( - batch=batch, - name=raw_record[master_column_label], - flex_fields=record, - ) - hh_ids[record[household_pk_col]] = hh.pk - ret["household"] += 1 - elif sheet_index == 1: - try: - name = record[detail_column_label] - except KeyError: - raise Exception( - "Error in configuration. '%s' is not a valid column name" % detail_column_label, - ) - job.program.individuals.create( - batch=batch, - name=name, - household_id=hh_ids[record[household_pk_col]], - flex_fields=record, - ) - ret["individual"] += 1 - except Exception as e: # noqa: BLE001 - raise Exception("Error processing sheet %s line %s: %s" % (1 + sheet_index, line, e)) - return ret + (_, household_sheet), (_, individual_sheet) = open_xls_multi(rdi, sheets=[0, 1]) + + household_sheet, individual_sheet = filter_rows_with_household_pk(config, household_sheet, individual_sheet) + + household_mapping = process_households(household_sheet, job, batch, config) + individuals_number = process_individuals(individual_sheet, household_mapping, job, batch, config) + + validate_households(config, household_mapping) + + return { + "household": len(household_mapping), + "individual": individuals_number, + } diff --git a/src/country_workspace/workspaces/admin/program.py b/src/country_workspace/workspaces/admin/program.py index aae325f2..7d52940b 100644 --- a/src/country_workspace/workspaces/admin/program.py +++ b/src/country_workspace/workspaces/admin/program.py @@ -19,7 +19,7 @@ from ...contrib.aurora.forms import ImportAuroraForm from ...contrib.kobo.forms import ImportKoboForm from ...contrib.kobo.sync import import_data as import_from_kobo -from ...datasources.rdi import import_from_rdi +from ...datasources.rdi import import_from_rdi, Config as RDIConfig from ...models import AsyncJob from ...utils.flex_fields import get_checker_fields from ..models import CountryProgram @@ -266,6 +266,13 @@ def import_data(self, request: HttpRequest, pk: str) -> "HttpResponse": def import_rdi(self, request: HttpRequest, program: CountryProgram) -> "ImportFileForm | None": form = ImportFileForm(request.POST, request.FILES, prefix="rdi") if form.is_valid(): + config: RDIConfig = { + "batch_name": form.cleaned_data["batch_name"] or batch_name_default(), + "household_pk_col": form.cleaned_data["pk_column_name"], + "master_column_label": form.cleaned_data["master_column_label"], + "detail_column_label": form.cleaned_data["detail_column_label"], + "check_before": form.cleaned_data["check_before"], + } job: AsyncJob = AsyncJob.objects.create( description="RDI importing", type=AsyncJob.JobType.TASK, @@ -273,12 +280,7 @@ def import_rdi(self, request: HttpRequest, program: CountryProgram) -> "ImportFi file=request.FILES["rdi-file"], program=program, owner=request.user, - config={ - "batch_name": form.cleaned_data["batch_name"] or batch_name_default(), - "household_pk_col": form.cleaned_data["pk_column_name"], - "master_column_label": form.cleaned_data["master_column_label"], - "detail_column_label": form.cleaned_data["detail_column_label"], - }, + config=config, ) job.queue() self.message_user(request, _("Import scheduled"), messages.SUCCESS) diff --git a/tests/datasources/test_rdi.py b/tests/datasources/test_rdi.py new file mode 100644 index 00000000..17676bb5 --- /dev/null +++ b/tests/datasources/test_rdi.py @@ -0,0 +1,250 @@ +from collections.abc import Mapping +from unittest.mock import Mock, call + +import pytest +from pytest_mock import MockerFixture + +from country_workspace.datasources.rdi import ( + normalize_row, + get_value, + ColumnConfigurationError, + SheetProcessingError, + MissingHouseholdError, + HouseholdValidationError, + filter_rows_with_household_pk, + process_households, + process_individuals, + validate_households, + import_from_rdi, + Config, + Sheet, + Row, +) +from country_workspace.models import Household + +HOUSEHOLD_1_PK = 1 +HOUSEHOLD_2_PK = 2 +HOUSEHOLD_1_NAME = "Household 1" +HOUSEHOLD_2_NAME = "Household 2" + + +@pytest.fixture +def config() -> Config: + return { + "batch_name": "batch_name", + "household_pk_col": "household_pk", + "master_column_label": "master_column", + "detail_column_label": "detail_column", + "check_before": False, + } + + +@pytest.fixture +def household_sheet(config: Config) -> Sheet: + return [ + {config["master_column_label"]: HOUSEHOLD_1_NAME, config["household_pk_col"]: HOUSEHOLD_1_PK}, + {config["master_column_label"]: HOUSEHOLD_1_NAME, config["household_pk_col"]: HOUSEHOLD_2_PK}, + ] + + +@pytest.fixture +def individual_sheet(config: Config) -> Sheet: + return [ + { + config["detail_column_label"]: "John Doe", + config["household_pk_col"]: HOUSEHOLD_1_PK, + }, + { + config["detail_column_label"]: "Doe John", + config["household_pk_col"]: HOUSEHOLD_2_PK, + }, + ] + + +@pytest.fixture +def household_mapping() -> Mapping[int, Mock]: + return { + HOUSEHOLD_1_PK: Mock(name=HOUSEHOLD_1_NAME), + HOUSEHOLD_2_PK: Mock(name=HOUSEHOLD_2_NAME), + } + + +def test_column_configuration_error_format() -> None: + error = ColumnConfigurationError(column_name := "test_column") + assert column_name in str(error) + + +def test_sheet_processing_error_format() -> None: + error = SheetProcessingError(sheet_name := "test_sheet", row_index := 42) + assert sheet_name in str(error) + assert str(row_index) in str(error) + + +def test_missing_household_error_format() -> None: + error = MissingHouseholdError(row_index := 42, household_key := "test_household_key") + assert str(row_index) in str(error) + assert household_key in str(error) + + +def test_household_validation_error_format() -> None: + error = HouseholdValidationError(household_key := "test_household_key") + assert household_key in str(error) + + +def test_normalize_row_calls_clean_field_name(mocker: MockerFixture) -> None: + row = {(key := "key"): (value := "value")} + clean_field_name_mock = mocker.patch("country_workspace.datasources.rdi.clean_field_name") + + result = normalize_row(row) + + assert result == {clean_field_name_mock.return_value: value} + clean_field_name_mock.assert_called_once_with(key) + + +def test_get_value_returns_value() -> None: + row = {(column := "column"): (column_value := "value")} + + value = get_value(row, column) + + assert value == column_value + + +def test_get_value_raise_exception_when_key_is_missing() -> None: + row: Row = {} + + with pytest.raises(ColumnConfigurationError): + get_value(row, "column") + + +def test_filter_rows_with_household_pk(mocker: MockerFixture, config: Config, household_sheet: Sheet) -> None: + household_sheet_list = list(household_sheet) + get_value_mock = mocker.patch("country_workspace.datasources.rdi.get_value") + get_value_mock.side_effect = True, False + + result = [list(s) for s in filter_rows_with_household_pk(config, household_sheet)] + + assert result == [[household_sheet_list[0]]] + get_value_mock.assert_has_calls( + ( + call(household_sheet_list[0], config["household_pk_col"]), + call(household_sheet_list[1], config["household_pk_col"]), + ) + ) + + +def test_process_households(config: Config, household_sheet: Sheet) -> None: + job = Mock() + batch = Mock() + + result = process_households(household_sheet, job, batch, config) + + assert result == { + row[config["household_pk_col"]]: job.program.households.create.return_value for row in household_sheet + } + job.program.households.create.assert_has_calls( + [ + call(batch=batch, name=row[config["master_column_label"]], flex_fields=normalize_row(row)) + for row in household_sheet + ] + ) + + +def test_process_households_failed_to_save_household(config: Config, household_sheet: Sheet) -> None: + job = Mock() + batch = Mock() + + job.program.households.create.side_effect = Exception("Something went wrong") + + with pytest.raises(SheetProcessingError): + process_households(household_sheet, job, batch, config) + + +def test_process_individuals( + config: Config, individual_sheet: Sheet, household_mapping: Mapping[int, Household] +) -> None: + job = Mock() + batch = Mock() + + result = process_individuals(individual_sheet, household_mapping, job, batch, config) + + assert result == len(list(individual_sheet)) + job.program.individuals.create.assert_has_calls( + [ + call( + batch=batch, + name=row[config["detail_column_label"]], + household_id=household_mapping[row[config["household_pk_col"]]].pk, + flex_fields=normalize_row(row), + ) + for row in individual_sheet + ] + ) + + +def test_validate_households(config: Config, household_mapping: Mapping[int, Mock]) -> None: + config["check_before"] = True + + validate_households(config, household_mapping) + + for household in household_mapping.values(): + household.validate_with_checker.assert_called_once() + + +def test_validate_households_raises_exception_on_failed_validation( + config: Config, household_mapping: Mapping[int, Mock] +) -> None: + config["check_before"] = True + household_mapping[HOUSEHOLD_1_PK].validate_with_checker.return_value = False + + with pytest.raises(HouseholdValidationError): + validate_households(config, household_mapping) + + +def test_validate_households_check_before_is_false(config: Config, household_mapping: Mapping[int, Mock]) -> None: + config["check_before"] = False + + validate_households(config, household_mapping) + + for household in household_mapping.values(): + household.validate_with_checker.assert_not_called() + + +def test_import_from_rdi( + mocker: MockerFixture, + config: Config, + household_sheet: Sheet, + individual_sheet: Sheet, + household_mapping: Mapping[int, Mock], +) -> None: + job = Mock() + job.config = config + batch_class_mock = mocker.patch("country_workspace.datasources.rdi.Batch") + open_xls_multi_mock = mocker.patch("country_workspace.datasources.rdi.open_xls_multi") + open_xls_multi_mock.return_value = (0, household_sheet), (1, individual_sheet) + filter_rows_with_household_pk_mock = mocker.patch("country_workspace.datasources.rdi.filter_rows_with_household_pk") + filter_rows_with_household_pk_mock.return_value = household_sheet, individual_sheet + process_households_mock = mocker.patch("country_workspace.datasources.rdi.process_households") + process_households_mock.return_value = household_mapping + process_individuals_mock = mocker.patch("country_workspace.datasources.rdi.process_individuals") + process_individuals_mock.return_value = (processed_individuals := len(list(individual_sheet))) + validate_households_mock = mocker.patch("country_workspace.datasources.rdi.validate_households") + + result = import_from_rdi(job) + + assert result == {"household": len(household_mapping), "individual": processed_individuals} + batch_class_mock.objects.create.assert_called_once_with( + name=config["batch_name"], + program=job.program, + country_office=job.program.country_office, + imported_by=job.owner, + source=batch_class_mock.BatchSource.RDI, + ) + open_xls_multi_mock.assert_called_once_with(job.file, sheets=[0, 1]) + filter_rows_with_household_pk_mock.assert_called_once_with(config, household_sheet, individual_sheet) + process_households_mock.assert_called_once_with( + household_sheet, job, batch_class_mock.objects.create.return_value, config + ) + process_individuals_mock.assert_called_once_with( + individual_sheet, household_mapping, job, batch_class_mock.objects.create.return_value, config + ) + validate_households_mock.assert_called_once_with(config, household_mapping) diff --git a/tests/workspace/test_ws_import.py b/tests/workspace/test_ws_import.py index e4aff03d..483f62b3 100644 --- a/tests/workspace/test_ws_import.py +++ b/tests/workspace/test_ws_import.py @@ -68,7 +68,7 @@ def test_import_data_rdi(force_migrated_records, app, program): res.forms["import-file"]["_selected_tab"] = "rdi" res.forms["import-file"]["rdi-file"] = Upload("rdi_one.xlsx", data) - res.forms["import-file"]["rdi-detail_column_label"] = "full_name" + res.forms["import-file"]["rdi-detail_column_label"] = "full_name_i_c" res = res.forms["import-file"].submit() assert res.status_code == 302 assert program.households.count() == 1