Description
I got this working, sort of:
julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))
julia> test_transport(d, Normal() ^ 3)
Test Summary: | Pass Total Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) | 8 8 0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)
To do this, I added for_constructor
that's like For
, but a little smarter - it might sometimes collapse to a power measure:
for_constructor(f, x) = for_constructor(f, (x,))
@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
eltypes = Tuple{eltype.(I.types)...}
quote
T = Core.Compiler.return_type(f, $eltypes)
_for(T, f, inds, static(Base.issingletontype(T)))
end
end
function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
instance(T) ^ size(first(inds))
end
function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
For{T,F,I}(f, inds)
end
Then we just need the standard stuff:
function MeasureBase.transport_origin(d::AbstractProductMeasure)
for_constructor(MeasureBase.transport_origin, marginals(d))
end
function MeasureBase.to_origin(d::AbstractProductMeasure, x)
map(MeasureBase.to_origin, marginals(d), x)
end
function MeasureBase.from_origin(d::AbstractProductMeasure, x)
map(MeasureBase.from_origin, marginals(d), x)
end
Well, almost. There's also this bug:
julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
Stacktrace:
[1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
[2] macro expansion
@ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
[3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
[4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
[5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
@ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
[6] _origin_depth(ν::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
@ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
[7] top-level scope
@ REPL[60]:1
We end up taking a power of a NoTransportOrigin
, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth
to
@inline function _origin_depth(ν::NU) where {NU}
ν_0 = ν
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
ν_{i} = transport_origin(ν_{i - 1})
if ν_{i} isa PowerMeasure
ν_{i} = ν_{i}.parent
else
if ν_{i} isa NoTransportOrigin
return static(i - 1)
end
end
return static(10)
end
This last part feels kind of hacky. Also, we have the problem that map
forces allocation. It would be nice to use mappedarray
instead, but that doesn't infer properly. Maybe a modification of it could?
Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?