Skip to content

MAINT, DOC: Modify Hyperband docs for first-time users #671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions dask_ml/model_selection/_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_hyperband_params(R, eta=3):
Returns
-------
brackets : Dict[int, Tuple[int, int]]
A dictionary of the form {bracket_id: (n_models, n_initial_iter)}
A dictionary of the form {bracket_id: (n_params, n_initial_iter)}

Notes
-----
Expand Down Expand Up @@ -177,12 +177,12 @@ class HyperbandSearchCV(BaseIncrementalSearchCV):
before computation happens with ``metadata`` or after computation
happens with ``metadata_``. These dictionaries both have keys

* ``n_models``, an int representing how many models will be/is created.
* ``n_params``, an int representing how many models will be/is created.
* ``partial_fit_calls``, an int representing how many times
``partial_fit`` will be/is called.
* ``brackets``, a list of the brackets that Hyperband runs. Each
bracket has different values for training time importance and
hyperparameter importance. In addition to ``n_models`` and
hyperparameter importance. In addition to ``n_params`` and
``partial_fit_calls``, each element in this list has keys
* ``bracket``, an int the bracket ID. Each bracket corresponds to
a different levels of training time importance.
Expand Down Expand Up @@ -280,16 +280,22 @@ class HyperbandSearchCV(BaseIncrementalSearchCV):
the longest trained model, ``n_examples = 10 * len(X)``.
* how many hyper-parameter combinations to sample (``n_params``)

These can be rough guesses. To determine the chunk size and ``max_iter``,
These can be rough guesses. More parameters than ``n_params`` will be sampled; if necessary, see
:func:`~dask_ml.model_selection.HyperbandSearchCV.metadata`
to see exact number of sampled parameters.

With these constrains, let's define the inputs of Hyperband to be the following:

1. Let the chunks size be ``chunk_size = n_examples / n_params``
2. Let ``max_iter = n_params``

Then, every estimator sees no
more than ``max_iter * chunk_size = n_examples`` examples.
Hyperband will actually sample some more hyper-parameter combinations than
``n_examples`` (which is why rough guesses are adequate). For example,
let's say
One feature of Hyperband and the underlying mathematics is that the
iteration count ``max_iter`` determines the number of parameters that
need to be sampled (which is why ``max_iter == n_params``).

For example, let's say

* about 200 or 300 hyper-parameters need to be tested to effectively
search the possible hyper-parameters
Expand All @@ -299,6 +305,11 @@ class HyperbandSearchCV(BaseIncrementalSearchCV):
Let's decide to provide ``81 * len(X)`` examples and to sample 243
parameters. Then each chunk will be 1/3rd the dataset and ``max_iter=243``.

These chunk size should be specified to make sure that array
is evenly chunked; there shouldn't be any chunks with e.g. 2
examples. Specyfing ``verbose=True`` will display some information about
the chunk sizes.

If you use ``HyperbandSearchCV``, please use the citation for [2]_

.. code-block:: tex
Expand Down Expand Up @@ -453,10 +464,12 @@ def _fit(self, X, y, **fit_params):
{b: SHA.history_ for b, SHA in SHAs.items()}, brackets.keys(), SHAs, key
)

total_pf_calls = sum(m["partial_fit_calls"] for m in meta)
self.metadata_ = {
"n_models": sum(m["n_models"] for m in meta),
"partial_fit_calls": sum(m["partial_fit_calls"] for m in meta),
"n_params_actual": sum(m["n_params"] for m in meta),
"total_partial_fit_calls": total_pf_calls,
"brackets": meta,
"random_search_comparison": self._get_random_comparison(total_pf_calls, self.max_iter)
}

self.best_index_ = int(best_index)
Expand All @@ -476,7 +489,7 @@ def _fit(self, X, y, **fit_params):
@property
def metadata(self):
bracket_info = _hyperband_paper_alg(self.max_iter, eta=self.aggressiveness)
num_models = sum(b["n_models"] for b in bracket_info)
num_models = sum(b["n_params"] for b in bracket_info)
for bracket in bracket_info:
bracket["decisions"] = sorted(list(bracket["decisions"]))
num_partial_fit = sum(b["partial_fit_calls"] for b in bracket_info)
Expand All @@ -490,12 +503,25 @@ def metadata(self):

bracket_info = sorted(bracket_info, key=lambda x: x["bracket"])
info = {
"partial_fit_calls": num_partial_fit,
"n_models": num_models,
"total_partial_fit_calls": num_partial_fit,
"n_params_actual": num_models,
"brackets": bracket_info,
"random_search_comparison": self._get_random_comparison(num_partial_fit, self.max_iter)
}
return info

def _get_random_comparison(self, total_pf_calls, model_pf_calls):
return {
"meta": (
"Assume random search (e.g, RandomizedSearchCV) does the "
"same amount of work as this HyperbandSearchCV (or does "
"the same number of total partial fit calls). How many "
"parameters can be sampled in this case?"
),
"n_params": total_pf_calls / model_pf_calls,
"total_partial_fit_calls": total_pf_calls,
}


def _get_meta(hists, brackets, SHAs, key):
meta_ = []
Expand All @@ -516,7 +542,7 @@ def _get_meta(hists, brackets, SHAs, key):
meta_.append(
{
"decisions": sorted(list(decisions)),
"n_models": len(hist),
"n_params": len(hist),
"bracket": bracket,
"partial_fit_calls": sum(calls.values()),
"SuccessiveHalvingSearchCV params": _get_SHA_params(SHAs[bracket]),
Expand Down Expand Up @@ -600,7 +626,7 @@ def _hyperband_paper_alg(R, eta=3):
info = [
{
"bracket": k,
"n_models": hist["num_estimators"],
"n_params": hist["num_estimators"],
"partial_fit_calls": sum(hist["estimators"].values()),
"decisions": {int(h) for h in hist["decisions"]},
}
Expand Down
25 changes: 19 additions & 6 deletions dask_ml/model_selection/_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ def _partial_fit(model_and_meta, X, y, fit_params):
def _score(model_and_meta, X, y, scorer):
start = time()
model, meta = model_and_meta
if scorer:
score = scorer(model, X, y)
else:
score = model.score(X, y)

with log_errors():
if scorer:
score = scorer(model, X, y)
else:
score = model.score(X, y)

meta = dict(meta)
meta.update(score=score, score_time=time() - start)
Expand Down Expand Up @@ -175,8 +177,19 @@ def _fit(
assert len(X_train) == len(y_train)

train_eg = yield client.map(len, y_train)
msg = "[CV%s] For training there are between %d and %d examples in each chunk"
logger.info(msg, prefix, min(train_eg), max(train_eg))
msg = (
"[CV%s] For chunk passed to partial_fit,"
"there are between %d and %d examples in each chunk. The median chunk"
"size is %d."
)
logger.info(msg, prefix, min(train_eg), max(train_eg), np.median(train_eg))
if min(train_eg) <= 0.5 * max(train_eg):
msg = (
"The number of examples for each partial_fit call is unbalanced. "
"Between {} and {} examples are in each chunk. The median chunk"
"size is {}."
)
warn(msg.format(min(train_eg), max(train_eg), np.median(train_eg)))

# Order by which we process training data futures
order = []
Expand Down
6 changes: 3 additions & 3 deletions dask_ml/model_selection/_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class SuccessiveHalvingSearchCV(IncrementalSearchCV):
computation that will be performed, and ``metadata_`` describes the
computation that has been performed. Both dictionaries have keys

* ``n_models``: the number of models for this run of successive halving
* ``n_params``: the number of models for this run of successive halving
* ``max_iter``: the maximum number of times ``partial_fit`` is called.
At least one model will have this many ``partial_fit`` calls.
* ``partial_fit_calls``: the total number of ``partial_fit`` calls.
Expand Down Expand Up @@ -261,7 +261,7 @@ def metadata(self):
meta = _simulate_sha(n, r, self.aggressiveness, max_iter=self.max_iter)
return {
"partial_fit_calls": meta["total_calls"],
"n_models": self.n_initial_parameters,
"n_params": self.n_initial_parameters,
"max_iter": meta["max_iter"],
}

Expand All @@ -271,7 +271,7 @@ def metadata_(self):
calls = [v[-1]["partial_fit_calls"] for v in self.model_history_.values()]
return {
"partial_fit_calls": sum(calls),
"n_models": self.n_initial_parameters,
"n_params": self.n_initial_parameters,
"max_iter": max(calls),
}

Expand Down
39 changes: 32 additions & 7 deletions tests/model_selection/test_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,21 @@ def _test_mirrors_paper(c, s, a, b):
assert alg.metadata == alg.metadata_

assert isinstance(alg.metadata["brackets"], list)
assert set(alg.metadata.keys()) == {"n_models", "partial_fit_calls", "brackets"}
assert set(alg.metadata.keys()) == {
"n_params_actual",
"total_partial_fit_calls",
"brackets",
"random_search_comparison"
}
assert set(alg.metadata["random_search_comparison"].keys()) == {
"meta", "n_params", "total_partial_fit_calls"
}

# Looping over alg.metadata["bracketes"] is okay because alg.metadata
# == alg.metadata_
for bracket in alg.metadata["brackets"]:
assert set(bracket.keys()) == {
"n_models",
"n_params",
"partial_fit_calls",
"bracket",
"SuccessiveHalvingSearchCV params",
Expand Down Expand Up @@ -182,7 +190,10 @@ def test_hyperband_patience(c, s, a, b):
# This makes sure models aren't trained for too long
assert all(x <= alg_patience + 1 for x in actual_iter)

assert alg.metadata_["partial_fit_calls"] <= alg.metadata["partial_fit_calls"]
assert (
alg.metadata_["total_partial_fit_calls"]
<= alg.metadata["total_partial_fit_calls"]
)
assert alg.best_score_ >= 0.9

max_iter = 6
Expand Down Expand Up @@ -240,9 +251,9 @@ def test_successive_halving_params(c, s, a, b):
metadata = alg.metadata["brackets"]
for k, (true_meta, SHA) in enumerate(zip(metadata, SHAs)):
yield SHA.fit(X, y)
n_models = len(SHA.model_history_)
n_params = len(SHA.model_history_)
pf_calls = [v[-1]["partial_fit_calls"] for v in SHA.model_history_.values()]
assert true_meta["n_models"] == n_models
assert true_meta["n_params"] == n_params
assert true_meta["partial_fit_calls"] == sum(pf_calls)


Expand Down Expand Up @@ -348,7 +359,7 @@ def test_same_random_state_same_params(c, s, a, b):
{"value": values},
random_state=seed,
max_iter=2,
n_initial_parameters=h.metadata["n_models"],
n_initial_parameters=h.metadata["n_params_actual"],
)
X, y = make_classification(n_samples=10, n_features=4, chunks=10)
yield h.fit(X, y)
Expand All @@ -366,7 +377,7 @@ def test_same_random_state_same_params(c, s, a, b):
# Getting the `value`s that are the same for both searches
same = set(v_passive).intersection(set(v_h))

passive_models = h.metadata["brackets"][0]["n_models"]
passive_models = h.metadata["brackets"][0]["n_params"]
assert len(same) == passive_models


Expand Down Expand Up @@ -432,3 +443,17 @@ def test_history(c, s, a, b):
for model_hist in alg.model_history_.values():
calls = [h["partial_fit_calls"] for h in model_hist]
assert (np.diff(calls) >= 1).all() or len(calls) == 1


@gen_cluster(client=True, timeout=5000)
def test_unbalanced_warns(c, s, a, b):
X, y = make_classification(
n_samples=40, n_features=4, chunks=((10, 10, 10, 4, 6), 4)
)
model = ConstantFunction()
params = {"value": scipy.stats.uniform(0, 1)}
alg = HyperbandSearchCV(model, params, max_iter=9, random_state=42)

match = "The number of examples for each partial_fit call is unbalanced"
with pytest.warns(UserWarning, match=match):
yield alg.fit(X, y)
2 changes: 1 addition & 1 deletion tests/model_selection/test_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _test_sha_max_iter(c, s, a, b):
assert search.metadata == search.metadata_
assert set(search.metadata.keys()) == {
"partial_fit_calls",
"n_models",
"n_params",
"max_iter",
}

Expand Down