Skip to content

Commit 4e71d55

Browse files
Drop T from public structs, orchestrators, and easy precision-defensive sites
1 parent 4b707d2 commit 4e71d55

25 files changed

Lines changed: 195 additions & 200 deletions

File tree

KomaMRIBase/src/datatypes/Phantom.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@ julia> obj = Phantom(x=[0.0])
2929
julia> obj.ρ
3030
```
3131
"""
32-
@with_kw mutable struct Phantom{T<:Real}
33-
name::String = "spins"
34-
x::AbstractVector{T} = @isdefined(T) ? T[] : Float64[]
35-
y::AbstractVector{T} = zeros(eltype(x), size(x))
36-
z::AbstractVector{T} = zeros(eltype(x), size(x))
37-
ρ::AbstractVector{T} = ones(eltype(x), size(x))
38-
T1::AbstractVector{T} = ones(eltype(x), size(x)) * 1_000_000
39-
T2::AbstractVector{T} = ones(eltype(x), size(x)) * 1_000_000
40-
T2s::AbstractVector{T} = ones(eltype(x), size(x)) * 1_000_000
32+
@with_kw mutable struct Phantom
33+
name::String = "spins"
34+
x::AbstractVector = Float64[]
35+
y::AbstractVector = zeros(eltype(x), size(x))
36+
z::AbstractVector = zeros(eltype(x), size(x))
37+
ρ::AbstractVector = ones(eltype(x), size(x))
38+
T1::AbstractVector = ones(eltype(x), size(x)) * 1_000_000
39+
T2::AbstractVector = ones(eltype(x), size(x)) * 1_000_000
40+
T2s::AbstractVector = ones(eltype(x), size(x)) * 1_000_000
4141
#Off-resonance related
42-
Δw::AbstractVector{T} = zeros(eltype(x), size(x))
42+
Δw::AbstractVector = zeros(eltype(x), size(x))
4343
#χ::Vector{SusceptibilityModel}
4444
#Diffusion
45-
Dλ1::AbstractVector{T} = zeros(eltype(x), size(x))
46-
Dλ2::AbstractVector{T} = zeros(eltype(x), size(x))
47-
::AbstractVector{T} = zeros(eltype(x), size(x))
45+
Dλ1::AbstractVector = zeros(eltype(x), size(x))
46+
Dλ2::AbstractVector = zeros(eltype(x), size(x))
47+
::AbstractVector = zeros(eltype(x), size(x))
4848
#Diff::Vector{DiffusionModel} #Diffusion map
4949
#Motion
50-
motion::Union{NoMotion, Motion{T}, MotionList{T}} = NoMotion()
50+
motion::Union{NoMotion, Motion, MotionList} = NoMotion()
5151
end
5252

5353
const NON_STRING_PHANTOM_FIELDS = Iterators.filter(x -> fieldtype(Phantom, x) != String, fieldnames(Phantom))

KomaMRIBase/src/datatypes/sequence/extensions/QuaternionRot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ struct QuaternionRot <: Extension
1313
qz::Float64
1414
end
1515

16-
function QuaternionRot(R::AbstractMatrix{T}) where {T<:Real}
16+
function QuaternionRot(R::AbstractMatrix)
1717
size(R) == (3, 3) || throw(DimensionMismatch("Rotation matrix must be 3x3."))
18-
R = copyto!(similar(R, typeof(float(one(T)))), R)
18+
R = copyto!(similar(R, typeof(float(one(eltype(R))))), R)
1919
all(iszero, R) && throw(ArgumentError("Empty matrix provided in place of a rotation matrix."))
2020

2121
qs = sqrt(max(0, R[1, 1] + R[2, 2] + R[3, 3] + 1)) / 2

KomaMRIBase/src/motion/Action.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type AbstractAction{T<:Real} end
1+
abstract type AbstractAction end
22

33
Base.:(==)(a1::AbstractAction, a2::AbstractAction) = (typeof(a1) == typeof(a2)) & reduce(&, [getfield(a1, field) == getfield(a2, field) for field in fieldnames(typeof(a1))])
44
Base.:()(a1::AbstractAction, a2::AbstractAction) = (typeof(a1) == typeof(a2)) & reduce(&, [getfield(a1, field) getfield(a2, field) for field in fieldnames(typeof(a1))])

KomaMRIBase/src/motion/Motion.jl

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,16 @@ julia> motion = Motion(
2626
)
2727
```
2828
"""
29-
@with_kw mutable struct Motion{T<:Real}
30-
action::AbstractAction{T}
31-
time ::TimeCurve{T} = TimeRange(t_start=zero(typeof(action).parameters[1]), t_end=eps(typeof(action).parameters[1]))
29+
@with_kw mutable struct Motion
30+
action::AbstractAction
31+
time ::TimeCurve = TimeRange(t_start=0.0, t_end=eps(Float64))
3232
spins ::AbstractSpinSpan = AllSpins()
3333
end
3434

3535
# Main constructors
36-
function Motion(action)
37-
T = first(typeof(action).parameters)
38-
return Motion(action, TimeRange(t_start=zero(T), t_end=eps(T)), AllSpins())
39-
end
40-
function Motion(action, time::TimeCurve)
41-
T = first(typeof(action).parameters)
42-
return Motion(action, time, AllSpins())
43-
end
44-
function Motion(action, spins::AbstractSpinSpan)
45-
T = first(typeof(action).parameters)
46-
return Motion(action, TimeRange(t_start=zero(T), t_end=eps(T)), spins)
47-
end
36+
Motion(action) = Motion(action, TimeRange(t_start=0.0, t_end=eps(Float64)), AllSpins())
37+
Motion(action, time::TimeCurve) = Motion(action, time, AllSpins())
38+
Motion(action, spins::AbstractSpinSpan) = Motion(action, TimeRange(t_start=0.0, t_end=eps(Float64)), spins)
4839

4940
# Custom constructors
5041
"""
@@ -205,8 +196,8 @@ For each dimension (x, y, z), the output matrix has ``N_{\t{spins}}`` rows and `
205196
- `x, y, z`: (`::Tuple{AbstractArray, AbstractArray, AbstractArray}`) spin positions over time
206197
"""
207198
function get_spin_coords(
208-
m::Motion{T}, x::AbstractVector{T}, y::AbstractVector{T}, z::AbstractVector{T}, t
209-
) where {T<:Real}
199+
m::Motion, x::AbstractVector, y::AbstractVector, z::AbstractVector, t
200+
)
210201
ux, uy, uz = x .* (0*t), y .* (0*t), z .* (0*t) # Buffers for displacements
211202
t_unit = unit_time(t, m.time)
212203
idx = get_indexing_range(m.spins)

KomaMRIBase/src/motion/MotionList.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ julia> motionlist = MotionList(
3333
)
3434
```
3535
"""
36-
struct MotionList{T<:Real}
37-
motions::Vector{<:Motion{T}}
36+
struct MotionList
37+
motions::Vector{<:Motion}
3838
end
3939

4040
# NOTE: this constructor must be simplified once the Vector{<:Motion} approach is accomplished:
@@ -52,10 +52,10 @@ end
5252

5353
# NOTE: these vcat methods must be simplified once the Vector{<:Motion} approach is accomplished:
5454
# https://github.com/JuliaHealth/KomaMRI.jl/issues/480
55-
""" Addition of MotionLists """
55+
""" Addition of MotionLists """
5656
# MotionList + MotionList
57-
function Base.vcat(m1::MotionList{T}, m2::MotionList{T}, Ns1, Ns2) where {T<:Real}
58-
mv_aux = Motion{T}[]
57+
function Base.vcat(m1::MotionList, m2::MotionList, Ns1, Ns2)
58+
mv_aux = Motion[]
5959
for m in m1.motions
6060
m_aux = deepcopy(m)
6161
m_aux.spins = expand(m_aux.spins, Ns1)
@@ -70,8 +70,8 @@ function Base.vcat(m1::MotionList{T}, m2::MotionList{T}, Ns1, Ns2) where {T<:Rea
7070
return MotionList(mv_aux...)
7171
end
7272
# Motion + Motion
73-
function Base.vcat(m1::Motion{T}, m2::Motion{T}, Ns1, Ns2) where {T<:Real}
74-
mv_aux = Motion{T}[]
73+
function Base.vcat(m1::Motion, m2::Motion, Ns1, Ns2)
74+
mv_aux = Motion[]
7575
m_aux = deepcopy(m1)
7676
m_aux.spins = expand(m_aux.spins, Ns1)
7777
push!(mv_aux, m_aux)
@@ -82,9 +82,9 @@ function Base.vcat(m1::Motion{T}, m2::Motion{T}, Ns1, Ns2) where {T<:Real}
8282
return MotionList(mv_aux...)
8383
end
8484
# Motion + MotionList
85-
Base.vcat(m1::MotionList{T}, m2::Motion{T}, Ns1, Ns2) where {T<:Real} = vcat(m2, m1, Ns2, Ns1)
86-
function Base.vcat(m1::Motion{T}, m2::MotionList{T}, Ns1, Ns2) where {T<:Real}
87-
mv_aux = Motion{T}[]
85+
Base.vcat(m1::MotionList, m2::Motion, Ns1, Ns2) = vcat(m2, m1, Ns2, Ns1)
86+
function Base.vcat(m1::Motion, m2::MotionList, Ns1, Ns2)
87+
mv_aux = Motion[]
8888
m_aux = deepcopy(m1)
8989
m_aux.spins = expand(m_aux.spins, Ns1)
9090
push!(mv_aux, m_aux)
@@ -98,29 +98,29 @@ function Base.vcat(m1::Motion{T}, m2::MotionList{T}, Ns1, Ns2) where {T<:Real}
9898
end
9999

100100
""" MotionList sub-group """
101-
function Base.getindex(mv::MotionList{T}, p) where {T<:Real}
102-
motion_array_aux = Motion{T}[]
101+
function Base.getindex(mv::MotionList, p)
102+
motion_array_aux = Motion[]
103103
for m in mv.motions
104104
m[p] isa NoMotion ? nothing : push!(motion_array_aux, m[p])
105105
end
106106
return MotionList(motion_array_aux...)
107107
end
108-
function Base.view(mv::MotionList{T}, p) where {T<:Real}
109-
motion_array_aux = Motion{T}[]
108+
function Base.view(mv::MotionList, p)
109+
motion_array_aux = Motion[]
110110
for m in mv.motions
111111
@view(m[p]) isa NoMotion ? nothing : push!(motion_array_aux, @view(m[p]))
112112
end
113113
return MotionList(motion_array_aux...)
114114
end
115115

116116
""" Compare two MotionLists """
117-
function Base.:(==)(mv1::MotionList{T}, mv2::MotionList{T}) where {T<:Real}
117+
function Base.:(==)(mv1::MotionList, mv2::MotionList)
118118
if length(mv1) != length(mv2) return false end
119119
sort_motions!(mv1)
120120
sort_motions!(mv2)
121121
return reduce(&, mv1.motions .== mv2.motions)
122122
end
123-
function Base.:()(mv1::MotionList{T}, mv2::MotionList{T}) where {T<:Real}
123+
function Base.:()(mv1::MotionList, mv2::MotionList)
124124
if length(mv1) != length(mv2) return false end
125125
sort_motions!(mv1)
126126
sort_motions!(mv2)
@@ -131,14 +131,14 @@ end
131131
Base.length(m::MotionList) = length(m.motions)
132132

133133
function get_spin_coords(
134-
ml::MotionList{T}, x::AbstractVector{T}, y::AbstractVector{T}, z::AbstractVector{T}, t
135-
) where {T<:Real}
134+
ml::MotionList, x::AbstractVector, y::AbstractVector, z::AbstractVector, t
135+
)
136136
# Sort motions
137137
sort_motions!(ml)
138138
# Buffers for positions:
139139
xt, yt, zt = x .+ 0*t, y .+ 0*t, z .+ 0*t
140140
# Buffers for displacements:
141-
ux, uy, uz = xt .* zero(T), yt .* zero(T), zt .* zero(T)
141+
ux, uy, uz = zero.(xt), zero.(yt), zero.(zt)
142142
# Composable motions: they need to be run sequentially. Note that they depend on xt, yt, and zt
143143
for m in Iterators.filter(is_composable, ml.motions)
144144
t_unit = unit_time(t, m.time)
@@ -147,7 +147,7 @@ function get_spin_coords(
147147
displacement_y!(@view(uy[idx, :]), m.action, @view(xt[idx, :]), @view(yt[idx, :]), @view(zt[idx, :]), t_unit)
148148
displacement_z!(@view(uz[idx, :]), m.action, @view(xt[idx, :]), @view(yt[idx, :]), @view(zt[idx, :]), t_unit)
149149
xt .+= ux; yt .+= uy; zt .+= uz
150-
ux .*= zero(T); uy .*= zero(T); uz .*= zero(T)
150+
fill!(ux, 0); fill!(uy, 0); fill!(uz, 0)
151151
end
152152
# Additive motions: these motions can be run in parallel
153153
for m in Iterators.filter(!is_composable, ml.motions)
@@ -157,7 +157,7 @@ function get_spin_coords(
157157
displacement_y!(@view(uy[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
158158
displacement_z!(@view(uz[idx, :]), m.action, @view(x[idx]), @view(y[idx]), @view(z[idx]), t_unit)
159159
xt .+= ux; yt .+= uy; zt .+= uz
160-
ux .*= zero(T); uy .*= zero(T); uz .*= zero(T)
160+
fill!(ux, 0); fill!(uy, 0); fill!(uz, 0)
161161
end
162162
return xt, yt, zt
163163
end

KomaMRIBase/src/motion/NoMotion.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.view(mv::NoMotion, p) = mv
2121
Base.vcat(m1::NoMotion, m2::NoMotion, Ns1, Ns2) = m1
2222
# NoMotion + MotionList
2323
Base.vcat(m1::MotionList, m2::NoMotion, Ns1, Ns2) = vcat(m2, m1, 0, Ns1)
24-
function Base.vcat(m1::NoMotion, m2::MotionList{T}, Ns1, Ns2) where {T}
25-
mv_aux = Motion{T}[]
24+
function Base.vcat(m1::NoMotion, m2::MotionList, Ns1, Ns2)
25+
mv_aux = Motion[]
2626
for m in m2.motions
2727
m_aux = deepcopy(m)
2828
m_aux.spins = expand(m_aux.spins, Ns2)
@@ -33,7 +33,7 @@ function Base.vcat(m1::NoMotion, m2::MotionList{T}, Ns1, Ns2) where {T}
3333
end
3434
# NoMotion + Motion
3535
Base.vcat(m1::Motion, m2::NoMotion, Ns1, Ns2) = vcat(m2, m1, 0, Ns1)
36-
function Base.vcat(m1::NoMotion, m2::Motion{T}, Ns1, Ns2) where {T}
36+
function Base.vcat(m1::NoMotion, m2::Motion, Ns1, Ns2)
3737
m_aux = deepcopy(m2)
3838
m_aux.spins = expand(m_aux.spins, Ns2)
3939
m_aux.spins = SpinRange(m_aux.spins.range .+ Ns1)
@@ -44,9 +44,5 @@ end
4444
Base.:(==)(m1::NoMotion, m2::NoMotion) = true
4545
Base.:()(m1::NoMotion, m2::NoMotion) = true
4646

47-
function get_spin_coords(
48-
mv::NoMotion, x::AbstractVector{T}, y::AbstractVector{T}, z::AbstractVector{T}, t
49-
) where {T<:Real}
50-
return x, y, z
51-
end
47+
get_spin_coords(::NoMotion, x::AbstractVector, y::AbstractVector, z::AbstractVector, t) = (x, y, z)
5248
add_key_time_points!(t, ::NoMotion) = nothing

KomaMRIBase/src/motion/TimeCurve.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ julia> timecurve = TimeCurve(t=[0.0, 0.2, 0.4, 0.6], t_unit=[0.0, 1.0, 1.0, 0.0]
5151
```
5252
![Time Curve 4](../assets/time-curve-4.svg)
5353
"""
54-
@with_kw struct TimeCurve{T<:Real}
55-
t::AbstractVector{T}
56-
t_unit::AbstractVector{T}
57-
periodic::Bool = false
58-
periods::Union{T,AbstractVector{T}} = oneunit(eltype(t))
59-
t_start::T = t[1]
60-
t_end::T = t[end]
54+
@with_kw struct TimeCurve
55+
t::AbstractVector
56+
t_unit::AbstractVector
57+
periodic::Bool = false
58+
periods::Union{Real, AbstractVector{<:Real}} = oneunit(eltype(t))
59+
t_start::Real = t[1]
60+
t_end::Real = t[end]
6161
@assert check_unique(t) "Vector t=$(t) contains duplicate elements. Please ensure all elements in t are unique and try again"
6262
end
6363

@@ -88,8 +88,8 @@ julia> timerange = TimeRange(t_start=0.6, t_end=1.4)
8888
```
8989
![Time Range](../assets/time-range.svg)
9090
"""
91-
TimeRange(t_start::T, t_end::T) where T = TimeCurve(t=[t_start, t_end], t_unit=[zero(T), oneunit(T)])
92-
TimeRange(; t_start=0.0, t_end=1.0) = TimeRange(t_start, t_end)
91+
TimeRange(t_start::Real, t_end::Real) = TimeCurve(t=[t_start, t_end], t_unit=[zero(t_start), oneunit(t_start)])
92+
TimeRange(; t_start=0.0, t_end=1.0) = TimeRange(t_start, t_end)
9393

9494
# Define our own Periodic function to avoid extending Periodic from Interpolations.jl (required since Julia 1.12)
9595
function Periodic end
@@ -113,13 +113,13 @@ julia> periodic = Periodic(period=1.0, asymmetry=0.2)
113113
```
114114
![Periodic](../assets/periodic.svg)
115115
"""
116-
function Periodic(period::T, asymmetry::T) where T
117-
if asymmetry == oneunit(T)
118-
return TimeCurve(t=[zero(T), period], t_unit=[zero(T), oneunit(T)], periodic=true)
119-
elseif asymmetry == zero(T)
120-
return TimeCurve(t=[zero(T), period], t_unit=[oneunit(T), zero(T)], periodic=true)
116+
function Periodic(period::Real, asymmetry::Real)
117+
if asymmetry == oneunit(period)
118+
return TimeCurve(t=[zero(period), period], t_unit=[zero(period), oneunit(period)], periodic=true)
119+
elseif asymmetry == zero(period)
120+
return TimeCurve(t=[zero(period), period], t_unit=[oneunit(period), zero(period)], periodic=true)
121121
else
122-
return TimeCurve(t=[zero(T), period*asymmetry, period], t_unit=[zero(T), oneunit(T), zero(T)], periodic=true)
122+
return TimeCurve(t=[zero(period), period*asymmetry, period], t_unit=[zero(period), oneunit(period), zero(period)], periodic=true)
123123
end
124124
end
125125
Periodic(; period=1.0, asymmetry=0.5) = Periodic(period, asymmetry)

KomaMRIBase/src/motion/actions/ArbitraryAction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type ArbitraryAction{T<:Real} <: AbstractAction{T} end
1+
abstract type ArbitraryAction <: AbstractAction end
22

33
function Base.getindex(action::ArbitraryAction, p)
44
return typeof(action)([getfield(action, d)[p,:] for d in fieldnames(typeof(action))]...)

KomaMRIBase/src/motion/actions/SimpleAction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type SimpleAction{T<:Real} <: AbstractAction{T} end
1+
abstract type SimpleAction <: AbstractAction end
22

33
Base.getindex(action::SimpleAction, p) = action
44
Base.view(action::SimpleAction, p) = action

KomaMRIBase/src/motion/actions/arbitraryactions/FlowPath.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ julia> f = FlowPath(
2929
)
3030
```
3131
"""
32-
@with_kw struct FlowPath{T<:Real} <: ArbitraryAction{T}
33-
dx::AbstractArray{T}
34-
dy::AbstractArray{T}
35-
dz::AbstractArray{T}
32+
@with_kw struct FlowPath <: ArbitraryAction
33+
dx::AbstractArray
34+
dy::AbstractArray
35+
dz::AbstractArray
3636
spin_reset::AbstractArray{Bool}
3737
end
3838

39-
FlowPath(dx::AbstractArray{T}, dy::AbstractArray{T}, dz::AbstractArray{T}, spin_reset::BitMatrix) where T<:Real = FlowPath(dx, dy, dz, collect(spin_reset))
39+
FlowPath(dx::AbstractArray, dy::AbstractArray, dz::AbstractArray, spin_reset::BitMatrix) = FlowPath(dx, dy, dz, collect(spin_reset))
4040

4141
function add_reset_times!(t, a::FlowPath, t_start, t_end, periods)
4242
aux = t_start .+ (t_end - t_start)/(size(a.spin_reset)[2]-1) * (getindex.(findall(a.spin_reset .== 1), 2) .- 1)

0 commit comments

Comments
 (0)