@@ -273,23 +273,49 @@ def test_sub_mul(self):
273273 assert np .allclose (dj1m2 (opt1 ), - 1 * dj1 (opt1 ))
274274 assert np .allclose (dj1m2 (opt2 ), - dj2 (opt2 ))
275275
276- def test_iadd_isub_imul (self ):
276+ def test_iadd (self ):
277277 opt1 = Opt (n = 3 )
278278 opt2 = Opt (n = 2 )
279279
280280 dj1 = opt1 .dfoo_vjp (np .ones (3 ))
281281 dj1_ = opt1 .dfoo_vjp (np .ones (3 ))
282282 dj2 = opt2 .dfoo_vjp (np .ones (2 ))
283+ dj2_ = opt2 .dfoo_vjp (np .ones (2 ))
283284
285+ dj1 += dj1_
284286 dj1 += dj2
285- assert np .allclose (dj1 (opt2 ), dj2 (opt2 ))
286- dj1 += dj1
287287 assert np .allclose (dj1 (opt1 ), 2 * dj1_ (opt1 ))
288- dj1 -= 3 * dj2
289- assert np .allclose (dj1 (opt2 ), - 1 * dj2 (opt2 ))
290- dj1 *= 1.5
291- assert np .allclose (dj1 (opt2 ), - 1.5 * dj2 (opt2 ))
292- assert np .allclose (dj1 (opt1 ), 3 * dj1_ (opt1 ))
288+ assert np .allclose (dj1 (opt2 ), dj2_ (opt2 ))
289+
290+ def test_isub (self ):
291+ opt1 = Opt (n = 3 )
292+ opt2 = Opt (n = 2 )
293+
294+ dj1 = opt1 .dfoo_vjp (np .ones (3 ))
295+ dj1_ = opt1 .dfoo_vjp (np .ones (3 ))
296+ dj2 = opt2 .dfoo_vjp (np .ones (2 ))
297+ dj2_ = opt2 .dfoo_vjp (np .ones (2 ))
298+
299+ dj1 -= 2 * dj1_
300+ dj1 -= dj2
301+ assert np .allclose (dj1 (opt1 ), (- 1 )* dj1_ (opt1 ))
302+ assert np .allclose (dj1 (opt2 ), - dj2_ (opt2 ))
303+
304+ def test_imul (self ):
305+ opt1 = Opt (n = 3 )
306+ opt2 = Opt (n = 2 )
307+
308+ dj1 = opt1 .dfoo_vjp (np .ones (3 ))
309+ dj2 = opt2 .dfoo_vjp (np .ones (2 ))
310+
311+ dj1_ = opt1 .dfoo_vjp (np .ones (3 ))
312+ dj2_ = opt2 .dfoo_vjp (np .ones (2 ))
313+
314+ dj1 *= 2.
315+ assert np .allclose (dj1 (opt1 ), 2 * dj1_ (opt1 ))
316+ dj = dj1 + 4 * dj2
317+ assert np .allclose (dj (opt1 ), 2 * dj1_ (opt1 ))
318+ assert np .allclose (dj (opt2 ), 4 * dj2_ (opt2 ))
293319
294320 def test_zero_when_not_found (self ):
295321 opt1 = Opt (n = 3 )
0 commit comments