Skip to content

Commit d9ed7a2

Browse files
committed
fixed dataset read
1 parent e6a5018 commit d9ed7a2

File tree

4 files changed

+113
-14
lines changed

4 files changed

+113
-14
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ model_*
1111
*.jpeg
1212
history.json
1313
data/
14-
!tests/dataset.csv
14+
!tests/dataset.csv

dags/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DEFAULT_AMOUNT_OF_CIRCUITS,
3030
DEFAULT_THREADS,
3131
images_gen_checkpoint_file,
32+
dataset_file
3233
)
3334
from ghz import gen_circuit
3435
from export.kaggle import upload_dataset as upload_dataset_kaggle
@@ -106,7 +107,7 @@ def update_checkpoint(checkpoint: Checkpoint, stage: Stages):
106107
"""
107108

108109
gen_df = PythonOperator(
109-
task_id="gen_df", python_callable=start_df, op_args=[folder]
110+
task_id="gen_df", python_callable=start_df, op_args=[dataset_file(folder)]
110111
)
111112
gen_df.doc_md = """
112113
Generate an empty dataframe and saves it as an csv file.

dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,6 @@ def get_duplicated_files_list_by_diff(df:pl.LazyFrame, clean_df:pl.LazyFrame) ->
384384
"""
385385
Get the files that are duplicated by applying a df diff.
386386
"""
387-
388-
389387
duplicated_files = df.join(clean_df, on=df.collect_schema().names(), how="anti").collect().get_column("file")
390388

391389
return duplicated_files.to_list() # type: ignore
@@ -467,15 +465,18 @@ def append_rows_to_df(file_path: FilePath, rows: Rows):
467465
writer.writerows(rows)
468466

469467

470-
def start_df(base_file_path: FilePath):
468+
def start_df(filename:FilePath):
471469
"""
472470
generates an empty df and saves it on a csv file.
473471
474472
It's not a good idea to use the scan_csv+sink_csv, but for
475473
an empty lazyFrame it works well.
476474
"""
475+
if(os.path.exists(filename)):
476+
return
477+
477478
df = create_df()
478-
save_df(df, dataset_file(base_file_path))
479+
save_df(df, filename)
479480

480481
del df
481482
gc.collect()
@@ -484,7 +485,7 @@ def start_df(base_file_path: FilePath):
484485
def main(args: Arguments):
485486
"""generate, clean and save dataset and images"""
486487

487-
crate_dataset_folder(args.target_folder)
488+
crate_dataset_folder(dataset_file(args.target_folder))
488489

489490
start_df(args.target_folder)
490491

tests/test_dataset_generation.py

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
from typing import List
22
import os
3+
import gc
34

45
import pytest
56
import polars as pl
67

7-
from dataset import clean_duplicated_rows_df, open_csv, save_df, get_duplicated_files_list_by_diff
8+
from dataset import (
9+
clean_duplicated_rows_df,
10+
open_csv,
11+
save_df,
12+
start_df,
13+
get_duplicated_files_list_by_diff
14+
)
15+
from utils.datatypes import df_schema
16+
817

918
@pytest.fixture()
1019
def base_df() -> str:
@@ -13,6 +22,7 @@ def base_df() -> str:
1322
"""
1423
return os.path.join(".", "tests", "dataset.csv")
1524

25+
1626
@pytest.fixture()
1727
def tmp_df() -> str:
1828
"""
@@ -52,15 +62,60 @@ def duplicated_files() -> List[str]:
5262

5363

5464
@pytest.fixture(autouse=True)
55-
def clear_file(tmp_df):
65+
def clear_file(tmp_df, tmp_df2):
5666
"""
5767
Clear tmp csv file
5868
"""
59-
if(not os.path.exists(tmp_df)):
60-
return
61-
os.remove(tmp_df)
69+
if(os.path.exists(tmp_df)):
70+
os.remove(tmp_df)
71+
if(os.path.exists(tmp_df2)):
72+
os.remove(tmp_df2)
73+
74+
class TestCSVFile:
75+
def test_open_csv(self,base_df):
76+
"""
77+
Should open the df with no problems and cast it
78+
to correct data types.
79+
"""
80+
df = open_csv(base_df).collect()
81+
82+
assert len(df) == 11
83+
assert df.schema == df_schema
84+
85+
def test_gen_df_no_previous_file(self,tmp_df):
86+
"""
87+
should create a new csv file.
88+
"""
89+
90+
assert not os.path.exists(tmp_df)
91+
start_df(tmp_df)
92+
assert os.path.exists(tmp_df)
93+
94+
df_data = pl.read_csv(tmp_df)
95+
assert len(df_data) == 0
96+
97+
def test_gen_df_file_already_exists(self,base_df,tmp_df):
98+
"""
99+
Should not overwrite the existent file.
100+
"""
101+
102+
df = pl.read_csv(base_df)
103+
df.write_csv(tmp_df)
104+
105+
assert os.path.exists(tmp_df)
106+
assert len(pl.read_csv(tmp_df)) == 11
107+
start_df(tmp_df)
108+
assert os.path.exists(tmp_df)
109+
assert len(pl.read_csv(tmp_df)) == 11
110+
111+
112+
113+
114+
115+
116+
class TestDatasetClean:
117+
62118

63-
class TestDatasetGeneration:
64119
"""Test dataset generation parts"""
65120

66121
def test_clean_duplicated_rows_return_the_correct_of_rows(self, base_df):
@@ -123,7 +178,7 @@ def test_save_df_with_modifications_different_files_and_rename(self, base_df, tm
123178

124179
assert len(target_csv) == 8
125180

126-
def test_get_duplicated_files_list_by_diff(self,base_df, duplicated_files):
181+
def test_get_duplicated_files_list_by_diff(self, base_df, duplicated_files):
127182
"""
128183
Must take the diff between the raw csv and the cleaned one
129184
and return a list of files that are duplicated and must be
@@ -136,6 +191,48 @@ def test_get_duplicated_files_list_by_diff(self,base_df, duplicated_files):
136191

137192
assert files_list == duplicated_files
138193

194+
def test_remove_duplicates_sequence(self, base_df, tmp_df, tmp_df2):
195+
"""
196+
We must be able to run the entire clean up sequence without losing
197+
any data.
198+
"""
199+
200+
df = pl.read_csv(base_df)
201+
df.write_csv(tmp_df)
202+
203+
del df
204+
gc.collect()
205+
206+
207+
df = open_csv(tmp_df)
208+
assert len(df.collect()) == 11
209+
clean_df = clean_duplicated_rows_df(df)
210+
assert len(clean_df.collect()) == 8
211+
duplicated_files = get_duplicated_files_list_by_diff(df, clean_df)
212+
assert len(duplicated_files) == 3
213+
214+
save_df(clean_df, tmp_df2)
215+
216+
assert os.path.exists(tmp_df2)
217+
assert len(pl.read_csv(tmp_df2)) == 8
218+
219+
os.remove(tmp_df)
220+
os.rename(tmp_df2, tmp_df)
221+
222+
assert os.path.exists(tmp_df)
223+
assert len(pl.read_csv(tmp_df)) == 8
224+
assert not os.path.exists(tmp_df2)
225+
226+
del df
227+
del clean_df
228+
gc.collect()
229+
230+
assert os.path.exists(tmp_df)
231+
assert len(pl.read_csv(tmp_df)) == 8
232+
233+
234+
235+
139236

140237

141238
# SINCE SAVING A LAZY FRAME AS CSV IN THE SAME FILE IS NOT STABLE,

0 commit comments

Comments
 (0)