Open
Description
I'm testing out torchtune with MPS in torchchat, but ran into this error when passing in a larger image (1536*2040).
Looks like torchtune/models/clip/_transform.py
doesn't play well with MPS due to _upsample_bilinear2d_aa
File "/Users/jackkhuu/Desktop/oss/torchchat/torchchat/generate.py", line 766, in _gen_model_input
data = transform({"messages": messages}, inference=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchtune/models/llama3_2_vision/_transform.py", line 211, in __call__
out = self.transform_image({"image": image}, inference=inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchtune/models/clip/_transform.py", line 180, in __call__
image = resize_with_pad(
^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchtune/modules/transforms/vision_utils/resize_with_pad.py", line 89, in resize_with_pad
image = F.resize(
^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchvision/transforms/v2/functional/_geometry.py", line 188, in resize
return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchvision/transforms/v2/functional/_utils.py", line 31, in wrapper
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torchvision/transforms/v2/functional/_geometry.py", line 260, in resize_image
image = interpolate(
^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torch/nn/functional.py", line 4418, in interpolate
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torch/overrides.py", line 1717, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torch/utils/_device.py", line 106, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/jackkhuu/Desktop/oss/torchchat/.venv/lib/python3.12/site-packages/torch/nn/functional.py", line 4565, in interpolate
return torch._C._nn._upsample_bilinear2d_aa(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: The operator 'aten::_upsample_bilinear2d_aa.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Activity