Skip to content

Double Precision Issues #13

@harrisonzhu508

Description

@harrisonzhu508

Hi!

Many thanks for open-sourcing this package.

I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

  • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
  • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

Software and hardware details:

objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
GeForce RTX 3090 GPUs

Thanks in advance.

Best,
Harrison

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions