Skip to content

Commit 7a0f78c

Browse files
committed
inject rng in rule extraction strategy
1 parent ea5072e commit 7a0f78c

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

src/dataset.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ Extract the logiset (if present) from the dataset's MLJ machine.
306306
"""
307307
get_logiset(ds::ModalDataSet)::SupportedLogiset = ds.mach.data[1].modalities[1]
308308

309+
get_rng(ds::AbstractDataSet) = get_rng(ds.pinfo)
310+
309311
# ---------------------------------------------------------------------------- #
310312
# MLJ models's extra setup #
311313
# ---------------------------------------------------------------------------- #

src/extractrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ to_namedtuple(x) = NamedTuple{fieldnames(typeof(x))}(ntuple(i -> getfield(x, i),
2323
# ---------------------------------------------------------------------------- #
2424
function extractrules(
2525
extractor :: InTreesRuleExtractor,
26-
params :: NamedTuple,
26+
_ :: NamedTuple,
2727
ds :: AbstractDataSet,
2828
solem :: Vector{AbstractModel}
2929
)::Vector{DecisionSet}

src/partition.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ struct PartitionInfo{T} <: AbstractPartitionInfo
3232
end
3333

3434
# ---------------------------------------------------------------------------- #
35-
# base show #
35+
# methods #
3636
# ---------------------------------------------------------------------------- #
37+
get_rng(p::PartitionInfo) = p.rng
38+
3739
function Base.show(io::IO, info::PartitionInfo)
3840
println(io, "PartitionInfo:")
3941
for field in fieldnames(PartitionInfo)

src/symbolic_analysis.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,18 @@ function sole_predict(solem::AbstractModel, y_test::AbstractVector{<:Label})
218218
preds
219219
end
220220

221+
# set the random number generator for a rule extraction strategy
222+
function set_rng(r::RuleExtractor, rng::Random.AbstractRNG)::RuleExtractor
223+
T = typeof(r)
224+
225+
fnames = fieldnames(T)
226+
fvalues = map(fnames) do fn
227+
fn === :rng ? rng : getfield(r, fn)
228+
end
229+
230+
return T(; NamedTuple{fnames}(fvalues)...)
231+
end
232+
221233
# ---------------------------------------------------------------------------- #
222234
# eval measures #
223235
# ---------------------------------------------------------------------------- #
@@ -285,10 +297,11 @@ end
285297
# internal symbolic_analysis #
286298
# ---------------------------------------------------------------------------- #
287299
function _symbolic_analysis!(
288-
modelset::ModelSet;
289-
extractor::Union{MaybeRuleExtractor,Tuple{RuleExtractor,NamedTuple}}=nothing,
290-
association::MaybeAbstractAssociationRuleExtractor=nothing,
291-
measures::Tuple{Vararg{FussyMeasure}}=()
300+
modelset :: AbstractModelSet;
301+
extractor :: Union{MaybeRuleExtractor,Tuple{RuleExtractor,NamedTuple}}=nothing,
302+
# extractor::MaybeRuleExtractor=nothing,
303+
association :: MaybeAbstractAssociationRuleExtractor=nothing,
304+
measures :: Tuple{Vararg{FussyMeasure}}=()
292305
)::ModelSet
293306
ds = dsetup(modelset)
294307
solem = solemodels(modelset)
@@ -300,7 +313,8 @@ function _symbolic_analysis!(
300313
else
301314
params = NamedTuple(;)
302315
end
303-
316+
317+
:rng fieldnames(typeof(extractor)) && (extractor = set_rng(extractor, get_rng(ds)))
304318
extractrules(extractor, params, ds, solem)
305319
end)
306320

0 commit comments

Comments
 (0)