Skip to content

For transports #244

Open
Open
@cscherrer

Description

@cscherrer

I got this working, sort of:

julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))

julia> test_transport(d, Normal() ^ 3)
Test Summary:                                                                                             | Pass  Total  Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) |    8      8  0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)

To do this, I added for_constructor that's like For, but a little smarter - it might sometimes collapse to a power measure:

for_constructor(f, x) = for_constructor(f, (x,))

@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    quote
        T = Core.Compiler.return_type(f, $eltypes)
        _for(T, f, inds, static(Base.issingletontype(T)))
    end
end

function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
    instance(T) ^ size(first(inds))
end

function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
    For{T,F,I}(f, inds)
end

Then we just need the standard stuff:

function MeasureBase.transport_origin(d::AbstractProductMeasure)
    for_constructor(MeasureBase.transport_origin, marginals(d))
end

function MeasureBase.to_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.to_origin, marginals(d), x)
end

function MeasureBase.from_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.from_origin, marginals(d), x)
end

Well, almost. There's also this bug:

julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
  ^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
  ^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
Stacktrace:
 [1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
 [2] macro expansion
   @ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
 [3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
 [4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
 [5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
 [6] _origin_depth::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
   @ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
 [7] top-level scope
   @ REPL[60]:1

We end up taking a power of a NoTransportOrigin, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth to

@inline function _origin_depth::NU) where {NU}
    ν_0 = ν
    Base.Cartesian.@nexprs 10 i -> begin  # 10 is just some "big enough" number
        ν_{i} = transport_origin(ν_{i - 1})
        if ν_{i} isa PowerMeasure
            ν_{i} = ν_{i}.parent
        else
            if ν_{i} isa NoTransportOrigin
            return static(i - 1)
        end
    end
    return static(10)
end

This last part feels kind of hacky. Also, we have the problem that map forces allocation. It would be nice to use mappedarray instead, but that doesn't infer properly. Maybe a modification of it could?

Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions