@@ -20,6 +20,32 @@ def fixture_key():
2020 return prng .PRNGKey (seed = 1 )
2121
2222
23+ @testing .parametrize ("num_batches" , [1_000 ])
24+ @testing .parametrize ("num_samples_per_batch" , [1_000 ])
25+ @testing .parametrize ("dim" , [1 , 10 ])
26+ @testing .parametrize ("sample_fun" , [montecarlo .normal , montecarlo .rademacher ])
27+ def test_frobeniusnorm_squared (
28+ fun , key , num_batches , num_samples_per_batch , dim , sample_fun
29+ ):
30+ """Assert that the Frobenius norm estimate is accurate."""
31+ # Linearise function
32+ x0 = prng .uniform (key , shape = (dim ,)) # random lin. point
33+ _ , jvp = func .linearize (fun , x0 )
34+ J = func .jacfwd (fun )(x0 )
35+
36+ # Estimate the trace
37+ fun = sample_fun (shape = np .shape (x0 ), dtype = np .dtype (x0 ))
38+ estimate = hutch .frobeniusnorm_squared (
39+ jvp ,
40+ num_batches = num_batches ,
41+ key = key ,
42+ num_samples_per_batch = num_samples_per_batch ,
43+ sample_fun = fun ,
44+ )
45+ truth = np .trace (J .T @ J )
46+ assert np .allclose (estimate , truth , rtol = 1e-2 )
47+
48+
2349@testing .parametrize ("num_batches" , [1_000 ])
2450@testing .parametrize ("num_samples_per_batch" , [1_000 ])
2551@testing .parametrize ("dim" , [1 , 10 ])
0 commit comments