Skip to content

Commit d32d1cf

Browse files
authored
[utils][torch] PyTorch to memref - enable bf16 conversion (#49)
Adds a special case for bfloat16 torch.Tensor to memref conversion.
1 parent db37707 commit d32d1cf

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

lighthouse/utils/torch.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
import ml_dtypes
23

34
import torch
45
from mlir import ir
@@ -14,7 +15,15 @@ def to_memref(input: torch.Tensor) -> ctypes.Structure:
1415
Args:
1516
input: PyTorch tensor.
1617
"""
17-
return get_ranked_memref_descriptor(input.numpy())
18+
if input.dtype == torch.bfloat16:
19+
# Numpy doesn't support bf16 natively which disables
20+
# direct conversion from PyTorch.
21+
# Solved through non-destructive type casting.
22+
nparray = input.view(dtype=torch.uint16).numpy()
23+
nparray = nparray.view(ml_dtypes.bfloat16)
24+
else:
25+
nparray = input.numpy()
26+
return get_ranked_memref_descriptor(nparray)
1827

1928

2029
def to_packed_args(inputs: list[torch.Tensor]) -> ctypes.Array[ctypes.c_void_p]:

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ dev = [
1717

1818
[project.optional-dependencies]
1919
ingress_torch_mlir = [
20-
"torch-mlir==20260125.703"
20+
"torch-mlir==20260125.703",
21+
"ml_dtypes",
2122
]
2223
# Additional "targets" which pull in optional dependencies -- use `uv sync --extra TARGET`
2324
ingress_torch_cpu = [

0 commit comments

Comments
 (0)