-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtransport.jl
More file actions
82 lines (72 loc) · 2.94 KB
/
transport.jl
File metadata and controls
82 lines (72 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
using Test
using MeasureBase.Interface: transport_to, test_transport
using MeasureBase: StdUniform, StdExponential, StdLogistic, StdNormal
using MeasureBase: Dirac
using LogExpFunctions: logit
using ChainRulesTestUtils
@testset "transport_to" begin
test_rrule(
MeasureBase._origin_depth,
pushfwd(exp, StdUniform()),
output_tangent = static(0),
)
for (f, μ) in [
(logit, StdUniform())
(log, StdExponential())
(exp, StdNormal())
]
test_transport(μ, pushfwd(f, μ))
test_transport(pushfwd(f, μ), μ)
end
for μ0 in [StdUniform(), StdExponential(), StdLogistic(), StdNormal()],
ν0 in [StdUniform(), StdExponential(), StdLogistic(), StdNormal()]
@testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin
test_transport(ν0, μ0)
test_transport(2.2 * ν0, 2.2 * μ0)
test_transport(ν0, μ0^1)
test_transport(ν0^1, μ0)
test_transport(ν0^3, μ0^3)
test_transport(ν0^(2, 3, 2), μ0^(3, 4))
test_transport(2.2 * ν0^(2, 3, 2), 2.2 * μ0^(3, 4))
@test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12))
@test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3, 4)))
end
end
@testset "transfrom from/to Dirac" begin
μ = Dirac(4.2)
test_transport(StdExponential()^0, μ)
test_transport(StdExponential()^(0, 0, 0), μ)
test_transport(μ, StdExponential()^static(0))
test_transport(μ, StdExponential()^(static(0), static(0)))
@test_throws ArgumentError transport_to(StdExponential()^1, μ)
@test_throws ArgumentError transport_to(μ, StdExponential()^1)
end
@testset "transport_to autosel" begin
@test @inferred(transport_to(StdExponential, StdUniform())) ==
transport_to(StdExponential(), StdUniform())
@test @inferred(transport_to(StdExponential, StdUniform()^(2, 3))) ==
transport_to(StdExponential()^6, StdUniform()^(2, 3))
@test @inferred(transport_to(StdUniform(), StdExponential)) ==
transport_to(StdUniform(), StdExponential())
@test @inferred(transport_to(StdUniform()^(2, 3), StdExponential)) ==
transport_to(StdUniform()^(2, 3), StdExponential()^6)
end
@testset "transport for products" begin
test_transport(
StdUniform()^(2, 2),
productmeasure((StdExponential(), StdLogistic()^3)),
)
test_transport(
productmeasure((StdExponential(), StdLogistic()^3)),
StdUniform()^(2, 2),
)
test_transport(
StdUniform()^(2, 2),
productmeasure((a = StdExponential(), b = StdLogistic()^3)),
)
test_transport(
productmeasure((a = StdExponential(), b = StdLogistic()^3)),
StdUniform()^(2, 2),
)
end
end