Skip to content

Commit 59f2729

Browse files
authored
Add --infer-enum-from-integers flag for low-cardinality integer enum detection
1 parent de982a9 commit 59f2729

3 files changed

Lines changed: 66 additions & 0 deletions

File tree

schema_automator/cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
infer_foreign_keys_option = click.option('--infer-foreign-keys/--no-infer-foreign-keys', default=False, help='infer ranges/foreign keys')
7373
infer_optional_option = click.option('--infer-optional/--no-infer-optional', default=False, help='mark slots as not required when columns have null or empty values (ignored in pandera mode)')
7474
infer_mixed_types_option = click.option('--infer-mixed-types/--no-infer-mixed-types', default=False, help='use any_of to represent columns with mixed types')
75+
infer_enum_from_integers_option = click.option('--infer-enum-from-integers/--no-infer-enum-from-integers', default=False, help='treat low-cardinality integer columns as enum candidates')
7576
enum_columns_option = click.option('--enum-columns', '-E', multiple=True, help='column(s) that is forced to be an enum')
7677
enum_mask_columns_option = click.option('--enum-mask-columns', multiple=True, help='column(s) that are excluded from being enums')
7778
max_enum_size_option = click.option('--max-enum-size', default=50, help='do not create an enum if more than max distinct members')
@@ -123,6 +124,7 @@ def main(verbose: int, quiet: bool):
123124
@max_enum_size_option
124125
@infer_optional_option
125126
@infer_mixed_types_option
127+
@infer_enum_from_integers_option
126128
@click.option('--data-dictionary-row-count',
127129
type=click.INT,
128130
help='rows that provide metadata about columns')
@@ -163,6 +165,7 @@ def generalize_tsv(tsvfile, output, class_name, schema_name, pandera: bool, anno
163165
@max_enum_size_option
164166
@infer_optional_option
165167
@infer_mixed_types_option
168+
@infer_enum_from_integers_option
166169
@click.option('--robot/--no-robot', default=False, help='set if the TSV is a ROBOT template')
167170
def generalize_tsvs(tsvfiles, output, schema_name, **kwargs):
168171
"""
@@ -193,6 +196,7 @@ def generalize_tsvs(tsvfiles, output, schema_name, **kwargs):
193196
@max_enum_size_option
194197
@infer_optional_option
195198
@infer_mixed_types_option
199+
@infer_enum_from_integers_option
196200
@click.option('--class-name', '-c', default=DEFAULT_CLASS_NAME, help='Core class name in schema')
197201
@click.option('--pandera/--no-pandera', default=False, help='set to use panderas as inference engine')
198202
@click.option('--data-output', help='Path to file of downloaded data')

schema_automator/generalizers/csv_data_generalizer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ class CsvDataGeneralizer(Generalizer):
116116
infer_mixed_types: bool = False
117117
"""If true, use any_of to represent columns with mixed types instead of collapsing to string"""
118118

119+
infer_enum_from_integers: bool = False
120+
"""If true, treat low-cardinality integer columns as enum candidates"""
121+
119122
def infer_linkages(self, files: List[str], **kwargs) -> List[ForeignKey]:
120123
"""
121124
Heuristic procedure for determining which tables are linked to others via implicit foreign keys
@@ -456,6 +459,21 @@ def convert_dicts(self,
456459
logging.info(f"Slot {sn} has range {s['range']}")
457460
if self.infer_optional and sn in slot_has_nulls and not s.get('identifier'):
458461
s['required'] = False
462+
if (self.infer_enum_from_integers
463+
and s.get('range') == 'integer'
464+
and sn not in enum_mask_columns
465+
and not s.get('identifier')):
466+
n_distinct = len(vals)
467+
n_total = len(slot_values[sn]) + 1
468+
if (sn in enum_columns
469+
or ((n_distinct / n_total) < self.enum_threshold
470+
and 0 < n_distinct <= self.max_enum_size)):
471+
enum_name = sn.replace(' ', '_').replace('(s)', '') + '_enum'
472+
s['range'] = enum_name
473+
enums[enum_name] = {
474+
'permissible_values': {str(v): {'description': str(v)} for v in vals}
475+
}
476+
logging.info(f"Slot {sn}: low-cardinality integers treated as enum {enum_name}")
459477
if 'any_of' not in s and (s.get('range') == 'string' or sn in enum_columns) and sn not in enum_mask_columns:
460478
filtered_vals = \
461479
[v

tests/test_generalizers/test_csv_data_generalizer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,50 @@ def test_infer_mixed_types_off_by_default(self):
170170
self.assertEqual(schema.slots["score"].range, "string")
171171
self.assertEqual(len(schema.slots["score"].any_of), 0)
172172

173+
def test_infer_enum_from_integers(self):
174+
rows = [
175+
{"id": "1", "name": "Alice", "status": "1"},
176+
{"id": "2", "name": "Bob", "status": "2"},
177+
{"id": "3", "name": "Carol", "status": "1"},
178+
{"id": "4", "name": "Dave", "status": "2"},
179+
{"id": "5", "name": "Eve", "status": "1"},
180+
{"id": "6", "name": "Frank", "status": "2"},
181+
{"id": "7", "name": "Grace", "status": "1"},
182+
{"id": "8", "name": "Hank", "status": "2"},
183+
{"id": "9", "name": "Ivy", "status": "1"},
184+
{"id": "10", "name": "Jack", "status": "2"},
185+
]
186+
ie = CsvDataGeneralizer(infer_enum_from_integers=True, enum_threshold=0.5)
187+
schema = ie.convert_dicts(rows, "test", "Pet")
188+
# status has 2 distinct values out of 10 rows => ratio 0.18 < 0.5 threshold
189+
self.assertEqual(schema.slots["status"].range, "status_enum")
190+
pvs = list(schema.enums["status_enum"].permissible_values.keys())
191+
self.assertCountEqual(pvs, ["1", "2"])
192+
193+
def test_infer_enum_from_integers_high_cardinality_stays_integer(self):
194+
rows = [{"id": str(i), "val": str(i)} for i in range(1, 21)]
195+
ie = CsvDataGeneralizer(infer_enum_from_integers=True, enum_threshold=0.1)
196+
schema = ie.convert_dicts(rows, "test", "Thing")
197+
# 20 distinct out of 20 rows => ratio 1.0, well above threshold
198+
self.assertEqual(schema.slots["val"].range, "integer")
199+
200+
def test_infer_enum_from_integers_off_by_default(self):
201+
rows = [
202+
{"id": "1", "status": "1"},
203+
{"id": "2", "status": "2"},
204+
{"id": "3", "status": "1"},
205+
{"id": "4", "status": "2"},
206+
{"id": "5", "status": "1"},
207+
{"id": "6", "status": "2"},
208+
{"id": "7", "status": "1"},
209+
{"id": "8", "status": "2"},
210+
{"id": "9", "status": "1"},
211+
{"id": "10", "status": "2"},
212+
]
213+
ie = CsvDataGeneralizer()
214+
schema = ie.convert_dicts(rows, "test", "Pet")
215+
self.assertEqual(schema.slots["status"].range, "integer")
216+
173217
def _convert(self, base_name: str, cn='Example', index_slot='examples') -> SchemaDefinition:
174218
ie = CsvDataGeneralizer()
175219
fn = f'{base_name}.tsv'

0 commit comments

Comments
 (0)