Skip to content

Commit 3f09c69

Browse files
authored
Merge pull request #274 from microsoft/docstr
update docstr
2 parents 62a3170 + 5b68f55 commit 3f09c69

24 files changed

Lines changed: 1250 additions & 1108 deletions

flaml/automl.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
"""!
2-
* Copyright (c) Microsoft Corporation. All rights reserved.
3-
* Licensed under the MIT License. See LICENSE file in the
4-
* project root for license information.
5-
"""
1+
# !
2+
# * Copyright (c) Microsoft Corporation. All rights reserved.
3+
# * Licensed under the MIT License. See LICENSE file in the
4+
# * project root for license information.
65
import time
76
from typing import Callable, Optional
87
from functools import partial
@@ -311,7 +310,7 @@ def size(state: AutoMLState, config: dict) -> float:
311310

312311

313312
class AutoML:
314-
"""The AutoML class
313+
"""The AutoML class.
315314
316315
Example:
317316
@@ -359,10 +358,10 @@ def model(self):
359358
return self.__dict__.get("_trained_estimator")
360359

361360
def best_model_for_estimator(self, estimator_name):
362-
"""Return the best model found for a particular estimator
361+
"""Return the best model found for a particular estimator.
363362
364363
Args:
365-
estimator_name: a str of the estimator's name
364+
estimator_name: a str of the estimator's name.
366365
367366
Returns:
368367
An object with `predict()` and `predict_proba()` method (for
@@ -398,7 +397,7 @@ def best_config_per_estimator(self):
398397

399398
@property
400399
def best_loss(self):
401-
"""A float of the best loss found"""
400+
"""A float of the best loss found."""
402401
return self._state.best_loss
403402

404403
@property
@@ -421,7 +420,7 @@ def classes_(self):
421420

422421
@property
423422
def time_to_find_best_model(self) -> float:
424-
"""Time taken to find best model in seconds"""
423+
"""Time taken to find best model in seconds."""
425424
return self.__dict__.get("_time_taken_best_iter")
426425

427426
def predict(self, X_test):
@@ -490,7 +489,7 @@ def _preprocess(self, X):
490489
if issparse(X):
491490
X = X.tocsr()
492491
if self._transformer:
493-
X = self._transformer.transform(X, self._state.task)
492+
X = self._transformer.transform(X)
494493
return X
495494

496495
def _validate_data(
@@ -583,13 +582,11 @@ def _validate_data(
583582
X_val.shape[0] == y_val.shape[0]
584583
), "# rows in X_val must match length of y_val."
585584
if self._transformer:
586-
self._state.X_val = self._transformer.transform(X_val, self._state.task)
585+
self._state.X_val = self._transformer.transform(X_val)
587586
else:
588587
self._state.X_val = X_val
589588
if self._label_transformer:
590-
self._state.y_val = self._label_transformer.transform(
591-
y_val, self._state.task
592-
)
589+
self._state.y_val = self._label_transformer.transform(y_val)
593590
else:
594591
self._state.y_val = y_val
595592
else:
@@ -852,26 +849,26 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
852849
)
853850

854851
def add_learner(self, learner_name, learner_class):
855-
"""Add a customized learner
852+
"""Add a customized learner.
856853
857854
Args:
858-
learner_name: A string of the learner's name
859-
learner_class: A subclass of flaml.model.BaseEstimator
855+
learner_name: A string of the learner's name.
856+
learner_class: A subclass of flaml.model.BaseEstimator.
860857
"""
861858
self._state.learner_classes[learner_name] = learner_class
862859

863860
def get_estimator_from_log(self, log_file_name, record_id, task):
864-
"""Get the estimator from log file
861+
"""Get the estimator from log file.
865862
866863
Args:
867-
log_file_name: A string of the log file name
864+
log_file_name: A string of the log file name.
868865
record_id: An integer of the record ID in the file,
869-
0 corresponds to the first trial
866+
0 corresponds to the first trial.
870867
task: A string of the task type,
871-
'binary', 'multi', 'regression', 'ts_forecast', 'rank'
868+
'binary', 'multi', 'regression', 'ts_forecast', 'rank'.
872869
873870
Returns:
874-
An estimator object for the given configuration
871+
An estimator object for the given configuration.
875872
"""
876873

877874
with training_log_reader(log_file_name) as reader:
@@ -910,16 +907,16 @@ def retrain_from_log(
910907
auto_augment=True,
911908
**fit_kwargs,
912909
):
913-
"""Retrain from log file
910+
"""Retrain from log file.
914911
915912
Args:
916-
log_file_name: A string of the log file name
917-
X_train: A numpy array of training data in shape n*m
913+
log_file_name: A string of the log file name.
914+
X_train: A numpy array or dataframe of training data in shape n*m.
918915
For 'ts_forecast' task, the first column of X_train
919916
must be the timestamp column (datetime type). Other
920917
columns in the dataframe are assumed to be exogenous
921918
variables (categorical or numeric).
922-
y_train: A numpy array of labels in shape n*1
919+
y_train: A numpy array or series of labels in shape n*1.
923920
dataframe: A dataframe of training data including label column.
924921
For 'ts_forecast' task, dataframe must be specified and should
925922
have at least two columns: timestamp and label, where the first
@@ -1080,11 +1077,13 @@ def _decide_eval_method(self, time_budget):
10801077

10811078
@property
10821079
def search_space(self) -> dict:
1083-
"""Search space
1084-
Must be called after fit(...) (use max_iter=0 to prevent actual fitting)
1080+
"""Search space.
1081+
1082+
Must be called after fit(...)
1083+
(use max_iter=0 and retrain_final=False to prevent actual fitting).
10851084
10861085
Returns:
1087-
A dict of the search space
1086+
A dict of the search space.
10881087
"""
10891088
estimator_list = self.estimator_list
10901089
if len(estimator_list) == 1:
@@ -1101,7 +1100,7 @@ def search_space(self) -> dict:
11011100

11021101
@property
11031102
def low_cost_partial_config(self) -> dict:
1104-
"""Low cost partial config
1103+
"""Low cost partial config.
11051104
11061105
Returns:
11071106
A dict.
@@ -1112,7 +1111,6 @@ def low_cost_partial_config(self) -> dict:
11121111
to each learner's low_cost_partial_config; the estimator index as
11131112
an integer corresponding to the cheapest learner is appended to the
11141113
list at the end.
1115-
11161114
"""
11171115
if len(self.estimator_list) == 1:
11181116
estimator = self.estimator_list[0]
@@ -1146,7 +1144,6 @@ def cat_hp_cost(self) -> dict:
11461144
a list of the cat_hp_cost's as the value, corresponding
11471145
to each learner's cat_hp_cost; the cost relative to lgbm for each
11481146
learner (as a list itself) is appended to the list at the end.
1149-
11501147
"""
11511148
if len(self.estimator_list) == 1:
11521149
estimator = self.estimator_list[0]
@@ -1198,28 +1195,28 @@ def prune_attr(self) -> Optional[str]:
11981195

11991196
@property
12001197
def min_resource(self) -> Optional[float]:
1201-
"""Attribute for pruning
1198+
"""Attribute for pruning.
12021199
12031200
Returns:
1204-
A float for the minimal sample size or None
1201+
A float for the minimal sample size or None.
12051202
"""
12061203
return self._min_sample_size if self._sample else None
12071204

12081205
@property
12091206
def max_resource(self) -> Optional[float]:
1210-
"""Attribute for pruning
1207+
"""Attribute for pruning.
12111208
12121209
Returns:
1213-
A float for the maximal sample size or None
1210+
A float for the maximal sample size or None.
12141211
"""
12151212
return self._state.data_size if self._sample else None
12161213

12171214
@property
12181215
def trainable(self) -> Callable[[dict], Optional[float]]:
1219-
"""Training function
1216+
"""Training function.
12201217
12211218
Returns:
1222-
A function that evaluates each config and returns the loss
1219+
A function that evaluates each config and returns the loss.
12231220
"""
12241221
self._state.time_from_start = 0
12251222
for estimator in self.estimator_list:
@@ -1255,10 +1252,10 @@ def train(config: dict):
12551252

12561253
@property
12571254
def metric_constraints(self) -> list:
1258-
"""Metric constraints
1255+
"""Metric constraints.
12591256
12601257
Returns:
1261-
A list of the metric constraints
1258+
A list of the metric constraints.
12621259
"""
12631260
constraints = []
12641261
if np.isfinite(self._pred_time_limit):
@@ -1310,7 +1307,7 @@ def fit(
13101307
use_ray=False,
13111308
**fit_kwargs,
13121309
):
1313-
"""Find a model for a given task
1310+
"""Find a model for a given task.
13141311
13151312
Args:
13161313
X_train: A numpy array or a pandas dataframe of training data in
@@ -1499,6 +1496,7 @@ def custom_metric(
14991496
and eval_method == "holdout"
15001497
and self._state.X_val is None
15011498
or eval_method == "cv"
1499+
and (max_iter > 0 or retrain_full is True)
15021500
or max_iter == 1
15031501
)
15041502
self._auto_augment = auto_augment

flaml/data.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
"""!
2-
* Copyright (c) Microsoft Corporation. All rights reserved.
3-
* Licensed under the MIT License.
4-
"""
5-
1+
# !
2+
# * Copyright (c) Microsoft Corporation. All rights reserved.
3+
# * Licensed under the MIT License. See LICENSE file in the
4+
# * project root for license information.
65
import numpy as np
76
from scipy.sparse import vstack, issparse
87
import pandas as pd
@@ -130,17 +129,15 @@ def get_output_from_log(filename, time_budget):
130129
"""Get output from log file
131130
132131
Args:
133-
filename: A string of the log file name
134-
time_budget: A float of the time budget in seconds
132+
filename: A string of the log file name.
133+
time_budget: A float of the time budget in seconds.
135134
136135
Returns:
137-
search_time_list: A list of the finished time of each logged iter
138-
best_error_list:
139-
A list of the best validation error after each logged iter
140-
error_list: A list of the validation error of each logged iter
141-
config_list:
142-
A list of the estimator, sample size and config of each logged iter
143-
logged_metric_list: A list of the logged metric of each logged iter
136+
search_time_list: A list of the finished time of each logged iter.
137+
best_error_list: A list of the best validation error after each logged iter.
138+
error_list: A list of the validation error of each logged iter.
139+
config_list: A list of the estimator, sample size and config of each logged iter.
140+
logged_metric_list: A list of the logged metric of each logged iter.
144141
"""
145142

146143
best_config = None
@@ -208,9 +205,21 @@ def concat(X1, X2):
208205

209206

210207
class DataTransformer:
211-
"""transform X, y"""
208+
"""Transform input training data."""
212209

213210
def fit_transform(self, X, y, task):
211+
"""Fit transformer and process the input training data according to the task type.
212+
213+
Args:
214+
X: A numpy array or a pandas dataframe of training data.
215+
y: A numpy array or a pandas series of labels.
216+
task: A string of the task type, e.g.,
217+
'classification', 'regression', 'ts_forecast', 'rank'.
218+
219+
Returns:
220+
X: Processed numpy array or pandas dataframe of training data.
221+
y: Processed numpy array or pandas series of labels.
222+
"""
214223
if isinstance(X, pd.DataFrame):
215224
X = X.copy()
216225
n = X.shape[0]
@@ -320,17 +329,30 @@ def fit_transform(self, X, y, task):
320329
y = self.label_transformer.fit_transform(y)
321330
else:
322331
self.label_transformer = None
332+
self._task = task
323333
return X, y
324334

325-
def transform(self, X, task):
335+
def transform(self, X):
336+
"""Process data using fit transformer.
337+
338+
Args:
339+
X: A numpy array or a pandas dataframe of training data.
340+
y: A numpy array or a pandas series of labels.
341+
task: A string of the task type, e.g.,
342+
'classification', 'regression', 'ts_forecast', 'rank'.
343+
344+
Returns:
345+
X: Processed numpy array or pandas dataframe of training data.
346+
y: Processed numpy array or pandas series of labels.
347+
"""
326348
X = X.copy()
327349
if isinstance(X, pd.DataFrame):
328350
cat_columns, num_columns, datetime_columns = (
329351
self._cat_columns,
330352
self._num_columns,
331353
self._datetime_columns,
332354
)
333-
if task == TS_FORECAST:
355+
if self._task == TS_FORECAST:
334356
X = X.rename(columns={X.columns[0]: TS_TIMESTAMP_COL})
335357
ds_col = X.pop(TS_TIMESTAMP_COL)
336358
if datetime_columns:
@@ -357,7 +379,7 @@ def transform(self, X, task):
357379
X[column] = X[column].map(datetime.toordinal)
358380
del tmp_dt
359381
X = X[cat_columns + num_columns].copy()
360-
if task == TS_FORECAST:
382+
if self._task == TS_FORECAST:
361383
X.insert(0, TS_TIMESTAMP_COL, ds_col)
362384
for column in cat_columns:
363385
if X[column].dtype.name == "object":

0 commit comments

Comments
 (0)