Skip to content

Commit ec4162b

Browse files
authored
Merge pull request #213 from hiddenSymmetries/fw/smalltweaks
Small performance tweaks
2 parents 4f63543 + e51d8d6 commit ec4162b

File tree

7 files changed

+106
-33
lines changed

7 files changed

+106
-33
lines changed

src/simsopt/_core/derivative.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __missing__(self, key):
2121

2222
def copy_numpy_dict(d):
2323
res = OptimizableDefaultDict({})
24-
for k in d:
25-
res[k] = d[k].copy()
24+
for k, v in d.items():
25+
res[k] = v.copy()
2626
return res
2727

2828

@@ -114,31 +114,41 @@ def __add__(self, other):
114114
y = other.data
115115
z = copy_numpy_dict(x)
116116
for k in y:
117-
z[k] += y[k]
118-
117+
if k in z:
118+
z[k] += y[k]
119+
else:
120+
z[k] = y[k].copy()
119121
return Derivative(z)
120122

121123
def __sub__(self, other):
122124
x = self.data
123125
y = other.data
124126
z = copy_numpy_dict(x)
125-
for k in y:
126-
z[k] -= y[k]
127-
127+
for k, yk in y.items():
128+
if k in z:
129+
z[k] -= yk
130+
else:
131+
z[k] = -yk
128132
return Derivative(z)
129133

130134
def __iadd__(self, other):
131135
x = self.data
132136
y = other.data
133-
for k in y:
134-
x[k] += y[k]
137+
for k, yk in y.items():
138+
if k in x:
139+
x[k] += yk
140+
else:
141+
x[k] = yk.copy()
135142
return self
136143

137144
def __isub__(self, other):
138145
x = self.data
139146
y = other.data
140-
for k in y:
141-
x[k] -= y[k]
147+
for k, yk in y.items():
148+
if k in x:
149+
x[k] -= yk
150+
else:
151+
x[k] = -yk
142152
return self
143153

144154
def __mul__(self, other):

src/simsopt/_core/graph_optimizable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ def __init__(self,
531531
# instances of same class
532532
self._id = ImmutableId(next(self.__class__._ids))
533533
self.name = self.__class__.__name__ + str(self._id.id)
534+
hash_str = hashlib.sha256(self.name.encode('utf-8')).hexdigest()
535+
self.hash = int(hash_str, 16) % 10**32 # 32 digit int as hash
534536
self._children = set() # This gets populated when the object is passed
535537
# as argument to another Optimizable object
536538
self.return_fns = WeakKeyDefaultDict(list) # Store return fn's required by each child
@@ -583,8 +585,7 @@ def __str__(self):
583585
return self.name
584586

585587
def __hash__(self) -> int:
586-
hash_str = hashlib.sha256(self.name.encode('utf-8')).hexdigest()
587-
return int(hash_str, 16) % 10**32 # 32 digit int as hash
588+
return self.hash
588589

589590
def __eq__(self, other: Optimizable) -> bool:
590591
"""

src/simsopt/geo/curve.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,12 @@ def gamma_impl(self, gamma, quadpoints):
662662
663663
"""
664664

665-
self.curve.gamma_impl(gamma, quadpoints)
666-
gamma[:] = gamma @ self.rotmat
665+
if len(quadpoints) == len(self.curve.quadpoints) \
666+
and np.sum((quadpoints-self.curve.quadpoints)**2) < 1e-15:
667+
gamma[:] = self.curve.gamma() @ self.rotmat
668+
else:
669+
self.curve.gamma_impl(gamma, quadpoints)
670+
gamma[:] = gamma @ self.rotmat
667671

668672
def gammadash_impl(self, gammadash):
669673
r"""

src/simsopt/objectives/fluxobjective.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from simsopt._core.graph_optimizable import Optimizable
22
from .._core.derivative import derivative_dec
33
import numpy as np
4+
import simsoptpp as sopp
45

56

67
class SquaredFlux(Optimizable):
@@ -32,17 +33,10 @@ def __init__(self, surface, field, target=None):
3233
Optimizable.__init__(self, x0=np.asarray([]), depends_on=[field])
3334

3435
def J(self):
35-
xyz = self.surface.gamma()
3636
n = self.surface.normal()
37-
absn = np.linalg.norm(n, axis=2)
38-
unitn = n * (1./absn)[:, :, None]
39-
Bcoil = self.field.B().reshape(xyz.shape)
40-
Bcoil_n = np.sum(Bcoil*unitn, axis=2)
41-
if self.target is not None:
42-
B_n = (Bcoil_n - self.target)
43-
else:
44-
B_n = Bcoil_n
45-
return 0.5 * np.mean(B_n**2 * absn)
37+
Bcoil = self.field.B().reshape(n.shape)
38+
Btarget = self.target if self.target is not None else []
39+
return sopp.integral_BdotN(Bcoil, Btarget, n)
4640

4741
@derivative_dec
4842
def dJ(self):

src/simsoptpp/python.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,29 @@ PYBIND11_MODULE(simsoptpp, m) {
113113
return C;
114114
});
115115

116+
m.def("integral_BdotN", [](PyArray& Bcoil, PyArray& Btarget, PyArray& n) {
117+
int nphi = Bcoil.shape(0);
118+
int ntheta = Bcoil.shape(1);
119+
double *Bcoil_ptr = &(Bcoil(0, 0, 0));
120+
double *Btarget_ptr = NULL;
121+
if(Btarget.size() == Bcoil.size())
122+
Btarget_ptr = &(Btarget(0, 0, 0));
123+
double *n_ptr = &(n(0, 0, 0));
124+
double res = 0;
125+
#pragma omp parallel for reduction(+:res)
126+
for(int i=0; i<nphi*ntheta; i++){
127+
double normN = std::sqrt(n_ptr[3*i+0]*n_ptr[3*i+0] + n_ptr[3*i+1]*n_ptr[3*i+1] + n_ptr[3*i+2]*n_ptr[3*i+2]);
128+
double Nx = n_ptr[3*i+0]/normN;
129+
double Ny = n_ptr[3*i+1]/normN;
130+
double Nz = n_ptr[3*i+2]/normN;
131+
double BcoildotN = Bcoil_ptr[3*i+0]*Nx + Bcoil_ptr[3*i+1]*Ny + Bcoil_ptr[3*i+2]*Nz;
132+
if(Btarget_ptr != NULL)
133+
BcoildotN -= Btarget_ptr[3*i];
134+
res += (BcoildotN * BcoildotN) * normN;
135+
}
136+
return 0.5 * res / (nphi*ntheta);
137+
});
138+
116139
#ifdef VERSION_INFO
117140
m.attr("__version__") = VERSION_INFO;
118141
#else

tests/core/test_derivative.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,23 +273,49 @@ def test_sub_mul(self):
273273
assert np.allclose(dj1m2(opt1), -1*dj1(opt1))
274274
assert np.allclose(dj1m2(opt2), -dj2(opt2))
275275

276-
def test_iadd_isub_imul(self):
276+
def test_iadd(self):
277277
opt1 = Opt(n=3)
278278
opt2 = Opt(n=2)
279279

280280
dj1 = opt1.dfoo_vjp(np.ones(3))
281281
dj1_ = opt1.dfoo_vjp(np.ones(3))
282282
dj2 = opt2.dfoo_vjp(np.ones(2))
283+
dj2_ = opt2.dfoo_vjp(np.ones(2))
283284

285+
dj1 += dj1_
284286
dj1 += dj2
285-
assert np.allclose(dj1(opt2), dj2(opt2))
286-
dj1 += dj1
287287
assert np.allclose(dj1(opt1), 2*dj1_(opt1))
288-
dj1 -= 3*dj2
289-
assert np.allclose(dj1(opt2), -1*dj2(opt2))
290-
dj1 *= 1.5
291-
assert np.allclose(dj1(opt2), -1.5*dj2(opt2))
292-
assert np.allclose(dj1(opt1), 3*dj1_(opt1))
288+
assert np.allclose(dj1(opt2), dj2_(opt2))
289+
290+
def test_isub(self):
291+
opt1 = Opt(n=3)
292+
opt2 = Opt(n=2)
293+
294+
dj1 = opt1.dfoo_vjp(np.ones(3))
295+
dj1_ = opt1.dfoo_vjp(np.ones(3))
296+
dj2 = opt2.dfoo_vjp(np.ones(2))
297+
dj2_ = opt2.dfoo_vjp(np.ones(2))
298+
299+
dj1 -= 2*dj1_
300+
dj1 -= dj2
301+
assert np.allclose(dj1(opt1), (-1)*dj1_(opt1))
302+
assert np.allclose(dj1(opt2), -dj2_(opt2))
303+
304+
def test_imul(self):
305+
opt1 = Opt(n=3)
306+
opt2 = Opt(n=2)
307+
308+
dj1 = opt1.dfoo_vjp(np.ones(3))
309+
dj2 = opt2.dfoo_vjp(np.ones(2))
310+
311+
dj1_ = opt1.dfoo_vjp(np.ones(3))
312+
dj2_ = opt2.dfoo_vjp(np.ones(2))
313+
314+
dj1 *= 2.
315+
assert np.allclose(dj1(opt1), 2*dj1_(opt1))
316+
dj = dj1 + 4*dj2
317+
assert np.allclose(dj(opt1), 2*dj1_(opt1))
318+
assert np.allclose(dj(opt2), 4*dj2_(opt2))
293319

294320
def test_zero_when_not_found(self):
295321
opt1 = Opt(n=3)

tests/geo/test_curve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,21 @@ def test_plot(self):
427427
ax = curve.plot(engine=engine, ax=ax, show=False, close=close)
428428
c.plot(engine=engine, ax=ax, close=close, plot_derivative=True, show=show)
429429

430+
def test_rotated_curve_gamma_impl(self):
431+
rc = get_curve("CurveXYZFourier", True, x=100)
432+
c = rc.curve
433+
mat = rc.rotmat
434+
435+
rcg = rc.gamma()
436+
cg = c.gamma()
437+
quadpoints = rc.quadpoints
438+
439+
assert np.allclose(rcg, cg@mat)
440+
# run gamma_impl so that the `else` in RotatedCurve.gamma_impl gets triggered
441+
tmp = np.zeros_like(cg[:10, :])
442+
rc.gamma_impl(tmp, quadpoints[:10])
443+
assert np.allclose(cg[:10, :]@mat, tmp)
444+
430445

431446
if __name__ == "__main__":
432447
unittest.main()

0 commit comments

Comments
 (0)