Skip to content

Commit 4b5bfab

Browse files
committed
_setup_dataset -> setup_dataset
1 parent a6a6659 commit 4b5bfab

File tree

1 file changed

+66
-78
lines changed

1 file changed

+66
-78
lines changed

src/dataset.jl

Lines changed: 66 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -344,73 +344,12 @@ function set_tuning(
344344
end
345345

346346
# ---------------------------------------------------------------------------- #
347-
# internal setup dataset #
347+
# setup dataset #
348348
# ---------------------------------------------------------------------------- #
349-
function _setup_dataset(
350-
X :: AbstractDataFrame,
351-
y :: AbstractVector{<:Label},
352-
w :: MaybeVector = nothing;
353-
model :: MLJ.Model = _DefaultModel(y),
354-
resampling :: ResamplingStrategy = Holdout(fraction_train=0.7, shuffle=true),
355-
valid_ratio :: Real = 0.0,
356-
seed :: MaybeInt = nothing,
357-
balancing :: MaybeBalancing = nothing,
358-
tuning :: MaybeTuning = nothing,
359-
win :: WinFunc = adaptivewindow(nwindows=3, overlap=0.1),
360-
features :: Tuple{Vararg{Base.Callable}} = (maximum, minimum),
361-
reducefunc :: Base.Callable = mean
362-
)::AbstractDataSet
363-
# setup rng
364-
if !isnothing(seed)
365-
rng = Xoshiro(seed)
366-
# propagate user rng to every field that needs it
367-
hasproperty(model, :rng) && set_rng!(model, rng)
368-
hasproperty(resampling, :rng) && (resampling = set_rng(resampling, rng))
369-
else
370-
rng = TaskLocalRNG()
371-
end
372-
373-
# Modal models need features to be passed in model params
374-
hasproperty(model, :features) && set_conditions!(model, features)
375-
# MLJ.TunedModels can't automatically assigns measure to Modal models
376-
if model isa Modal && !isnothing(tuning)
377-
isnothing(get_measure(tuning)) && (tuning.measure = LogLoss())
378-
end
379-
380-
# handle multidimensional datasets:
381-
# propositional models requiring feature aggregation
382-
# modal models requiring reducing data size
383-
if DataTreatments.is_multidim_dataset(X)
384-
if model isa Modal
385-
t = DataTreatment(X, :reducesize; win, features, reducefunc)
386-
X = DataFrame(get_dataset(t), Symbol.(get_featureid(t)))
387-
tinfo = ReductionInfo(features, win, reducefunc)
388-
else
389-
t = DataTreatment(X, :aggregate; win, features)
390-
X = DataFrame(get_dataset(t), Symbol.(get_featureid(t)))
391-
tinfo = AggregationInfo(features, win)
392-
end
393-
else
394-
X = code_dataset(X)
395-
# some algos, like xgboost, doesnt accept dataset with numeric values, only float
396-
X = to_float_dataset(X)
397-
tinfo = nothing
398-
end
399-
400-
ttpairs, pinfo = partition(y; resampling, valid_ratio, rng)
401-
402-
isnothing(seed) && (seed = 1)
403-
isnothing(balancing) || (model = set_balancing(model, balancing, seed))
404-
isnothing(tuning) || (model = set_tuning(model, tuning, rng))
405-
406-
mach = isnothing(w) ? MLJ.machine(model, X, y) : MLJ.machine(model, X, y, w)
407-
408-
DataSet(mach, ttpairs, pinfo; tinfo)
349+
function setup_dataset(X::AbstractDataFrame, y::AbstractVector, args...; kwargs...)
350+
throw(ArgumentError("Target variable y must have elements of type Label, " * "got eltype: $(eltype(y))"))
409351
end
410352

411-
# ---------------------------------------------------------------------------- #
412-
# setup dataset #
413-
# ---------------------------------------------------------------------------- #
414353
"""
415354
setup_dataset(
416355
X, y, w=nothing;
@@ -547,20 +486,68 @@ dts = setup_dataset(
547486
548487
# See also: [`DataSet`](@ref), [`PropositionalDataSet`](@ref), [`ModalDataSet`](@ref), [`symbolic_analysis`](@ref)
549488
"""
550-
# setup_dataset(args...; kwargs...) = _setup_dataset(args...; kwargs...)
551-
552489
function setup_dataset(
553-
X::AbstractDataFrame,
554-
y::AbstractVector{<:Label},
555-
args...;
556-
model :: MLJ.Model = _DefaultModel(y),
557-
kwargs...
558-
)
559-
_setup_dataset(X, check_y(y, model), args...; model, kwargs...)
560-
end
490+
X :: AbstractDataFrame,
491+
y :: AbstractVector{<:Label},
492+
w :: MaybeVector = nothing;
493+
model :: MLJ.Model = _DefaultModel(y),
494+
resampling :: ResamplingStrategy = Holdout(fraction_train=0.7, shuffle=true),
495+
valid_ratio :: Real = 0.0,
496+
seed :: MaybeInt = nothing,
497+
balancing :: MaybeBalancing = nothing,
498+
tuning :: MaybeTuning = nothing,
499+
win :: WinFunc = adaptivewindow(nwindows=3, overlap=0.1),
500+
features :: Tuple{Vararg{Base.Callable}} = (maximum, minimum),
501+
reducefunc :: Base.Callable = mean
502+
)::AbstractDataSet
503+
y = check_y(y, model)
561504

562-
function setup_dataset(X::AbstractDataFrame, y::AbstractVector, args...; kwargs...)
563-
throw(ArgumentError("Target variable y must have elements of type Label, " * "got eltype: $(eltype(y))"))
505+
# setup rng
506+
if !isnothing(seed)
507+
rng = Xoshiro(seed)
508+
# propagate user rng to every field that needs it
509+
hasproperty(model, :rng) && set_rng!(model, rng)
510+
hasproperty(resampling, :rng) && (resampling = set_rng(resampling, rng))
511+
else
512+
rng = TaskLocalRNG()
513+
end
514+
515+
# Modal models need features to be passed in model params
516+
hasproperty(model, :features) && set_conditions!(model, features)
517+
# MLJ.TunedModels can't automatically assigns measure to Modal models
518+
if model isa Modal && !isnothing(tuning)
519+
isnothing(get_measure(tuning)) && (tuning.measure = LogLoss())
520+
end
521+
522+
# handle multidimensional datasets:
523+
# propositional models requiring feature aggregation
524+
# modal models requiring reducing data size
525+
if DataTreatments.is_multidim_dataset(X)
526+
if model isa Modal
527+
t = DataTreatment(X, :reducesize; win, features, reducefunc)
528+
X = DataFrame(get_dataset(t), Symbol.(get_featureid(t)))
529+
tinfo = ReductionInfo(features, win, reducefunc)
530+
else
531+
t = DataTreatment(X, :aggregate; win, features)
532+
X = DataFrame(get_dataset(t), Symbol.(get_featureid(t)))
533+
tinfo = AggregationInfo(features, win)
534+
end
535+
else
536+
X = code_dataset(X)
537+
# some algos, like xgboost, doesnt accept dataset with numeric values, only float
538+
X = to_float_dataset(X)
539+
tinfo = nothing
540+
end
541+
542+
ttpairs, pinfo = partition(y; resampling, valid_ratio, rng)
543+
544+
isnothing(seed) && (seed = 1)
545+
isnothing(balancing) || (model = set_balancing(model, balancing, seed))
546+
isnothing(tuning) || (model = set_tuning(model, tuning, rng))
547+
548+
mach = isnothing(w) ? MLJ.machine(model, X, y) : MLJ.machine(model, X, y, w)
549+
550+
DataSet(mach, ttpairs, pinfo; tinfo)
564551
end
565552

566553
"""
@@ -570,8 +557,9 @@ Convenience method when target variable is a column in the feature DataFrame.
570557
"""
571558
function setup_dataset(
572559
X::AbstractDataFrame,
573-
y::Symbol;
560+
y::Symbol,
561+
args...;
574562
kwargs...
575563
)::AbstractDataSet
576-
setup_dataset(X[!, Not(y)], X[!, y]; kwargs...)
564+
setup_dataset(X[!, Not(y)], X[!, y], args...; kwargs...)
577565
end

0 commit comments

Comments
 (0)