Skip to content

[Bug] Implicit assumption of double precision can cause failures when single precision is used #2596

Open
@AVHopp

Description

🐛 Bug

Within test_functions/base.py, the bounds are hard-coded to double precision. This makes it impossible to use single-precision which is necessary for e.g. MPS support on Mac.

To reproduce

import torch
from botorch.test_functions import Rastrigin

torch.set_default_device("mps")
torch.set_default_dtype(torch.float32)

test = Rastrigin()

This yields the following error message (paths shortened for better readability):

Traceback (most recent call last):
  File "[...]/test.py", line 7, in <module>
    test = Rastrigin()
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/synthetic.py", line 664, in __init__
    super().__init__(noise_std=noise_std, negate=negate, bounds=bounds)
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/synthetic.py", line 83, in __init__
    super().__init__(noise_std=noise_std, negate=negate)
  File "[...]/lib/python3.10/site-packages/botorch/test_functions/base.py", line 51, in __init__
    "bounds", torch.tensor(self._bounds, dtype=torch.double).transpose(-1, -2)
  File "[...]/lib/python3.10/site-packages/torch/utils/_device.py", line 79, in __torch_function__
    return func(*args, **kwargs)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Expected Behavior

The initialization of the BaseTestProblem class should not enforce double precision in the buffer for the bounds but probably use torch.get_default_dtype()

System information

Please complete the following information:

  • BoTorch Version 0.12.0
  • PyTorch Version 2.4.1
  • MacOS

Additional context

These seem to be the problematic lines:

self.register_buffer(
"bounds", torch.tensor(self._bounds, dtype=torch.double).transpose(-1, -2)
)

If the error is in fact just this one line of code, I'd be more than happy to create a mini Pull Request if it makes sense :)

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions