-
Notifications
You must be signed in to change notification settings - Fork 32
Description
Hi!
Many thanks for open-sourcing this package.
I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,
- for the
Periodickernel, it is rather sensitive to the matrix operations in_sequential_kf()and_sequential_rts(). - Likewise, the same when the lengthscales are too large in the
Matern32kernel.
However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.
Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.
Software and hardware details:
objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56 Driver Version: 460.56 CUDA Version: 11.2
GeForce RTX 3090 GPUs
Thanks in advance.
Best,
Harrison