I got this working, sort of:
julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))
julia> test_transport(d, Normal() ^ 3)
Test Summary: | Pass Total Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) | 8 8 0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)
To do this, I added for_constructor
that's like For
, but a little smarter - it might sometimes collapse to a power measure:
for_constructor(f, x) = for_constructor(f, (x,))
@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
T = Core.Compiler.return_type(f, $eltypes)
_for(T, f, inds, static(Base.issingletontype(T)))
function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
instance(T) ^ size(first(inds))
function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
For{T,F,I}(f, inds)
Then we just need the standard stuff:
function MeasureBase.transport_origin(d::AbstractProductMeasure)
for_constructor(MeasureBase.transport_origin, marginals(d))
function MeasureBase.to_origin(d::AbstractProductMeasure, x)
map(MeasureBase.to_origin, marginals(d), x)
function MeasureBase.from_origin(d::AbstractProductMeasure, x)
map(MeasureBase.from_origin, marginals(d), x)
Well, almost. There's also this bug:
julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
[1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
[2] macro expansion
@ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
[3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
[4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
[5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
[6] _origin_depth(ν::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
@ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
[7] top-level scope
@ REPL[60]:1
We end up taking a power of a NoTransportOrigin
, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth
@inline function _origin_depth(ν::NU) where {NU}
ν_0 = ν
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
ν_{i} = transport_origin(ν_{i - 1})
if ν_{i} isa PowerMeasure
ν_{i} = ν_{i}.parent
if ν_{i} isa NoTransportOrigin
return static(i - 1)
return static(10)
This last part feels kind of hacky. Also, we have the problem that map
forces allocation. It would be nice to use mappedarray
instead, but that doesn't infer properly. Maybe a modification of it could?
Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?