Skip to content

Commit 667e581

Browse files
authored
Merge pull request #59 from mindsdb/refactor/modules
[refactor] Inference engines
2 parents 42419a8 + 9c63ba0 commit 667e581

File tree

19 files changed

+817
-739
lines changed

19 files changed

+817
-739
lines changed

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "type_infer"
3-
version = "0.0.17"
3+
version = "0.0.18"
44
description = "Automated type inference for Machine Learning pipelines."
55
authors = ["MindsDB Inc. <[email protected]>"]
66
license = "GPL-3.0"
@@ -15,12 +15,17 @@ numpy = "^1.15"
1515
pandas = "^2"
1616
dataclasses-json = "^0.6.3"
1717
colorlog = "^6.5.0"
18-
langid = "^1.1.6"
19-
nltk = "^3"
20-
toml = "^0.10.2"
2118
psutil = "^5.9.0"
19+
toml = "^0.10.2"
2220

21+
# rule based deps, part of core
22+
langid = "^1.1.6"
23+
nltk = "^3"
2324

2425
[build-system]
2526
requires = ["poetry-core"]
2627
build-backend = "poetry.core.masonry.api"
28+
29+
# TODO: update once this engine is introduced
30+
[tool.poetry.extras]
31+
# bert = ["torch"]

tests/integration_tests/test_type_infer.py renamed to tests/integration_tests/test_rule_based.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from datetime import datetime, timedelta
66

77
from type_infer.dtype import dtype
8-
from type_infer.infer import infer_types
8+
from type_infer.api import infer_types
99

1010

11-
class TestTypeInference(unittest.TestCase):
11+
class TestRuleBasedTypeInference(unittest.TestCase):
1212
def test_0_airline_sentiment(self):
1313
df = pd.read_csv("tests/data/airline_sentiment_sample.csv")
14-
inferred_types = infer_types(df, pct_invalid=0)
14+
config = {'engine': 'rule_based', 'pct_invalid': 0, 'seed': 420, 'mp_cutoff': 1e4}
15+
inferred_types = infer_types(df, config=config)
1516

1617
expected_types = {
1718
'airline_sentiment': 'categorical',
@@ -44,6 +45,7 @@ def test_0_airline_sentiment(self):
4445

4546
def test_1_stack_overflow_survey(self):
4647
df = pd.read_csv("tests/data/stack_overflow_survey_sample.csv")
48+
config = {'engine': 'rule_based', 'pct_invalid': 0, 'seed': 420, 'mp_cutoff': 1e4}
4749

4850
expected_types = {
4951
'Respondent': 'integer',
@@ -68,7 +70,7 @@ def test_1_stack_overflow_survey(self):
6870
'Professional': 'No Information'
6971
}
7072

71-
inferred_types = infer_types(df, pct_invalid=0)
73+
inferred_types = infer_types(df, config=config)
7274

7375
for col in expected_types:
7476
self.assertTrue(expected_types[col], inferred_types.dtypes[col])
@@ -90,7 +92,10 @@ def test_2_simple(self):
9092
# manual tinkering
9193
df['float'].iloc[-n_corrupted:] = 'random string'
9294

93-
inferred_types = infer_types(df, pct_invalid=100 * (n_corrupted) / n_points)
95+
pct_invalid = 100 * (n_corrupted) / n_points
96+
config = {'engine': 'rule_based', 'pct_invalid': pct_invalid, 'seed': 420, 'mp_cutoff': 1e4}
97+
98+
inferred_types = infer_types(df, config=config)
9499
expected_types = {
95100
'date': dtype.date,
96101
'datetime': dtype.datetime,

tests/unit_tests/rule_based/__init__.py

Whitespace-only changes.

tests/unit_tests/test_dates.py renamed to tests/unit_tests/rule_based/test_dates.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
22

33
from type_infer.dtype import dtype
4-
from type_infer.infer import type_check_date
4+
from type_infer.rule_based.core import RuleBasedEngine
5+
6+
type_check_date = RuleBasedEngine.type_check_date
57

68

79
class TestDates(unittest.TestCase):
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest
2+
import random
3+
4+
import pandas as pd
5+
from type_infer.rule_based.core import RuleBasedEngine
6+
from type_infer.dtype import dtype
7+
8+
get_column_data_type = RuleBasedEngine.get_column_data_type
9+
10+
11+
class TestInferDtypes(unittest.TestCase):
12+
def test_negative_integers(self):
13+
data = pd.DataFrame([-random.randint(-10, 10) for _ in range(100)], columns=['test_col'])
14+
engine = RuleBasedEngine()
15+
dtyp, dist, ainfo, warn, info = engine.get_column_data_type(data['test_col'], data, 'test_col', 0.0)
16+
self.assertEqual(dtyp, dtype.integer)
17+
18+
def test_negative_floats(self):
19+
data = pd.DataFrame([float(-random.randint(-10, 10)) for _ in range(100)] + [0.1], columns=['test_col'])
20+
engine = RuleBasedEngine()
21+
dtyp, dist, ainfo, warn, info = engine.get_column_data_type(data['test_col'], data, 'test_col', 0.0)
22+
self.assertEqual(dtyp, dtype.float)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import unittest
2+
3+
from type_infer.rule_based.helpers import tokenize_text
4+
5+
6+
class TestDates(unittest.TestCase):
7+
def test_get_tokens(self):
8+
sentences = ['hello, world!', ' !hello! world!!,..#', '#hello!world']
9+
for sent in sentences:
10+
assert list(tokenize_text(sent)) == ['hello', 'world']
11+
12+
assert list(tokenize_text("don't wouldn't")) == ['do', 'not', 'would', 'not']

tests/unit_tests/test_infer_dtypes.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

tests/unit_tests/test_misc.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from pathlib import Path
44

55
import type_infer
6-
from type_infer.helpers import tokenize_text
76

87

98
class TestDates(unittest.TestCase):
@@ -19,10 +18,3 @@ def test_versions_are_in_sync(self):
1918
package_init_version = type_infer.__version__
2019

2120
self.assertEqual(package_init_version, pyproject_version)
22-
23-
def test_get_tokens(self):
24-
sentences = ['hello, world!', ' !hello! world!!,..#', '#hello!world']
25-
for sent in sentences:
26-
assert list(tokenize_text(sent)) == ['hello', 'world']
27-
28-
assert list(tokenize_text("don't wouldn't")) == ['do', 'not', 'would', 'not']

type_infer/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from type_infer import base
22
from type_infer import dtype
3-
from type_infer import infer
3+
from type_infer import api
44
from type_infer import helpers
55

6+
__version__ = '0.0.18'
67

7-
__version__ = '0.0.17'
88

9-
10-
__all__ = ['base', 'dtype', 'infer', 'helpers', '__version__']
9+
__all__ = [
10+
'__version__',
11+
'base', 'dtype', 'api', 'helpers',
12+
]

type_infer/api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Dict, Optional
2+
import pandas as pd
3+
4+
from type_infer.base import TypeInformation, ENGINES
5+
from type_infer.rule_based.core import RuleBasedEngine
6+
7+
8+
def infer_types(
9+
data: pd.DataFrame,
10+
config: Optional[Dict] = None
11+
) -> TypeInformation:
12+
"""
13+
Infers the data types of each column of the dataset by analyzing a small sample of
14+
each column's items.
15+
16+
Inputs
17+
----------
18+
data : pd.DataFrame
19+
The input dataset for which we want to infer data type information.
20+
"""
21+
# Set global defaults if missing
22+
if config is None:
23+
config = {'engine': 'rule_based', 'pct_invalid': 2, 'seed': 420, 'mp_cutoff': 1e4}
24+
elif 'engine' not in config:
25+
config['engine'] = 'rule_based'
26+
27+
if 'pct_invalid' not in config:
28+
config['pct_invalid'] = 2
29+
30+
if 'seed' not in config:
31+
config['seed'] = 420
32+
33+
if config['engine'] == ENGINES.RULE_BASED:
34+
if 'mp_cutoff' not in config:
35+
config['mp_cutoff'] = 1e4
36+
37+
engine = RuleBasedEngine(config)
38+
return engine.infer(data)
39+
else:
40+
raise Exception(f'Unknown engine {config["engine"]}')

0 commit comments

Comments
 (0)