Skip to content

Commit 682014d

Browse files
authored
Merge pull request #78 from EHTJulia/ptiede-paramtype
paramtype
2 parents 28c84f8 + 16d9c3a commit 682014d

File tree

3 files changed

+46
-24
lines changed

3 files changed

+46
-24
lines changed

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ VLBISkyModels.DomainParams
2121
VLBISkyModels.build_param
2222
VLBISkyModels.getparam
2323
VLBISkyModels.@unpack_params
24-
24+
VLBISkyModels.paramtype
2525
```
2626

2727
### Time Frequency Domain

src/models/geometric_models.jl

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,17 @@ By default if T isn't given, `Gaussian` defaults to `Float64`
4343
"""
4444
struct Gaussian{T} <: GeometricModel{T} end
4545
Gaussian() = Gaussian{Float64}()
46-
radialextent(::Gaussian{T}) where {T} = convert(T, 5)
46+
radialextent(::Gaussian{T}) where {T} = convert(paramtype(T), 5)
4747

48-
@inline function intensity_point(::Gaussian{T}, p) where {T}
48+
@inline function intensity_point(::Gaussian{D}, p) where {D}
4949
x, y = _getxy(p)
50+
T = paramtype(D)
5051
return exp(-(x^2 + y^2) / 2) / T(2 * pi)
5152
end
5253

53-
@inline function visibility_point(::Gaussian{T}, p) where {T}
54+
@inline function visibility_point(::Gaussian{D}, p) where {D}
5455
u, v = _getuv(p)
56+
T = paramtype(D)
5557
return exp(-2 * T(π)^2 * (u^2 + v^2)) + zero(T)im
5658
end
5759
"""
@@ -70,7 +72,7 @@ unit Gaussian.
7072
struct TBlob{T} <: GeometricModel{T}
7173
slope::T
7274
norm::T
73-
function TBlob(slope::Number)
75+
function TBlob(slope::Number)
7476
T = typeof(slope)
7577
norm = tblobnorm(slope)
7678
return new{T}(slope, norm)
@@ -87,7 +89,7 @@ visanalytic(::Type{<:TBlob}) = NotAnalytic()
8789
@inline getnorm(s::TBlob{<:DomainParams}, p) = tblobnorm(s.norm(p))
8890
radialextent(m::TBlob) = 5 * m.slope / (m.slope - 2)
8991

90-
function intensity_point(m::TBlob{T}, p) where {T}
92+
function intensity_point(m::TBlob, p)
9193
x, y = _getxy(p)
9294
= x^2 + y^2
9395
@unpack_params slope = m(p)
@@ -109,14 +111,16 @@ By default if T isn't given, `Disk` defaults to `Float64`
109111
struct Disk{T} <: GeometricModel{T} end
110112
Disk() = Disk{Float64}()
111113

112-
@inline function intensity_point(::Disk{T}, p) where {T}
114+
@inline function intensity_point(::Disk{D}, p) where {D}
113115
x, y = _getxy(p)
114116
r = x^2 + y^2
115-
return r < 1 ? one(T) / (π) : zero(T)
117+
T = paramtype(D)
118+
return r < 1 ? one(T) / T(π) : zero(T)
116119
end
117120

118-
@inline function visibility_point(::Disk{T}, p) where {T}
121+
@inline function visibility_point(::Disk{D}, p) where {D}
119122
u, v = _getuv(p)
123+
T = paramtype(D)
120124
ur = 2 * T(π) * (sqrt(u^2 + v^2) + eps(T))
121125
return 2 * besselj1(ur) / (ur) + zero(T)im
122126
end
@@ -138,8 +142,9 @@ struct SlashedDisk{T} <: GeometricModel{T}
138142
slash::T
139143
end
140144

141-
function intensity_point(m::SlashedDisk{T}, p) where {T}
145+
function intensity_point(m::SlashedDisk{D}, p) where {D}
142146
x, y = _getxy(p)
147+
T = paramtype(D)
143148
r2 = x^2 + y^2
144149
@unpack_params slash = m(p)
145150
s = 1 - slash
@@ -151,8 +156,9 @@ function intensity_point(m::SlashedDisk{T}, p) where {T}
151156
end
152157
end
153158

154-
function visibility_point(m::SlashedDisk{T}, p) where {T}
159+
function visibility_point(m::SlashedDisk{D}, p) where {D}
155160
u, v = _getuv(p)
161+
T = paramtype(D)
156162
@unpack_params slash = m(p)
157163
k = 2 * T(π) * sqrt(u^2 + v^2) + eps(T)
158164
s = 1 - slash
@@ -167,7 +173,7 @@ function visibility_point(m::SlashedDisk{T}, p) where {T}
167173
return norm * (v1 + v3)
168174
end
169175

170-
radialextent(::SlashedDisk{T}) where {T} = convert(T, 3)
176+
radialextent(::SlashedDisk{T}) where {T} = convert(paramtype(T), 3)
171177

172178
"""
173179
$(TYPEDEF)
@@ -180,12 +186,12 @@ By default if `T` isn't given, `Gaussian` defaults to `Float64`
180186
"""
181187
struct Ring{T} <: GeometricModel{T} end
182188
Ring() = Ring{Float64}()
183-
radialextent(::Ring{T}) where {T} = convert(T, 3 / 2)
189+
radialextent(::Ring{T}) where {T} = convert(paramtype(T), 3 / 2)
184190

185-
@inline function intensity_point(::Ring{T}, p) where {T}
191+
@inline function intensity_point(::Ring{D}, p) where {D}
186192
x, y = _getxy(p)
193+
T = paramtype(D)
187194
r = hypot(x, y)
188-
θ = atan(x, y)
189195
dr = T(1e-2)
190196
if (abs(r - 1) < dr / 2)
191197
acc = one(T)
@@ -195,8 +201,9 @@ radialextent(::Ring{T}) where {T} = convert(T, 3 / 2)
195201
end
196202
end
197203

198-
@inline function visibility_point(::Ring{T}, p) where {T}
204+
@inline function visibility_point(::Ring{D}, p) where {D}
199205
u, v = _getuv(p)
206+
T = paramtype(D)
200207
k = 2 * T(π) * sqrt(u^2 + v^2) + eps(T)
201208
vis = besselj0(k) + zero(T) * im
202209
return vis
@@ -276,10 +283,11 @@ end
276283
# Depreciate this method since we are moving to vectors for simplificty
277284
#@deprecate MRing(a::Tuple, b::Tuple) MRing(a::AbstractVector, b::AbstractVector)
278285

279-
radialextent(::MRing{T}) where {T} = convert(T, 3 / 2)
286+
radialextent(::MRing{T}) where {T} = convert(paramtype(T), 3 / 2)
280287

281-
@inline function intensity_point(m::MRing{T}, p) where {T}
288+
@inline function intensity_point(m::MRing{D}, p) where {D}
282289
x, y = _getxy(p)
290+
T = paramtype(D)
283291
r = hypot(x, y)
284292
θ = atan(-x, y)
285293
dr = T(0.02)
@@ -296,8 +304,9 @@ radialextent(::MRing{T}) where {T} = convert(T, 3 / 2)
296304
end
297305
end
298306

299-
@inline function visibility_point(m::MRing{T}, p) where {T}
307+
@inline function visibility_point(m::MRing{D}, p) where {D}
300308
@unpack_params α, β = m(p)
309+
T = paramtype(D)
301310
u, v = _getuv(p)
302311
k = T(2π) * sqrt(u^2 + v^2) + eps(T)
303312
vis = besselj0(k) + zero(T) * im
@@ -433,7 +442,8 @@ function _crescentnorm(m::ConcordanceCrescent, p)
433442
return 2 /* f)
434443
end
435444

436-
function intensity_point(m::ConcordanceCrescent{T}, p) where {T}
445+
function intensity_point(m::ConcordanceCrescent{D}, p) where {D}
446+
T = paramtype(D)
437447
x, y = _getxy(p)
438448
r2 = x^2 + y^2
439449
norm = _crescentnorm(m, p)
@@ -445,8 +455,9 @@ function intensity_point(m::ConcordanceCrescent{T}, p) where {T}
445455
end
446456
end
447457

448-
function visibility_point(m::ConcordanceCrescent{T}, p) where {T}
458+
function visibility_point(m::ConcordanceCrescent{D}, p) where {D}
449459
u, v = _getuv(p)
460+
T = paramtype(D)
450461
k = 2 * T(π) * sqrt(u^2 + v^2) + eps(T)
451462
norm = T(π) * _crescentnorm(m, p) / k
452463
@unpack_params router, rinner, shift, slash = m(p)
@@ -495,7 +506,7 @@ struct ExtendedRing{T} <: GeometricModel{T}
495506
end
496507
visanalytic(::Type{<:ExtendedRing}) = NotAnalytic()
497508

498-
radialextent(::ExtendedRing{T}) where {T} = convert(T, 6)
509+
radialextent(::ExtendedRing{T}) where {T} = convert(paramtype(T), 6)
499510

500511
@fastmath @inline function intensity_point(m::ExtendedRing, p)
501512
x, y = _getxy(p)
@@ -533,18 +544,20 @@ This is just a convenience function for `stretched(ParabolicSegment(), a, h)`
533544
return stretched(ParabolicSegment(), a, h)
534545
end
535546

536-
function intensity_point(::ParabolicSegment{T}, p) where {T}
547+
function intensity_point(::ParabolicSegment{D}, p) where {D}
537548
x, y = _getxy(p)
538549
yw = (1 - x^2)
550+
T = paramtype(D)
539551
if abs(y - yw) < T(0.01 / 2) && abs(x) < 1
540552
return 1 / T(2 * 0.01)
541553
else
542554
return zero(T)
543555
end
544556
end
545557

546-
function visibility_point(::ParabolicSegment{T}, p) where {T}
558+
function visibility_point(::ParabolicSegment{D}, p) where {D}
547559
u, v = _getuv(p)
560+
T = paramtype(D)
548561
ϵ = sqrt(eps(T))
549562
= complex(v + ϵ)
550563
phase = cispi(T(3) / 4 + 2 *+ u^2 / (2 * vϵ))

src/models/multidomain/multidomain.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ abstract type DomainParams{T} end
3030
abstract type FrequencyParams{T} <: DomainParams{T} end
3131
abstract type TimeParams{T} <: DomainParams{T} end
3232

33+
"""
34+
paramtype(::Type{<:DomainParams})
35+
36+
Computes the base parameter type of the DomainParams. If `!<:DomainParams` then it just returns
37+
the type.
38+
"""
39+
@inline paramtype(::Type{<:DomainParams{T}}) where {T} = paramtype(T)
40+
@inline paramtype(T::Type{<:Any}) = T
41+
3342
"""
3443
getparam(m, s::Symbol, p)
3544

0 commit comments

Comments
 (0)