Skip to content

Commit d661a79

Browse files
Use ValidatorMixin.validate fail_if_alien argument
1 parent 718295a commit d661a79

File tree

9 files changed

+51
-131
lines changed

9 files changed

+51
-131
lines changed

src/country_workspace/contrib/aurora/pipeline.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from country_workspace.contrib.aurora.client import AuroraClient
66
from country_workspace.models import AsyncJob, Batch, Household, Individual
77
from country_workspace.utils.config import BatchNameConfig, FailIfAlienConfig
8-
from country_workspace.utils.fields import uppercase_field_value, RecordPreprocessor, create_json_record_preprocessor
8+
from country_workspace.utils.fields import uppercase_field_value, clean_field_names
99

1010

1111
class Config(BatchNameConfig, FailIfAlienConfig):
@@ -43,38 +43,30 @@ def import_from_aurora(job: AsyncJob) -> dict[str, int]:
4343
source=Batch.BatchSource.AURORA,
4444
)
4545
client = AuroraClient()
46-
individual_preprocessor = create_json_record_preprocessor(config, job.program.individual_checker)
47-
household_preprocessor = create_json_record_preprocessor(config, job.program.household_checker)
4846
with atomic():
4947
for record in client.get(f"registration/{config['registration_reference_pk']}/records/"):
5048
inds_data = _collect_by_prefix(record["flatten"], config.get("individuals_column_prefix"))
5149
if inds_data:
52-
hh = create_household(
53-
batch, record["flatten"], config.get("household_column_prefix"), household_preprocessor
54-
)
50+
hh = create_household(batch, record["flatten"], config.get("household_column_prefix"))
5551
total_hh += 1
5652
total_ind += len(
5753
create_individuals(
5854
household=hh,
5955
data=inds_data,
6056
household_label_column=config.get("household_label_column"),
61-
preprocess_record=individual_preprocessor,
6257
)
6358
)
6459
return {"households": total_hh, "individuals": total_ind}
6560

6661

67-
def create_household(
68-
batch: Batch, data: dict[str, Any], prefix: str, preprocess_record: RecordPreprocessor
69-
) -> Household:
62+
def create_household(batch: Batch, data: dict[str, Any], prefix: str) -> Household:
7063
"""
7164
Create a Household object from the provided data and associate it with a batch.
7265
7366
Args:
7467
batch (Batch): The batch to which the household will be linked.
7568
data (dict[str, Any]): A dictionary containing household-related information.
7669
prefix (str): The prefix used to filter and group household-related information.
77-
preprocess_record (RecordPreprocessor): The function normalizing field names and checking if they are valid.
7870
7971
Returns:
8072
Household: The newly created household instance.
@@ -87,19 +79,16 @@ def create_household(
8779
if len(flex_fields) > 1:
8880
raise ValueError("Multiple households found")
8981
flex_fields = next(iter(flex_fields.values()), {})
90-
return batch.program.households.create(batch=batch, flex_fields=preprocess_record(flex_fields))
82+
return batch.program.households.create(batch=batch, flex_fields=clean_field_names(flex_fields))
9183

9284

93-
def create_individuals(
94-
household: Household, data: dict[str, Any], household_label_column: str, preprocess_record: RecordPreprocessor
95-
) -> list[Individual]:
85+
def create_individuals(household: Household, data: dict[str, Any], household_label_column: str) -> list[Individual]:
9686
"""Create and associate Individual objects with a given Household.
9787
9888
Args:
9989
household (Household): The household to which the individuals will be linked.
10090
data (dict[str, Any]): A dictionary mapping indices to individual details.
10191
household_label_column (str): The key in the individual data used to determine the household label.
102-
preprocess_record (RecordPreprocessor): The function normalizing field names and checking if they are valid.
10392
10493
Returns:
10594
list[Individual]: A list of successfully created Individual instances.
@@ -109,7 +98,7 @@ def create_individuals(
10998
head_found = False
11099

111100
for raw_individual in data.values():
112-
individual = preprocess_record(raw_individual)
101+
individual = clean_field_names(raw_individual)
113102
if not head_found:
114103
head_found = _update_household_label_from_individual(household, individual, household_label_column)
115104
individuals.append(

src/country_workspace/contrib/kobo/sync.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from country_workspace.contrib.kobo.api.data.asset import Asset
88
from country_workspace.contrib.kobo.api.data.submission import Submission
99
from country_workspace.contrib.kobo.models import KoboSubmission
10-
from country_workspace.models import AsyncJob, Batch, Household, Individual, Program
10+
from country_workspace.models import AsyncJob, Batch, Household, Individual
1111
from country_workspace.utils.config import FailIfAlienConfig, BatchNameConfig
12-
from country_workspace.utils.fields import create_json_record_preprocessor, RecordPreprocessor
12+
from country_workspace.utils.fields import clean_field_names
1313

1414

1515
class Config(BatchNameConfig, FailIfAlienConfig):
@@ -32,9 +32,7 @@ def extract_household_data(submission: Submission, individual_records_field: str
3232
return {key: value for key, value in submission.items() if key != individual_records_field}
3333

3434

35-
def create_individuals(
36-
batch: Batch, household: Household, submission: Submission, config: Config, preprocess_record: RecordPreprocessor
37-
) -> int:
35+
def create_individuals(batch: Batch, household: Household, submission: Submission, config: Config) -> int:
3836
individuals = []
3937
for raw_individual in submission.get(config["individual_records_field"], []):
4038
individual = {
@@ -46,22 +44,20 @@ def create_individuals(
4644
batch=batch,
4745
household=household,
4846
name=individual.get(fullname, ""),
49-
flex_fields=preprocess_record(individual),
47+
flex_fields=clean_field_names(individual),
5048
),
5149
)
5250
household.program.individuals.bulk_create(individuals)
5351
return len(individuals)
5452

5553

56-
def create_household(
57-
batch: Batch, submission: Submission, config: Config, preprocess_record: RecordPreprocessor
58-
) -> Household:
54+
def create_household(batch: Batch, submission: Submission, config: Config) -> Household:
5955
household_fields = extract_household_data(submission, config["individual_records_field"])
6056
return cast(
6157
Household,
6258
batch.program.households.create(
6359
batch=batch,
64-
flex_fields=preprocess_record(household_fields),
60+
flex_fields=clean_field_names(household_fields),
6561
),
6662
)
6763

@@ -74,21 +70,18 @@ class ImportResult(TypedDict):
7470
individuals: int
7571

7672

77-
def import_asset(batch: Batch, asset: Asset, config: Config, program: Program) -> ImportResult:
73+
def import_asset(batch: Batch, asset: Asset, config: Config) -> ImportResult:
7874
household_counter = 0
7975
individual_counter = 0
8076

81-
individual_preprocessor = create_json_record_preprocessor(config, program.individual_checker)
82-
household_preprocessor = create_json_record_preprocessor(config, program.household_checker)
83-
8477
with cache.lock(ASSET_CACHE_KEY.format(asset_id=asset.uid)):
8578
submission_ids = set(KoboSubmission.objects.filter(asset_uid=asset.uid).values_list("submission_id", flat=True))
8679
for submission in asset.submissions:
8780
if submission.id in submission_ids:
8881
continue
89-
household = create_household(batch, submission, config, household_preprocessor)
82+
household = create_household(batch, submission, config)
9083
household_counter += 1
91-
individual_counter += create_individuals(batch, household, submission, config, individual_preprocessor)
84+
individual_counter += create_individuals(batch, household, submission, config)
9285

9386
return ImportResult(households=household_counter, individuals=individual_counter)
9487

@@ -111,7 +104,7 @@ def import_data(job: AsyncJob) -> ImportResult:
111104
for asset in client.assets:
112105
# TODO: fetch specific asset
113106
if config["project_id"] == asset.uid:
114-
import_result = import_asset(batch, asset, config, job.program)
107+
import_result = import_asset(batch, asset, config)
115108
household_counter += import_result["households"]
116109
individual_counter += import_result["individuals"]
117110

src/country_workspace/datasources/rdi.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from country_workspace.models import AsyncJob, Batch, Household
99
from country_workspace.utils.config import FailIfAlienConfig, BatchNameConfig
10-
from country_workspace.utils.fields import create_json_record_preprocessor, Record
10+
from country_workspace.utils.fields import clean_field_names, Record
1111

1212
RDI = str | io.BytesIO
1313
Sheet = Iterable[Record]
@@ -80,21 +80,17 @@ def has_household_pk(row: Record) -> bool:
8080
def process_households(sheet: Sheet, job: AsyncJob, batch: Batch, config: Config) -> Mapping[int, Household]:
8181
mapping = {}
8282

83-
preprocess_json_record = create_json_record_preprocessor(config, job.program.household_checker)
84-
8583
for i, row in enumerate(sheet, 1):
8684
name = get_value(row, config["master_column_label"])
8785
household_key = get_value(row, config["household_pk_col"])
8886

89-
preprocessed_row = preprocess_json_record(row)
90-
9187
try:
9288
mapping[household_key] = cast(
9389
Household,
9490
job.program.households.create(
9591
batch=batch,
9692
name=name,
97-
flex_fields=preprocessed_row,
93+
flex_fields=clean_field_names(row),
9894
),
9995
)
10096
except Exception as e:
@@ -108,8 +104,6 @@ def process_individuals(
108104
) -> int:
109105
processed = 0
110106

111-
preprocess_json_record = create_json_record_preprocessor(config, job.program.individual_checker)
112-
113107
for i, row in enumerate(sheet, 1):
114108
name = get_value(row, config["detail_column_label"])
115109
household_key = get_value(row, config["household_pk_col"])
@@ -118,14 +112,12 @@ def process_individuals(
118112
if not household:
119113
raise MissingHouseholdError(i, household_key)
120114

121-
preprocessed_row = preprocess_json_record(row)
122-
123115
try:
124116
job.program.individuals.create(
125117
batch=batch,
126118
name=name,
127119
household_id=household.pk,
128-
flex_fields=preprocessed_row,
120+
flex_fields=clean_field_names(row),
129121
)
130122
except Exception as e:
131123
raise SheetProcessingError(INDIVIDUAL, i) from e
@@ -138,7 +130,7 @@ def process_individuals(
138130
def validate_households(config: Config, household_mapping: Mapping[int, Household]) -> None:
139131
if config["check_before"]:
140132
for household_key, household in household_mapping.items():
141-
if not household.validate_with_checker():
133+
if not household.validate_with_checker(fail_if_alien=config["fail_if_alien"]):
142134
raise HouseholdValidationError(household_key)
143135

144136

src/country_workspace/models/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def save(
112112
def checker(self) -> "DataChecker":
113113
raise NotImplementedError
114114

115-
def validate_with_checker(self) -> bool:
116-
errors = self.checker.validate([self.flex_fields])
115+
def validate_with_checker(self, fail_if_alien: bool = False) -> bool:
116+
errors = self.checker.validate([self.flex_fields], fail_if_alien=fail_if_alien)
117117
if errors:
118118
self.errors = errors[1]
119119
else:

src/country_workspace/models/household.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ def program(self) -> "Program":
3636
def country_office(self) -> "Office":
3737
return self.batch.program.country_office
3838

39-
def validate_with_checker(self) -> bool:
39+
def validate_with_checker(self, fail_if_alien: bool = False) -> bool:
4040
hh_valid = True
4141
for ind in self.members.all():
42-
if not ind.validate_with_checker():
42+
if not ind.validate_with_checker(fail_if_alien=fail_if_alien):
4343
hh_valid = False
4444
if hh_valid:
45-
super().validate_with_checker()
45+
super().validate_with_checker(fail_if_alien=fail_if_alien)
4646
errors = self.program.beneficiary_validator.validate(self)
4747
if errors:
4848
self.errors["dct"] = errors

src/country_workspace/utils/fields.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
from collections.abc import Callable, Mapping
2-
from typing import Any
32
from functools import reduce
3+
from typing import Any
44

55
from django.utils import timezone
6-
from hope_flex_fields.models import DataChecker
7-
8-
from country_workspace.utils.config import FailIfAlienConfig
96

107
batch_name_default: Callable[[], str] = lambda: f"Batch {timezone.now()}"
118
rdi_name_default: Callable[[], str] = lambda: f"RDI to HOPE {timezone.now()}"
129

1310
Record = Mapping[str, Any]
14-
RecordPreprocessor = Callable[[Record], Record]
1511

1612

1713
TO_REMOVE = "_h_c", "_h_f", "_i_c", "_i_f"
@@ -30,30 +26,8 @@ def clean_field_name(v: str) -> str:
3026
return reduce(lambda name, substr: name.replace(substr, ""), TO_REMOVE, v.lower())
3127

3228

33-
class ExtraFieldInRecordError(Exception):
34-
def __init__(self, *fields: str) -> None:
35-
super().__init__(*fields)
36-
self.fields = fields
37-
38-
def __str__(self) -> str:
39-
return f"Extra fields found: {', '.join(self.fields)}"
40-
41-
42-
def create_json_record_preprocessor(config: FailIfAlienConfig, checker: DataChecker) -> Callable[[Record], Record]:
43-
if config["fail_if_alien"]:
44-
field_names = {field.name for _, field in checker.get_fields()}
45-
else:
46-
field_names = set()
47-
48-
def preprocessor(record: Record) -> Record:
49-
cleaned_record = {clean_field_name(k): v for k, v in record.items()}
50-
51-
if config["fail_if_alien"] and (extra_fields := cleaned_record.keys() - field_names):
52-
raise ExtraFieldInRecordError(*extra_fields)
53-
54-
return cleaned_record
55-
56-
return preprocessor
29+
def clean_field_names(record: Record) -> Record:
30+
return {clean_field_name(k): v for k, v in record.items()}
5731

5832

5933
def uppercase_field_value(k: str, v: Any) -> str:

0 commit comments

Comments
 (0)