Skip to content

Commit 4946699

Browse files
committed
dev: filter_on_column, tests big refactor
1 parent bb38493 commit 4946699

9 files changed

+210
-93
lines changed

Diff for: data_flow/data_flow.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
import os
22
import tempfile
3+
from typing import Any
34

45
import fireducks.pandas as fd
56
import pandas as pd
67
import polars as pl
78
from pyarrow import feather
89

9-
from data_flow.lib import FileType
10-
from data_flow.lib.data_columns import data_get_columns, data_delete_columns, data_rename_columns, data_select_columns
10+
from data_flow.lib import FileType, Operator
11+
from data_flow.lib.data_columns import (
12+
data_get_columns,
13+
data_delete_columns,
14+
data_rename_columns,
15+
data_select_columns,
16+
data_filter_on_column,
17+
)
1118
from data_flow.lib.data_from import (
1219
from_csv_2_file,
1320
from_feather_2_file,
@@ -188,7 +195,26 @@ def columns_select(self, columns: list):
188195
else:
189196
data_select_columns(tmp_filename=self.__filename, file_type=self.__file_type, columns=columns)
190197

191-
# def filter_on_column(self, column: str, value: Any, operator: Operator):
192-
# if self.__in_memory:
193-
#
194-
#
198+
def filter_on_column(self, column: str, value: Any, operator: Operator):
199+
if self.__in_memory:
200+
match operator:
201+
case Operator.Eq:
202+
self.__data = self.__data[self.__data[column] == value]
203+
case Operator.Gte:
204+
self.__data = self.__data[self.__data[column] >= value]
205+
case Operator.Lte:
206+
self.__data = self.__data[self.__data[column] <= value]
207+
case Operator.Gt:
208+
self.__data = self.__data[self.__data[column] > value]
209+
case Operator.Lt:
210+
self.__data = self.__data[self.__data[column] < value]
211+
case Operator.Ne:
212+
self.__data = self.__data[self.__data[column] != value]
213+
else:
214+
data_filter_on_column(
215+
tmp_filename=self.__filename,
216+
file_type=self.__file_type,
217+
column=column,
218+
value=value,
219+
operator=operator,
220+
)

Diff for: data_flow/lib/data_columns.py

+33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from typing import Any
2+
13
import fireducks.pandas as fd
24

35
from data_flow.lib.FileType import FileType
6+
from data_flow.lib.Operator import Operator
47

58

69
def data_get_columns(tmp_filename: str, file_type: FileType) -> list:
@@ -47,3 +50,33 @@ def data_select_columns(tmp_filename: str, file_type: FileType, columns: list) -
4750
data.to_feather(tmp_filename)
4851
case _:
4952
raise ValueError(f"File type not implemented: {file_type} !")
53+
54+
55+
def data_filter_on_column(tmp_filename: str, file_type: FileType, column: str, value: Any, operator: Operator) -> None:
56+
match file_type:
57+
case FileType.parquet:
58+
data = fd.read_parquet(tmp_filename)
59+
case FileType.feather:
60+
data = fd.read_feather(tmp_filename)
61+
case _:
62+
raise ValueError(f"File type not implemented: {file_type} !")
63+
64+
match operator:
65+
case Operator.Eq:
66+
data = data[data[column] == value]
67+
case Operator.Gte:
68+
data = data[data[column] >= value]
69+
case Operator.Lte:
70+
data = data[data[column] <= value]
71+
case Operator.Gt:
72+
data = data[data[column] > value]
73+
case Operator.Lt:
74+
data = data[data[column] < value]
75+
case Operator.Ne:
76+
data = data[data[column] != value]
77+
78+
match file_type:
79+
case FileType.parquet:
80+
data.to_parquet(tmp_filename)
81+
case FileType.feather:
82+
data.to_feather(tmp_filename)

Diff for: tests/BaseTestCase.py

+80
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import unittest
2+
from typing import Callable
23
from zipfile import ZipFile
34

45
import pandas as pd
56

7+
from data_flow import DataFlow
8+
from data_flow.lib import Operator
9+
610

711
class BaseTestCase(unittest.TestCase):
812
def setUp(self):
@@ -18,3 +22,79 @@ def setUp(self):
1822

1923
def assertPandasEqual(self, df1: pd.DataFrame, df2: pd.DataFrame):
2024
self.assertTrue(df1.equals(df2), "Pandas DataFrames are not equal !")
25+
26+
def all(self, function: Callable):
27+
self._sequence(data=function())
28+
self._filter_Eq(data=function())
29+
self._filter_Gte(data=function())
30+
self._filter_Lte(data=function())
31+
self._filter_Gt(data=function())
32+
self._filter_Lt(data=function())
33+
self._filter_Ne(data=function())
34+
35+
# @count_assertions
36+
def _sequence(self, data: DataFlow.DataFrame) -> None:
37+
self.assertPandasEqual(data.to_pandas(), DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas())
38+
polars = data.to_polars()
39+
40+
self.assertEqual(10, len(data.columns()))
41+
42+
data.columns_delete(
43+
[
44+
"Industry_aggregation_NZSIOC",
45+
"Industry_code_NZSIOC",
46+
"Industry_name_NZSIOC",
47+
"Industry_code_ANZSIC06",
48+
"Variable_code",
49+
"Variable_name",
50+
"Variable_category",
51+
]
52+
)
53+
54+
self.assertEqual(3, len(data.columns()))
55+
self.assertListEqual(["Year", "Units", "Value"], data.columns())
56+
57+
data.columns_rename(columns_mapping={"Year": "_year_", "Units": "_units_"})
58+
self.assertListEqual(["_year_", "_units_", "Value"], data.columns())
59+
60+
data.columns_select(columns=["_year_"])
61+
self.assertListEqual(["_year_"], data.columns())
62+
63+
self.assertPandasEqual(
64+
DataFlow().DataFrame().from_polars(polars).to_pandas(),
65+
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_pandas(),
66+
)
67+
68+
def _filter_Eq(self, data: DataFlow.DataFrame) -> None:
69+
data.filter_on_column(column="Year", operator=Operator.Eq, value=2018)
70+
self.assertListEqual([2018], list(data.to_pandas().Year.unique()))
71+
72+
def _filter_Gte(self, data: DataFlow.DataFrame) -> None:
73+
data.filter_on_column(column="Year", operator=Operator.Gte, value=2018)
74+
result = data.to_pandas().Year.unique().tolist()
75+
result.sort()
76+
self.assertListEqual([2018, 2019, 2020, 2021, 2022, 2023], result)
77+
78+
def _filter_Lte(self, data: DataFlow.DataFrame) -> None:
79+
data.filter_on_column(column="Year", operator=Operator.Lte, value=2018)
80+
result = data.to_pandas().Year.unique().tolist()
81+
result.sort()
82+
self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2018], result)
83+
84+
def _filter_Gt(self, data: DataFlow.DataFrame) -> None:
85+
data.filter_on_column(column="Year", operator=Operator.Gt, value=2018)
86+
result = data.to_pandas().Year.unique().tolist()
87+
result.sort()
88+
self.assertListEqual([2019, 2020, 2021, 2022, 2023], result)
89+
90+
def _filter_Lt(self, data: DataFlow.DataFrame) -> None:
91+
data.filter_on_column(column="Year", operator=Operator.Lt, value=2018)
92+
result = data.to_pandas().Year.unique().tolist()
93+
result.sort()
94+
self.assertListEqual([2013, 2014, 2015, 2016, 2017], result)
95+
96+
def _filter_Ne(self, data: DataFlow.DataFrame) -> None:
97+
data.filter_on_column(column="Year", operator=Operator.Ne, value=2018)
98+
result = data.to_pandas().Year.unique().tolist()
99+
result.sort()
100+
self.assertListEqual([2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021, 2022, 2023], result)

Diff for: tests/SequenceTestCase.py

-37
This file was deleted.

Diff for: tests/test_data_flow_csv.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,32 @@
33
from data_flow import DataFlow
44
from data_flow.lib import FileType
55
from data_flow.lib.tools import delete_file
6-
from tests.SequenceTestCase import SequenceTestCase
6+
from tests.BaseTestCase import BaseTestCase
77

88

9-
class DataFlowCSVTestCase(SequenceTestCase):
9+
class DataFlowCSVTestCase(BaseTestCase):
1010
def setUp(self):
1111
super().setUp()
1212
delete_file(self.TEST_CSV_FILE)
1313
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_csv(self.TEST_CSV_FILE)
1414

1515
def test_memory(self):
16-
df = DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE)
17-
18-
self._sequence(data=df)
16+
self.all(self.__memory)
1917

2018
def test_parquet(self):
21-
df = DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE)
22-
23-
self._sequence(data=df)
19+
self.all(self.__parquet)
2420

2521
def test_feather(self):
26-
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE)
22+
self.all(self.__feather)
23+
24+
def __memory(self) -> DataFlow.DataFrame:
25+
return DataFlow().DataFrame().from_csv(self.TEST_CSV_FILE)
26+
27+
def __parquet(self) -> DataFlow.DataFrame:
28+
return DataFlow().DataFrame(in_memory=False).from_csv(self.TEST_CSV_FILE)
2729

28-
self._sequence(data=df)
30+
def __feather(self) -> DataFlow.DataFrame:
31+
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_csv(self.TEST_CSV_FILE)
2932

3033

3134
if __name__ == "__main__":

Diff for: tests/test_data_flow_feather.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,32 @@
33
from data_flow import DataFlow
44
from data_flow.lib import FileType
55
from data_flow.lib.tools import delete_file
6-
from tests.SequenceTestCase import SequenceTestCase
6+
from tests.BaseTestCase import BaseTestCase
77

88

9-
class DataFlowFeatherTestCase(SequenceTestCase):
9+
class DataFlowFeatherTestCase(BaseTestCase):
1010
def setUp(self):
1111
super().setUp()
1212
delete_file(self.TEST_FEATHER_FILE)
1313
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_feather(self.TEST_FEATHER_FILE)
1414

1515
def test_memory(self):
16-
df = DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE)
17-
18-
self._sequence(data=df)
16+
self.all(self.__memory)
1917

2018
def test_parquet(self):
21-
df = DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE)
22-
23-
self._sequence(data=df)
19+
self.all(self.__parquet)
2420

2521
def test_feather(self):
26-
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE)
22+
self.all(self.__feather)
23+
24+
def __memory(self) -> DataFlow.DataFrame:
25+
return DataFlow().DataFrame().from_feather(self.TEST_FEATHER_FILE)
26+
27+
def __parquet(self) -> DataFlow.DataFrame:
28+
return DataFlow().DataFrame(in_memory=False).from_feather(self.TEST_FEATHER_FILE)
2729

28-
self._sequence(data=df)
30+
def __feather(self) -> DataFlow.DataFrame:
31+
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_feather(self.TEST_FEATHER_FILE)
2932

3033

3134
if __name__ == "__main__":

Diff for: tests/test_data_flow_hdf.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,32 @@
33
from data_flow import DataFlow
44
from data_flow.lib import FileType
55
from data_flow.lib.tools import delete_file
6-
from tests.SequenceTestCase import SequenceTestCase
6+
from tests.BaseTestCase import BaseTestCase
77

88

9-
class DataFlowHdfTestCase(SequenceTestCase):
9+
class DataFlowHdfTestCase(BaseTestCase):
1010
def setUp(self):
1111
super().setUp()
1212
delete_file(self.TEST_HDF_FILE)
1313
DataFlow().DataFrame().from_csv(self.CSV_FILE).to_hdf(self.TEST_HDF_FILE)
1414

1515
def test_memory(self):
16-
df = DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE)
17-
18-
self._sequence(data=df)
16+
self.all(self.__memory)
1917

2018
def test_parquet(self):
21-
df = DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE)
22-
23-
self._sequence(data=df)
19+
self.all(self.__parquet)
2420

2521
def test_feather(self):
26-
df = DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE)
22+
self.all(self.__feather)
23+
24+
def __memory(self) -> DataFlow.DataFrame:
25+
return DataFlow().DataFrame().from_hdf(self.TEST_HDF_FILE)
26+
27+
def __parquet(self) -> DataFlow.DataFrame:
28+
return DataFlow().DataFrame(in_memory=False).from_hdf(self.TEST_HDF_FILE)
2729

28-
self._sequence(data=df)
30+
def __feather(self) -> DataFlow.DataFrame:
31+
return DataFlow().DataFrame(in_memory=False, file_type=FileType.feather).from_hdf(self.TEST_HDF_FILE)
2932

3033

3134
if __name__ == "__main__":

0 commit comments

Comments
 (0)