Note
Machine Learning support is experimental in dask-sql
.
We encourage you to try it out and report any issues on our
issue tracker.
As all SQL statements in dask-sql
are eventually converted to Python calls, it is very simple to include
any custom Python function and library, e.g. Machine Learning libraries. Although it would be possible to
register custom functions (see :ref:`custom`) for this and use them, it is much more convenient if this functionality
is already included in the core SQL language.
These three statements help in training and using models. Every :class:`~dask_sql.Context` has a registry for models, which
can be used for training or prediction.
For a full example, see :ref:`machine_learning`.
CREATE [ OR REPLACE ] MODEL [ IF NOT EXISTS ] <model-name> WITH ( <key> = <value> [ , ... ] ) AS ( SELECT ... ) DROP MODEL [ IF EXISTS ] <model-name> SELECT <expression> FROM PREDICT (MODEL <model-name>, SELECT ... )
IF [ NOT ] EXISTS
and CREATE OR REPLACE
behave similar to its analogous flags in CREATE TABLE
.
See :ref:`creation` for more information.
Create and train a model on the data from the given SELECT query and register it at the context.
The select query is a normal SELECT
query (following the same syntax as described in :ref:`select`
or even a call to PREDICT
(which typically does not make sense however) and its
result is used as the training data.
The key-value parameters control, how and which model is trained:
model_class
: This argument needs to be present. It is the full python module path to the class of the model to train. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. You might need to install necessary packages to use the models.target_column
: Which column from the data to use as target. If not empty, it is removed automatically from the training data. Defaults to an empty string, in which case no target is feed to the model training (e.g. for unsupervised algorithms). This means, you typically want to set this parameter.wrap_predict
: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. Defaults to false. Typically you set it to true for sklearn models if predicting on big data.wrap_fit
: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. Defaults to false. Typically you set it to true for sklearn models if training on big data.fit_kwargs
: keyword arguments sent to the call tofit()
.
All other arguments are passed to the constructor of the model class.
Example:
CREATE MODEL my_model WITH ( model_class = 'xgboost.XGBClassifier', target_column = 'target' ) AS ( SELECT x, y, target FROM "data" )
This SQL call is not a 1:1 replacement for a normal python training and can not fulfill all use-cases or requirements!
If you are dealing with large amounts of data, you might run into problems while model training and/or prediction, depending if your model can cope with dask dataframes.
- if you are training on relatively small amounts of data but predicting on large data samples, you might want to set
wrap_predict
to True. With this option, model interference will be parallelized/distributed.- If you are training on large amounts of data, you can try setting wrap_fit to True. This will do the same on the training step, but works only on those models, which have a
fit_partial
method.
Remove the model with the given name from the registered models.
Predict the target using the given model and dataframe from the SELECT
query.
The return value is the input dataframe with an additional column named
"target", which contains the predicted values.
The model needs to be registered at the context before using it in this function,
either by calling :func:`~dask_sql.Context.register_model` explicitly or by training
a model using the CREATE MODEL
SQL statement above.
A model can be anything which has a predict
function.
Please note however, that it will need to act on Dask dataframes. If you
are using a model not optimized for this, it might be that you run out of memory if
your data is larger than the RAM of a single machine.
To prevent this, have a look into the dask_sql.physical.rel.custom.wrappers.ParallelPostFit
meta-estimator. If you are using a model trained with CREATE MODEL
and the wrap_predict
flag set to true, this is done automatically.
Using this SQL statement is roughly equivalent to doing
df = context.sql("<select query>")
model = get the model from the context
target = model.predict(df)
return df.assign(target=target)
The select query is a normal SELECT
query (following the same syntax as described in :ref:`select`
or even another a call to PREDICT
.