1- """!
2- * Copyright (c) Microsoft Corporation. All rights reserved.
3- * Licensed under the MIT License. See LICENSE file in the
4- * project root for license information.
5- """
1+ # !
2+ # * Copyright (c) Microsoft Corporation. All rights reserved.
3+ # * Licensed under the MIT License. See LICENSE file in the
4+ # * project root for license information.
65import time
76from typing import Callable , Optional
87from functools import partial
@@ -311,7 +310,7 @@ def size(state: AutoMLState, config: dict) -> float:
311310
312311
313312class AutoML :
314- """The AutoML class
313+ """The AutoML class.
315314
316315 Example:
317316
@@ -359,10 +358,10 @@ def model(self):
359358 return self .__dict__ .get ("_trained_estimator" )
360359
361360 def best_model_for_estimator (self , estimator_name ):
362- """Return the best model found for a particular estimator
361+ """Return the best model found for a particular estimator.
363362
364363 Args:
365- estimator_name: a str of the estimator's name
364+ estimator_name: a str of the estimator's name.
366365
367366 Returns:
368367 An object with `predict()` and `predict_proba()` method (for
@@ -398,7 +397,7 @@ def best_config_per_estimator(self):
398397
399398 @property
400399 def best_loss (self ):
401- """A float of the best loss found"""
400+ """A float of the best loss found. """
402401 return self ._state .best_loss
403402
404403 @property
@@ -421,7 +420,7 @@ def classes_(self):
421420
422421 @property
423422 def time_to_find_best_model (self ) -> float :
424- """Time taken to find best model in seconds"""
423+ """Time taken to find best model in seconds. """
425424 return self .__dict__ .get ("_time_taken_best_iter" )
426425
427426 def predict (self , X_test ):
@@ -490,7 +489,7 @@ def _preprocess(self, X):
490489 if issparse (X ):
491490 X = X .tocsr ()
492491 if self ._transformer :
493- X = self ._transformer .transform (X , self . _state . task )
492+ X = self ._transformer .transform (X )
494493 return X
495494
496495 def _validate_data (
@@ -583,13 +582,11 @@ def _validate_data(
583582 X_val .shape [0 ] == y_val .shape [0 ]
584583 ), "# rows in X_val must match length of y_val."
585584 if self ._transformer :
586- self ._state .X_val = self ._transformer .transform (X_val , self . _state . task )
585+ self ._state .X_val = self ._transformer .transform (X_val )
587586 else :
588587 self ._state .X_val = X_val
589588 if self ._label_transformer :
590- self ._state .y_val = self ._label_transformer .transform (
591- y_val , self ._state .task
592- )
589+ self ._state .y_val = self ._label_transformer .transform (y_val )
593590 else :
594591 self ._state .y_val = y_val
595592 else :
@@ -852,26 +849,26 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
852849 )
853850
854851 def add_learner (self , learner_name , learner_class ):
855- """Add a customized learner
852+ """Add a customized learner.
856853
857854 Args:
858- learner_name: A string of the learner's name
859- learner_class: A subclass of flaml.model.BaseEstimator
855+ learner_name: A string of the learner's name.
856+ learner_class: A subclass of flaml.model.BaseEstimator.
860857 """
861858 self ._state .learner_classes [learner_name ] = learner_class
862859
863860 def get_estimator_from_log (self , log_file_name , record_id , task ):
864- """Get the estimator from log file
861+ """Get the estimator from log file.
865862
866863 Args:
867- log_file_name: A string of the log file name
864+ log_file_name: A string of the log file name.
868865 record_id: An integer of the record ID in the file,
869- 0 corresponds to the first trial
866+ 0 corresponds to the first trial.
870867 task: A string of the task type,
871- 'binary', 'multi', 'regression', 'ts_forecast', 'rank'
868+ 'binary', 'multi', 'regression', 'ts_forecast', 'rank'.
872869
873870 Returns:
874- An estimator object for the given configuration
871+ An estimator object for the given configuration.
875872 """
876873
877874 with training_log_reader (log_file_name ) as reader :
@@ -910,16 +907,16 @@ def retrain_from_log(
910907 auto_augment = True ,
911908 ** fit_kwargs ,
912909 ):
913- """Retrain from log file
910+ """Retrain from log file.
914911
915912 Args:
916- log_file_name: A string of the log file name
917- X_train: A numpy array of training data in shape n*m
913+ log_file_name: A string of the log file name.
914+ X_train: A numpy array or dataframe of training data in shape n*m.
918915 For 'ts_forecast' task, the first column of X_train
919916 must be the timestamp column (datetime type). Other
920917 columns in the dataframe are assumed to be exogenous
921918 variables (categorical or numeric).
922- y_train: A numpy array of labels in shape n*1
919+ y_train: A numpy array or series of labels in shape n*1.
923920 dataframe: A dataframe of training data including label column.
924921 For 'ts_forecast' task, dataframe must be specified and should
925922 have at least two columns: timestamp and label, where the first
@@ -1080,11 +1077,13 @@ def _decide_eval_method(self, time_budget):
10801077
10811078 @property
10821079 def search_space (self ) -> dict :
1083- """Search space
1084- Must be called after fit(...) (use max_iter=0 to prevent actual fitting)
1080+ """Search space.
1081+
1082+ Must be called after fit(...)
1083+ (use max_iter=0 and retrain_final=False to prevent actual fitting).
10851084
10861085 Returns:
1087- A dict of the search space
1086+ A dict of the search space.
10881087 """
10891088 estimator_list = self .estimator_list
10901089 if len (estimator_list ) == 1 :
@@ -1101,7 +1100,7 @@ def search_space(self) -> dict:
11011100
11021101 @property
11031102 def low_cost_partial_config (self ) -> dict :
1104- """Low cost partial config
1103+ """Low cost partial config.
11051104
11061105 Returns:
11071106 A dict.
@@ -1112,7 +1111,6 @@ def low_cost_partial_config(self) -> dict:
11121111 to each learner's low_cost_partial_config; the estimator index as
11131112 an integer corresponding to the cheapest learner is appended to the
11141113 list at the end.
1115-
11161114 """
11171115 if len (self .estimator_list ) == 1 :
11181116 estimator = self .estimator_list [0 ]
@@ -1146,7 +1144,6 @@ def cat_hp_cost(self) -> dict:
11461144 a list of the cat_hp_cost's as the value, corresponding
11471145 to each learner's cat_hp_cost; the cost relative to lgbm for each
11481146 learner (as a list itself) is appended to the list at the end.
1149-
11501147 """
11511148 if len (self .estimator_list ) == 1 :
11521149 estimator = self .estimator_list [0 ]
@@ -1198,28 +1195,28 @@ def prune_attr(self) -> Optional[str]:
11981195
11991196 @property
12001197 def min_resource (self ) -> Optional [float ]:
1201- """Attribute for pruning
1198+ """Attribute for pruning.
12021199
12031200 Returns:
1204- A float for the minimal sample size or None
1201+ A float for the minimal sample size or None.
12051202 """
12061203 return self ._min_sample_size if self ._sample else None
12071204
12081205 @property
12091206 def max_resource (self ) -> Optional [float ]:
1210- """Attribute for pruning
1207+ """Attribute for pruning.
12111208
12121209 Returns:
1213- A float for the maximal sample size or None
1210+ A float for the maximal sample size or None.
12141211 """
12151212 return self ._state .data_size if self ._sample else None
12161213
12171214 @property
12181215 def trainable (self ) -> Callable [[dict ], Optional [float ]]:
1219- """Training function
1216+ """Training function.
12201217
12211218 Returns:
1222- A function that evaluates each config and returns the loss
1219+ A function that evaluates each config and returns the loss.
12231220 """
12241221 self ._state .time_from_start = 0
12251222 for estimator in self .estimator_list :
@@ -1255,10 +1252,10 @@ def train(config: dict):
12551252
12561253 @property
12571254 def metric_constraints (self ) -> list :
1258- """Metric constraints
1255+ """Metric constraints.
12591256
12601257 Returns:
1261- A list of the metric constraints
1258+ A list of the metric constraints.
12621259 """
12631260 constraints = []
12641261 if np .isfinite (self ._pred_time_limit ):
@@ -1310,7 +1307,7 @@ def fit(
13101307 use_ray = False ,
13111308 ** fit_kwargs ,
13121309 ):
1313- """Find a model for a given task
1310+ """Find a model for a given task.
13141311
13151312 Args:
13161313 X_train: A numpy array or a pandas dataframe of training data in
@@ -1499,6 +1496,7 @@ def custom_metric(
14991496 and eval_method == "holdout"
15001497 and self ._state .X_val is None
15011498 or eval_method == "cv"
1499+ and (max_iter > 0 or retrain_full is True )
15021500 or max_iter == 1
15031501 )
15041502 self ._auto_augment = auto_augment
0 commit comments