11"""
22The highest level classes for pipelines.
33"""
4- import os
54import copy
5+ import os
66import pickle
77from typing import Dict
88
99import pandas as pd
10-
11- from automatminer .base import LoggableMixin , DFTransformer
10+ from automatminer .base import DFTransformer , LoggableMixin
1211from automatminer .presets import get_preset_config
13- from automatminer .utils .ml import regression_or_classification
14- from automatminer .utils .pkg import check_fitted , set_fitted , \
15- return_attrs_recursively , AutomatminerError , VersionError , get_version , \
16- save_dict_to_file
1712from automatminer .utils .log import AMM_DEFAULT_LOGGER
13+ from automatminer .utils .ml import regression_or_classification
14+ from automatminer .utils .pkg import (
15+ AutomatminerError ,
16+ VersionError ,
17+ check_fitted ,
18+ get_version ,
19+ return_attrs_recursively ,
20+ save_dict_to_file ,
21+ set_fitted ,
22+ )
1823
1924
2025class MatPipe (DFTransformer , LoggableMixin ):
@@ -88,15 +93,23 @@ class MatPipe(DFTransformer, LoggableMixin):
8893 target (str): The name of the column where target values are held.
8994 """
9095
91- def __init__ (self , autofeaturizer = None , cleaner = None , reducer = None ,
92- learner = None , logger = AMM_DEFAULT_LOGGER ):
96+ def __init__ (
97+ self ,
98+ autofeaturizer = None ,
99+ cleaner = None ,
100+ reducer = None ,
101+ learner = None ,
102+ logger = AMM_DEFAULT_LOGGER ,
103+ ):
93104 transformers = [autofeaturizer , cleaner , reducer , learner ]
94105 if not all (transformers ):
95106 if any (transformers ):
96- raise AutomatminerError ("Please specify all dataframe"
97- "transformers (autofeaturizer, learner,"
98- "reducer, and cleaner), or none (to use"
99- "default)." )
107+ raise AutomatminerError (
108+ "Please specify all dataframe"
109+ "transformers (autofeaturizer, learner,"
110+ "reducer, and cleaner), or none (to use"
111+ "default)."
112+ )
100113 else :
101114 config = get_preset_config ("express" )
102115 autofeaturizer = config ["autofeaturizer" ]
@@ -117,7 +130,7 @@ def __init__(self, autofeaturizer=None, cleaner=None, reducer=None,
117130 super (MatPipe , self ).__init__ ()
118131
119132 @staticmethod
120- def from_preset (preset : str = ' express' , ** powerups ):
133+ def from_preset (preset : str = " express" , ** powerups ):
121134 """
122135 Get a preset MatPipe from a string using
123136 automatminer.presets.get_preset_config
@@ -238,8 +251,7 @@ def predict(self, df, ignore=None):
238251 return merged_df
239252
240253 @set_fitted
241- def benchmark (self , df , target , kfold , fold_subset = None , cache = False ,
242- ignore = None ):
254+ def benchmark (self , df , target , kfold , fold_subset = None , cache = False , ignore = None ):
243255 """
244256 If the target property is known for all data, perform an ML benchmark
245257 using MatPipe. Used for getting an idea of how well AutoML can predict
@@ -292,22 +304,26 @@ def benchmark(self, df, target, kfold, fold_subset=None, cache=False,
292304 if os .path .exists (cache_src ):
293305 self .logger .warning (
294306 "Cache src {} already found! Ensure this featurized data "
295- "matches the df being benchmarked." .format (cache_src ))
307+ "matches the df being benchmarked." .format (cache_src )
308+ )
296309 self .logger .warning ("Running pre-featurization for caching." )
297310 self .autofeaturizer .fit_transform (df , target )
298311 elif cache_src and not cache :
299312 raise AutomatminerError (
300313 "Caching was enabled in AutoFeaturizer but not in benchmark. "
301314 "Either disable caching in AutoFeaturizer or enable it by "
302- "passing cache=True to benchmark." )
315+ "passing cache=True to benchmark."
316+ )
303317 elif cache and not cache_src :
304318 raise AutomatminerError (
305319 "MatPipe cache is enabled, but no cache_src was defined in "
306320 "autofeaturizer. Pass the cache_src argument to AutoFeaturizer "
307- "or use the cache_src get_preset_config powerup." )
321+ "or use the cache_src get_preset_config powerup."
322+ )
308323 else :
309- self .logger .debug ("No caching being used in AutoFeaturizer or "
310- "benchmark." )
324+ self .logger .debug (
325+ "No caching being used in AutoFeaturizer or " "benchmark."
326+ )
311327
312328 if not fold_subset :
313329 fold_subset = list (range (kfold .n_splits ))
@@ -372,25 +388,20 @@ def summarize(self, filename=None) -> Dict[str, str]:
372388 "drop_na_targets" ,
373389 ]
374390 cleaner_data = {
375- attr : str (getattr (self .cleaner , attr ))
376- for attr in cleaner_attrs
391+ attr : str (getattr (self .cleaner , attr )) for attr in cleaner_attrs
377392 }
378393
379- reducer_attrs = [
380- "reducers" ,
381- "reducer_params" ,
382- ]
394+ reducer_attrs = ["reducers" , "reducer_params" ]
383395 reducer_data = {
384- attr : str (getattr (self .reducer , attr ))
385- for attr in reducer_attrs
396+ attr : str (getattr (self .reducer , attr )) for attr in reducer_attrs
386397 }
387398
388399 attrs = {
389400 "featurizers" : self .autofeaturizer .featurizers ,
390401 "ml_model" : str (self .learner .best_pipeline ),
391402 "feature_reduction" : reducer_data ,
392403 "data_cleaning" : cleaner_data ,
393- "features" : self .learner .features
404+ "features" : self .learner .features ,
394405 }
395406 if filename :
396407 save_dict_to_file (attrs , filename )
@@ -416,12 +427,16 @@ def save(self, filename="mat.pipe"):
416427
417428 temp_logger = copy .deepcopy (self ._logger )
418429 loggables = [
419- self , self .learner , self .reducer , self .cleaner , self .autofeaturizer
430+ self ,
431+ self .learner ,
432+ self .reducer ,
433+ self .cleaner ,
434+ self .autofeaturizer ,
420435 ]
421436 for loggable in loggables :
422437 loggable ._logger = AMM_DEFAULT_LOGGER
423438
424- with open (filename , 'wb' ) as f :
439+ with open (filename , "wb" ) as f :
425440 pickle .dump (self , f )
426441
427442 # Reassign live memory objects for further use in this object
@@ -446,7 +461,7 @@ def load(filename, logger=True, supress_version_mismatch=False):
446461 Returns:
447462 pipe (MatPipe): A MatPipe object.
448463 """
449- with open (filename , 'rb' ) as f :
464+ with open (filename , "rb" ) as f :
450465 pipe = pickle .load (f )
451466
452467 if pipe .version != get_version () and not supress_version_mismatch :
0 commit comments