Skip to content

variable axis on which the "one hot" property holds #35

@nomadbl

Description

@nomadbl

Motivation and description

I am working on a layer that produces one hot outputs, so I am looking into using OneHotArrays.jl.
My gripe is that currently the datatype only supports the one hot vectors to extend on the first axis.

I thought I'd write my thoughts and possible implementations of the variable axis, to get some feedback and context from other maintainers and users here (I am very new to Julia and Flux, coming from working with python).

Possible Implementation

Implementation path 1 (WIP), change the constructors, size and getindex:

struct OneHotArray{T<:Integer,N,var"N+1",I<:Union{T,AbstractArray{T,N}}} <: AbstractArray{Bool,var"N+1"}
  indices::I
  nlabels::Int
  axis::Int
end
OneHotArray{T,N,I}(indices, L::Int, axis::Int=1) where {T,N,I} = OneHotArray{T,N,N + 1,I}(indices, L, axis)
OneHotArray(indices::T, L::Int, axis::Int=1) where {T<:Integer} = OneHotArray{T,0,1,T}(indices, L, axis)
OneHotArray(indices::I, L::Int, axis::Int=1) where {T,N,I<:AbstractArray{T,N}} = OneHotArray{T,N,N + 1,I}(indices, L, axis)

Base.size(x::OneHotArray) = Tuple(insert!(collect(size(x.indices)), x.axis, x.nlabels))

function Base.getindex(x::OneHotArray, I::Vararg{Int,N}) where {N}
  length(I) == length(size(x)) || throw(DimensionMismatch("dimensions of OneHotArray $(length(size(x))) and dimensions of indices $(length(I)) do not match."))
  @boundscheck all(1 .<= I .<= size(x)) || throw(BoundsError(x, I))
  Ip = Tuple(popat!(collect(I), x.axis))
  return some_appropriate_checks_here
end

The idea with this is to maintain the sparse nature of the representation for later optimized multiplications, backptop etc.

While working on this I also hit upon path 2, to reuse all the original code, but use the new axis parameter to do appropriate permutations of the underlying (1,...) dimensional object before computations.

I expect to do a PR of this soon, but I'd love to hear your thoughts: do you think the first approach is better (more memory and compute efficient?)? But also it is probably harder to maintain and test.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions