Skip to content

Commit dbeaf9c

Browse files
authored
Merge pull request #287 from FluxML/dev
Rebase entity-tutorial
2 parents 71eeafd + 318a61e commit dbeaf9c

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJFlux"
22
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
33
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>", "Ayush Shridhar <ayush.shridhar1999@gmail.com>"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/classifier.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ function MLJModelInterface.predict(
4040
)
4141
chain, levels, ordinal_mappings, _ = fitresult
4242
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings) # what if Xnew is a matrix
43-
X = reformat(Xnew)
43+
X = _f32(reformat(Xnew), 0)
4444
probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...)
4545
return MLJModelInterface.UnivariateFinite(levels, probs)
4646
end
@@ -69,7 +69,7 @@ function MLJModelInterface.predict(
6969
)
7070
chain, levels, ordinal_mappings, _ = fitresult
7171
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
72-
X = reformat(Xnew)
72+
X = _f32(reformat(Xnew), 0)
7373
probs = vec(chain(X))
7474
return MLJModelInterface.UnivariateFinite(levels, probs; augment = true)
7575
end

src/encoders.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function ordinal_encoder_transform(X, mapping_matrix)
6767
test_levels = levels(col)
6868
check_unkown_levels(train_levels, test_levels)
6969
level2scalar = mapping_matrix[ind]
70-
new_col = [recode(unwrap.(col), level2scalar...)...]
70+
new_col = unwrap.(recode(col, level2scalar...))
7171
push!(new_feats, new_col)
7272
else
7373
push!(new_feats, col)

src/regressor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ function MLJModelInterface.predict(model::NeuralNetworkRegressor,
2727
Xnew)
2828
chain, ordinal_mappings = fitresult[1], fitresult[3]
2929
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
30-
Xnew_ = reformat(Xnew)
30+
Xnew_ = _f32(reformat(Xnew), 0)
3131
return [chain(values.(tomat(Xnew_[:, i])))[1]
3232
for i in 1:size(Xnew_, 2)]
3333
end
@@ -74,7 +74,7 @@ function MLJModelInterface.predict(model::MultitargetNeuralNetworkRegressor,
7474
fitresult, Xnew)
7575
chain, target_column_names, ordinal_mappings, _ = fitresult
7676
Xnew = ordinal_encoder_transform(Xnew, ordinal_mappings)
77-
X = reformat(Xnew)
77+
X = _f32(reformat(Xnew), 0)
7878
ypred = [chain(values.(tomat(X[:, i])))
7979
for i in 1:size(X, 2)]
8080
output =

test/mlj_model_interface.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ mutable struct LisasBuilder
133133
n1::Int
134134
end
135135

136+
# UndefVarError accepts two inputs from julia > v"1.9"
137+
_UndefVarError(var, scope) = @static if VERSION < v"1.10"
138+
UndefVarError(var)
139+
else
140+
UndefVarError(var, scope)
141+
end
142+
136143
@testset "builder errors and issue #237" begin
137144
# create a builder with an intentional flaw;
138145
# `Chains` is undefined - it should be `Chain`
@@ -153,7 +160,7 @@ end
153160
y = rand(Float32, 75)
154161
@test_logs(
155162
(:error, MLJFlux.ERR_BUILDER),
156-
@test_throws UndefVarError(:Chains) MLJBase.fit(model, 0, X, y)
163+
@test_throws _UndefVarError(:Chains, Flux) MLJBase.fit(model, 0, X, y)
157164
)
158165
end
159166

0 commit comments

Comments
 (0)