In Laplace approximation, the Hessian of the loss function is computed for quadratic approximation. Can this package be used to do a block-diagonal approximation of the Hessian at the minimum? If yes, could you please show (using jax and flax) how to approximate it and define a quadratic approximation of the loss function (which should be something like 1/2 (theta - theta_star)^T H(L)(theta_star) (theta - theta_star), where theta_star is the minimum and H(L) is the Hessian of the loss function)?