diff --git a/src/booster.jl b/src/booster.jl index 648acca..5312019 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -275,20 +275,17 @@ function deserialize(::Type{Booster}, buf::AbstractVector{UInt8}, data=DMatrix[] deserialize!(b, buf) end + # sadly this is type unstable because we might return a transpose """ predict(b::Booster, data; margin=false, training=false, ntree_limit=0) - Use the model `b` to run predictions on `data`. This will return a `Vector{Float32}` which can be compared to training or test target data. - If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction. - ## Examples ```julia (X, y) = (randn(100,3), randn(100)) b = xgboost((X, y), 10) - ŷ = predict(b, X) ``` """ @@ -314,6 +311,58 @@ function predict(b::Booster, Xy::DMatrix; end predict(b::Booster, Xy; kw...) = predict(b, DMatrix(Xy); kw...) + +""" + predictbytype(b::Booster, data::DMatrix; type=0, training=false, ntree_limit=0) + +Use the model `b` to run predictions on `data`. + +This version of predict gives access to contribution and interaction values. + +If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction. + +The 'type' parameter conforms to prediction types specified in the XGBoost documentation. +Options include: + 0 => normal (default) + 1 => output margin + 2 => predict contribution + 3 => predict approximate contribution + 4 => predict feature interactions + 5 => predict approximate feature interactions + 6 => predict leaf training (see XGBoost documentation) + +The shape of returned data varies with 'type' option and certain objectives. + +## Examples +```julia +(X, y) = (randn(100,3), randn(100)) +b = xgboost((X, y), 10) + +ŷ = predict(b, X, type=2) +``` +""" +function predictbytype(b::Booster, Xy::DMatrix; + type::Integer=0, # 0-normal, 1-margin, 2-contrib, 3-est. contrib,4-interact,5-est. interact, 6-leaf + training::Bool=false, + ntree_lower_limit::Integer=0, + ntree_limit::Integer=0, # 0 corresponds to no limit + ) + opts = Dict("type"=>type , + "iteration_begin"=>ntree_lower_limit, + "iteration_end"=>ntree_limit, + "strict_shape"=>false, + "training"=>training, + ) |> JSON3.write + oshape = Ref{Ptr{Lib.bst_ulong}}() + odim = Ref{Lib.bst_ulong}() + o = Ref{Ptr{Cfloat}}() + xgbcall(XGBoosterPredictFromDMatrix, b.handle, Xy.handle, opts, oshape, odim, o) + dims = reverse(unsafe_wrap(Array, oshape[], odim[])) + o = unsafe_wrap(Array, o[], tuple(dims...)) + length(dims) > 1 ? permutedims(o, reverse(1:ndims(o))) : o # permutedims to handle ndims>=3 +end + + function evaliter(b::Booster, watch, n::Integer=1) o = Ref{Ptr{Int8}}() names = collect(Iterators.map(string, keys(watch)))