1414from _ctgan .synthesizer import _CTGANSynthesizer as CTGAN
1515from tabgan .abc_sampler import Sampler , SampleData
1616from tabgan .adversarial_model import AdversarialModel
17- from tabgan .utils import setup_logging
17+ from tabgan .utils import setup_logging , get_year_mnth_dt_from_date , collect_dates
1818
1919warnings .filterwarnings ("ignore" , category = FutureWarning )
2020
@@ -317,12 +317,13 @@ def get_columns_if_exists(df, col) -> pd.DataFrame:
317317
318318if __name__ == "__main__" :
319319 setup_logging (logging .DEBUG )
320+ train_size = 100
320321 train = pd .DataFrame (
321- np .random .randint (- 10 , 150 , size = (100 , 4 )), columns = list ("ABCD" )
322+ np .random .randint (- 10 , 150 , size = (train_size , 4 )), columns = list ("ABCD" )
322323 )
323324 logging .info (train )
324- target = pd .DataFrame (np .random .randint (0 , 2 , size = (100 , 1 )), columns = list ("Y" ))
325- test = pd .DataFrame (np .random .randint (0 , 100 , size = (100 , 4 )), columns = list ("ABCD" ))
325+ target = pd .DataFrame (np .random .randint (0 , 2 , size = (train_size , 1 )), columns = list ("Y" ))
326+ test = pd .DataFrame (np .random .randint (0 , 100 , size = (train_size , 4 )), columns = list ("ABCD" ))
326327 _sampler (OriginalGenerator (gen_x_times = 15 ), train , target , test )
327328 _sampler (
328329 GANGenerator (gen_x_times = 10 , only_generated_data = False ,
@@ -336,3 +337,18 @@ def get_columns_if_exists(df, col) -> pd.DataFrame:
336337 None ,
337338 train ,
338339 )
340+ min_date = pd .to_datetime ('2019-01-01' )
341+ max_date = pd .to_datetime ('2021-12-31' )
342+
343+ d = (max_date - min_date ).days + 1
344+
345+ train ['Date' ] = min_date + pd .to_timedelta (pd .np .random .randint (d , size = train_size ), unit = 'd' )
346+ train = get_year_mnth_dt_from_date (train , 'Date' )
347+
348+ new_train , new_target = GANGenerator (gen_x_times = 1.1 , cat_cols = ['year' ], bot_filter_quantile = 0.001 ,
349+ top_filter_quantile = 0.999 ,
350+ is_post_process = True , pregeneration_frac = 2 , only_generated_data = False ).\
351+ generate_data_pipe (train .drop ('Date' , axis = 1 ), None ,
352+ train .drop ('Date' , axis = 1 )
353+ )
354+ new_train = collect_dates (new_train )
0 commit comments