Skip to content

Commit 850f4b5

Browse files
authored
Merge pull request #128 from slimgroup/dft-fix
Fix DFT caching
2 parents 204ac0e + edd6313 commit 850f4b5

File tree

6 files changed

+63
-42
lines changed

6 files changed

+63
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JUDI"
22
uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
33
authors = ["Philipp Witte, Mathias Louboutin"]
4-
version = "3.1.3"
4+
version = "3.1.4"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/pysource/fields.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,24 @@ def lr_src_fields(model, weight, wavelet, empty_ws=False):
167167
return source_weight, wavelett
168168

169169

170+
def frequencies(freq, fdim=None):
171+
"""
172+
Frequencies as a one dimensional Function
173+
174+
Parameters
175+
----------
176+
freq: List or 1D array
177+
List of frequencies
178+
"""
179+
if freq is None:
180+
return None, 0
181+
nfreq = np.shape(freq)[0]
182+
freq_dim = fdim or DefaultDimension(name='freq_dim', default_value=nfreq)
183+
f = Function(name='f', dimensions=(freq_dim,), shape=(nfreq,))
184+
f.data[:] = np.array(freq[:])
185+
return f, nfreq
186+
187+
170188
def fourier_modes(u, freq):
171189
"""
172190
On the fly DFT wavefield (frequency slices) and expression
@@ -182,10 +200,8 @@ def fourier_modes(u, freq):
182200
return None, None
183201

184202
# Frequencies
185-
nfreq = np.shape(freq)[0]
186-
freq_dim = DefaultDimension(name='freq_dim', default_value=nfreq)
187-
f = Function(name='f', dimensions=(freq_dim,), shape=(nfreq,))
188-
f.data[:] = np.array(freq[:])
203+
f, nfreq = frequencies(freq)
204+
freq_dim = f.dimensions[0]
189205

190206
dft_modes = []
191207
for wf in as_tuple(u):

src/pysource/operators.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __get__(self, obj, objtype):
7373

7474
@memoized_func
7575
def forward_op(p_params, tti, visco, space_order, spacing, save, t_sub, fs, pt_src,
76-
pt_rec, dft, dft_sub, ws, full_q):
76+
pt_rec, nfreq, dft_sub, ws, full_q):
7777
"""
7878
Low level forward operator creation, to be used through `propagator.py`
7979
Compute forward wavefield u = A(m)^{-1}*f and related quantities (u(xrcv))
@@ -86,7 +86,7 @@ def forward_op(p_params, tti, visco, space_order, spacing, save, t_sub, fs, pt_s
8686
scords = np.ones((1, ndim)) if pt_src else None
8787
rcords = np.ones((1, ndim)) if pt_rec else None
8888
wavelet = np.ones((nt, 1))
89-
freq_list = np.ones((2,)) if dft else None
89+
freq_list = np.ones((nfreq,)) if nfreq > 0 else None
9090
q = wavefield(model, 0, save=True, nt=nt, name="qwf") if full_q else 0
9191
wsrc = Function(name='src_weight', grid=model.grid, space_order=0) if ws else None
9292

@@ -120,7 +120,7 @@ def forward_op(p_params, tti, visco, space_order, spacing, save, t_sub, fs, pt_s
120120

121121
@memoized_func
122122
def adjoint_op(p_params, tti, visco, space_order, spacing, save, nv_weights, fs,
123-
pt_src, pt_rec, dft, dft_sub, ws, full_q):
123+
pt_src, pt_rec, nfreq, dft_sub, ws, full_q):
124124
"""
125125
Low level adjoint operator creation, to be used through `propagators.py`
126126
Compute adjoint wavefield v = adjoint(F(m))*y
@@ -134,7 +134,7 @@ def adjoint_op(p_params, tti, visco, space_order, spacing, save, nv_weights, fs,
134134
scords = np.ones((1, ndim)) if pt_src else None
135135
rcords = np.ones((1, ndim)) if pt_rec else None
136136
wavelet = np.ones((nt, 1))
137-
freq_list = np.ones((2,)) if dft else None
137+
freq_list = np.ones((nfreq,)) if nfreq > 0 else None
138138
q = wavefield(model, 0, save=True, nt=nt, fw=False, name="qwf") if full_q else 0
139139

140140
# Setting adjoint wavefield
@@ -167,7 +167,7 @@ def adjoint_op(p_params, tti, visco, space_order, spacing, save, nv_weights, fs,
167167

168168
@memoized_func
169169
def born_op(p_params, tti, visco, space_order, spacing, save, pt_src,
170-
pt_rec, fs, t_sub, ws, dft, dft_sub, isic, nlind):
170+
pt_rec, fs, t_sub, ws, nfreq, dft_sub, isic, nlind):
171171
"""
172172
Low level born operator creation, to be used through `interface.py`
173173
Compute linearized wavefield U = J(m)* δ m
@@ -181,7 +181,7 @@ def born_op(p_params, tti, visco, space_order, spacing, save, pt_src,
181181
wavelet = np.ones((nt, 1))
182182
scords = np.ones((1, ndim)) if pt_src else None
183183
rcords = np.ones((1, ndim)) if pt_rec else None
184-
freq_list = np.ones((2,)) if dft else None
184+
freq_list = np.ones((nfreq,)) if nfreq > 0 else None
185185
wsrc = Function(name='src_weight', grid=model.grid, space_order=0) if ws else None
186186
f0 = Constant('f0')
187187

@@ -220,7 +220,7 @@ def born_op(p_params, tti, visco, space_order, spacing, save, pt_src,
220220

221221
@memoized_func
222222
def adjoint_born_op(p_params, tti, visco, space_order, spacing, pt_rec, fs, w,
223-
save, t_sub, dft, dft_sub, isic):
223+
save, t_sub, nfreq, dft_sub, isic):
224224
"""
225225
Low level gradient operator creation, to be used through `propagators.py`
226226
Compute the action of the adjoint Jacobian onto a residual J'* δ d.
@@ -231,10 +231,11 @@ def adjoint_born_op(p_params, tti, visco, space_order, spacing, pt_rec, fs, w,
231231
ndim = len(spacing)
232232
residual = np.ones((nt, 1))
233233
rcords = np.ones((1, ndim)) if pt_rec else None
234-
freq_list = np.ones((2,)) if dft else None
234+
freq_list = np.ones((nfreq,)) if nfreq > 0 else None
235235
# Setting adjoint wavefieldgradient
236236
v = wavefield(model, space_order, fw=False)
237-
u = forward_wavefield(model, space_order, save=save, nt=nt, dft=dft, t_sub=t_sub)
237+
u = forward_wavefield(model, space_order, save=save, nt=nt,
238+
dft=nfreq > 0, t_sub=t_sub)
238239

239240
# Set up PDE expression and rearrange
240241
pde = wave_kernel(model, v, fw=False, f0=Constant('f0'))

src/pysource/propagators.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from kernels import wave_kernel
22
from geom_utils import src_rec, geom_expr
33
from fields import (fourier_modes, wavefield, lr_src_fields,
4-
wavefield_subsampled, norm_holder)
4+
wavefield_subsampled, norm_holder, frequencies)
55
from fields_exprs import extented_src
66
from sensitivity import grad_expr
7-
from utils import weight_fun, opt_op, fields_kwargs
7+
from utils import weight_fun, opt_op, fields_kwargs, nfreq
88
from operators import forward_op, adjoint_op, born_op, adjoint_born_op
99

1010
from devito import Operator, Function, Constant
@@ -32,7 +32,7 @@ def forward(model, src_coords, rcv_coords, wavelet, space_order=8, save=False,
3232
op = forward_op(model.physical_parameters, model.is_tti, model.is_viscoacoustic,
3333
space_order, model.spacing, save, t_sub, model.fs,
3434
src_coords is not None, rcv_coords is not None,
35-
freq_list is not None, dft_sub, ws is not None, qwf is not None)
35+
nfreq(freq_list), dft_sub, ws is not None, qwf is not None)
3636

3737
# Make kwargs
3838
kw = {'dt': model.critical_dt}
@@ -83,7 +83,7 @@ def adjoint(model, y, src_coords, rcv_coords, space_order=8, qwf=None, dft_sub=N
8383
op = adjoint_op(model.physical_parameters, model.is_tti, model.is_viscoacoustic,
8484
space_order, model.spacing, save, nv_weights, model.fs,
8585
src_coords is not None, rcv_coords is not None,
86-
freq_list is not None, dft_sub, ws is not None, qwf is not None)
86+
nfreq(freq_list), dft_sub, ws is not None, qwf is not None)
8787

8888
# On-the-fly Fourier
8989
dft_modes, fr = fourier_modes(v, freq_list)
@@ -133,12 +133,13 @@ def gradient(model, residual, rcv_coords, u, return_op=False, space_order=8,
133133
# Create operator and run
134134
op = adjoint_born_op(model.physical_parameters, model.is_tti, model.is_viscoacoustic,
135135
space_order, model.spacing, rcv_coords is not None, model.fs, w,
136-
not return_op, t_sub, freq is not None, dft_sub, isic)
136+
not return_op, t_sub, nfreq(freq), dft_sub, isic)
137137

138138
# Update kwargs
139139
kw = {'dt': model.critical_dt}
140+
f, _factor = frequencies(freq)
140141
f0q = Constant('f0', value=f0) if model.is_viscoacoustic else None
141-
kw.update(fields_kwargs(src, u, v, gradm, f0q))
142+
kw.update(fields_kwargs(src, u, v, gradm, f0q, f))
142143
kw.update(model.physical_params())
143144

144145
if return_op:
@@ -172,7 +173,7 @@ def born(model, src_coords, rcv_coords, wavelet, space_order=8, save=False,
172173
op = born_op(model.physical_parameters, model.is_tti, model.is_viscoacoustic,
173174
space_order, model.spacing, save,
174175
src_coords is not None, rcv_coords is not None, model.fs, t_sub,
175-
ws is not None, freq_list is not None, dft_sub, isic, nlind)
176+
ws is not None, nfreq(freq_list), dft_sub, isic, nlind)
176177

177178
# Make kwargs
178179
kw = {'dt': model.critical_dt}

src/pysource/sensitivity.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from devito import Eq
55
from devito.tools import as_tuple
66

7+
from fields import frequencies
78
from fields_exprs import sub_time, freesurface
89
from FD_utils import divs, grads
910

@@ -89,17 +90,15 @@ def crosscorr_freq(u, v, model, freq=None, dft_sub=None, **kwargs):
8990
tsave, factor = sub_time(time, dft_sub)
9091
expr = 0
9192

93+
fdim = as_tuple(u)[0][0].dimensions[0]
94+
f, nfreq = frequencies(freq, fdim=fdim)
95+
omega_t = 2*np.pi*f*tsave*factor*dt
96+
# Gradient weighting is (2*np.pi*f)**2/nt
97+
w = -(2*np.pi*f)**2/time.symbolic_max
98+
9299
for uu, vv in zip(u, v):
93100
ufr, ufi = uu
94-
# Frequencies
95-
nfreq = np.shape(freq)[0]
96-
fdim = ufr.dimensions[0]
97-
omega_t = lambda f: 2*np.pi*f*tsave*factor*dt
98-
# Gradient weighting is (2*np.pi*f)**2/nt
99-
w = lambda f: -(2*np.pi*f)**2/time.symbolic_max
100-
expr += sum(w(freq[ff])*(ufr._subs(fdim, ff)*cos(omega_t(freq[ff])) -
101-
ufi._subs(fdim, ff)*sin(omega_t(freq[ff])))
102-
for ff in range(nfreq))*vv
101+
expr += w*(ufr*cos(omega_t) - ufi*sin(omega_t))*vv
103102
return expr
104103

105104

@@ -139,21 +138,18 @@ def isic_freq(u, v, model, **kwargs):
139138
time = model.grid.time_dim
140139
dt = time.spacing
141140
tsave, factor = sub_time(time, kwargs.get('factor'))
141+
fdim = as_tuple(u)[0][0].dimensions[0]
142+
f, nfreq = frequencies(freq, fdim=fdim)
143+
omega_t = 2*np.pi*f*tsave*factor*dt
144+
w = -(2*np.pi*f)**2/time.symbolic_max
145+
w2 = factor / time.symbolic_max
146+
142147
expr = 0
143148
for uu, vv in zip(u, v):
144149
ufr, ufi = uu
145-
# Frequencies
146-
nfreq = np.shape(freq)[0]
147-
fdim = ufr.dimensions[0]
148-
omega_t = lambda f: 2*np.pi*f*tsave*factor*dt
149-
# Gradient weighting is (2*np.pi*f)**2/nt
150-
w = lambda f: -(2*np.pi*f)**2/time.symbolic_max
151-
w2 = factor / time.symbolic_max
152-
for ff in range(nfreq):
153-
cwt, swt = cos(omega_t(freq[ff])), sin(omega_t(freq[ff]))
154-
ufrf, ufif = ufr._subs(fdim, ff), ufi._subs(fdim, ff)
155-
idftu = (ufrf * cwt - ufif * swt)
156-
expr += w(freq[ff]) * idftu * vv * model.m - w2 * inner_grad(idftu, vv)
150+
cwt, swt = cos(omega_t), sin(omega_t)
151+
idftu = (ufr * cwt - ufi * swt)
152+
expr += w * idftu * vv * model.m - w2 * inner_grad(idftu, vv)
157153
return expr
158154

159155

src/pysource/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ def opt_op(model):
115115
return ('advanced', opts)
116116

117117

118+
def nfreq(freq_list):
119+
"""
120+
Check number of on-the-fly DFT frequencies.
121+
"""
122+
return 0 if freq_list is None else np.shape(freq_list)[0]
123+
124+
118125
def fields_kwargs(*args):
119126
"""
120127
Creates a dictionary of {f.name: f} for any field argument that is not None

0 commit comments

Comments
 (0)