Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ include("image.jl")
include("fit_utils.jl")
include("entity_embedding_utils.jl")
include("mlj_model_interface.jl")
include("mlj_embedder_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export EntityEmbedder
export CUDALibs, CPU1

include("deprecated.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ entityprops = [
numfeats = 4

# Run it through the categorical embedding layer
embedder = EntityEmbedder(entityprops, 4)
embedder = EntityEmbedderLayer(entityprops, 4)
julia> output = embedder(batch)
5×10 Matrix{Float64}:
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1
Expand All @@ -35,18 +35,18 @@ julia> output = embedder(batch)
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
struct EntityEmbedderLayer{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
(m::EntityEmbedderLayer)(x) =
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
function EntityEmbedderLayer(entityprops, numfeats; init = Flux.randn32)
embedders = []
modifiers = []
# Setup entityprops
Expand All @@ -66,8 +66,8 @@ function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
end
end

EntityEmbedder(embedders, modifiers, numfeats)
EntityEmbedderLayer(embedders, modifiers, numfeats)
end

# 4. Register it as layer with Flux
Flux.@layer EntityEmbedder
Flux.@layer EntityEmbedderLayer
2 changes: 1 addition & 1 deletion src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function construct_model_chain_with_entityembs(
)
chain = try
Flux.Chain(
EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)),
EntityEmbedderLayer(entityprops, shape[1]; init = Flux.glorot_uniform(rng)),
build(model, rng, (entityemb_output_dim, shape[2])),
) |> move
catch ex
Expand Down
128 changes: 128 additions & 0 deletions src/mlj_embedder_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
### EntityEmbedder with MLJ Interface

# 1. Interface Struct
mutable struct EntityEmbedder{M <: MLJFluxModel} <: Unsupervised
model::M
end;


# 2. Constructor
function EntityEmbedder(model;)
return EntityEmbedder(model)
end;


# 4. Fitted parameters (for user access)
MMI.fitted_params(::EntityEmbedder, fitresult) = fitresult

# 5. Fit method
function MMI.fit(transformer::EntityEmbedder, verbosity::Int, X, y)
return MLJModelInterface.fit(transformer.model, verbosity, X, y)
end;


# 6. Transform method
function MMI.transform(transformer::EntityEmbedder, fitresult, Xnew)
Xnew_transf = MLJModelInterface.transform(transformer.model, fitresult, Xnew)
return Xnew_transf
end

# 8. Extra metadata
MMI.metadata_pkg(
EntityEmbedder,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
)

MMI.metadata_model(
EntityEmbedder,
input_scitype = Table,
output_scitype = Table,
load_path = "MLJTransforms.EntityEmbedder",
)

MMI.target_in_fit(::Type{<:EntityEmbedder}) = true





"""
$(MMI.doc_header(EntityEmbedder))

`EntityEmbedder` implements entity embeddings as in the "Entity Embeddings of Categorical Variables" paper by Cheng Guo, Felix Berkhahn.

# Training data

In MLJ (or MLJBase) bind an instance unsupervised `model` to data with

mach = machine(model, X, y)

Here:


- `X` is any table of input features (eg, a `DataFrame`). Features to be transformed must
have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to
check scitypes.

- `y` is the target, which can be any `AbstractVector` whose element
scitype is `Continuous` or `Count` for regression problems and
`Multiclass` or `OrderedFactor` for classification problems; check the scitype with `schema(y)`

Train the machine using `fit!(mach)`.

# Hyper-parameters

- `model`: The underlying deep learning model to be used for entity embedding. So far this supports `NeuralNetworkClassifier`, `NeuralNetworkRegressor`, and `MultitargetNeuralNetworkRegressor`.

# Operations

- `transform(mach, Xnew)`: Transform the categorical features of `Xnew` into dense `Continuous` vectors using the trained `MLJFlux.EntityEmbedderLayer` layer present in the network.
Check relevant documentation [here](https://fluxml.ai/MLJFlux.jl/dev/) and in particular, the `embedding_dims` hyperparameter.


# Examples

```julia
using MLJFlux
using MLJ
using CategoricalArrays

# Setup some data
N = 200
X = (;
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(Float32, N),
Column6 = categorical(
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
),
)
y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification

# Initiate model
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux

clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2))

emb = EntityEmbedder(clf)

# Construct machine
mach = machine(emb, X, y)

# Train model
fit!(mach)

# Transform data using model to encode categorical columns
Xnew = transform(mach, X)
Xnew
```

See also
[`TargetEncoder`](@ref)
"""
EntityEmbedder
6 changes: 3 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ MMI.metadata_pkg.(
const MODELSUPPORTDOC = """
In addition to features with `Continuous` scientific element type, this model supports
categorical features in the input table. If present, such features are embedded into dense
vectors by the use of an additional `EntityEmbedder` layer after the input, as described in
vectors by the use of an additional `EntityEmbedderLayer` layer after the input, as described in
Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn arXiv, 2016.
"""

Expand All @@ -204,7 +204,7 @@ const XDOC = """
scitype (typically `Float32`); or (ii) a table of input features (eg, a `DataFrame`)
whose columns have `Continuous`, `Multiclass` or `OrderedFactor` element scitype; check
column scitypes with `schema(X)`. If any `Multiclass` or `OrderedFactor` features
appear, the constructed network will use an `EntityEmbedder` layer to transform
appear, the constructed network will use an `EntityEmbedderLayer` layer to transform
them into dense vectors. If `X` is a `Matrix`, it is assumed that columns correspond to
features and rows corresponding to observations.

Expand All @@ -222,7 +222,7 @@ const EMBDOC = """
const TRANSFORMDOC = """
- `transform(mach, Xnew)`: Assuming `Xnew` has the same schema as `X`, transform the
categorical features of `Xnew` into dense `Continuous` vectors using the
`MLJFlux.EntityEmbedder` layer present in the network. Does nothing in case the model
`MLJFlux.EntityEmbedderLayer` layer present in the network. Does nothing in case the model
was trained on an input `X` that lacks categorical features.
"""

Expand Down
28 changes: 22 additions & 6 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ entityprops = [
(index = 4, levels = 2, newdim = 2),
]

embedder = MLJFlux.EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4)

output = embedder(batch)

Expand Down Expand Up @@ -68,7 +68,7 @@ end
]

cat_model = Chain(
MLJFlux.EntityEmbedder(entityprops, 4),
MLJFlux.EntityEmbedderLayer(entityprops, 4),
Dense(9 => (ind == 1) ? 10 : 1),
finalizer[ind],
)
Expand Down Expand Up @@ -143,7 +143,7 @@ end
@testset "Transparent when no categorical variables" begin
entityprops = []
numfeats = 4
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4)
output = embedder(batch)
@test output ≈ batch
@test eltype(output) == Float32
Expand Down Expand Up @@ -187,21 +187,37 @@ end
3 4
])

stable_rng=StableRNG(123)

for j in eachindex(embedding_dims)
for i in eachindex(models)
# Without lightweight wrapper
clf = models[1](
builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2),
builder = MLJFlux.MLP(hidden=(10, 10)),
optimiser = Optimisers.Adam(0.01),
batch_size = 8,
epochs = 100,
acceleration = CUDALibs(),
optimiser_changes_trigger_retraining = true,
embedding_dims = embedding_dims[3],
rng=42
)

mach = machine(clf, X, ys[1])

fit!(mach, verbosity = 0)
Xnew = transform(mach, X)
# With lightweight wrapper
clf2 = deepcopy(clf)
emb = MLJFlux.EntityEmbedder(clf2)
mach_emb = machine(emb, X, ys[1])
fit!(mach_emb, verbosity = 0)
Xnew_emb = transform(mach_emb, X)
@test Xnew == Xnew_emb

# Pipeline doesn't throw an error
pipeline = emb |> clf
mach_pipe = machine(pipeline, X, y)
fit!(mach_pipe, verbosity = 0)
y = predict_mode(mach_pipe, X)

mapping_matrices = MLJFlux.get_embedding_matrices(
fitted_params(mach).chain,
Expand Down
Loading