File tree Expand file tree Collapse file tree 2 files changed +12
-2
lines changed
Expand file tree Collapse file tree 2 files changed +12
-2
lines changed Original file line number Diff line number Diff line change 11import ctypes
2+ import ml_dtypes
23
34import torch
45from mlir import ir
@@ -14,7 +15,15 @@ def to_memref(input: torch.Tensor) -> ctypes.Structure:
1415 Args:
1516 input: PyTorch tensor.
1617 """
17- return get_ranked_memref_descriptor (input .numpy ())
18+ if input .dtype == torch .bfloat16 :
19+ # Numpy doesn't support bf16 natively which disables
20+ # direct conversion from PyTorch.
21+ # Solved through non-destructive type casting.
22+ nparray = input .view (dtype = torch .uint16 ).numpy ()
23+ nparray = nparray .view (ml_dtypes .bfloat16 )
24+ else :
25+ nparray = input .numpy ()
26+ return get_ranked_memref_descriptor (nparray )
1827
1928
2029def to_packed_args (inputs : list [torch .Tensor ]) -> ctypes .Array [ctypes .c_void_p ]:
Original file line number Diff line number Diff line change @@ -17,7 +17,8 @@ dev = [
1717
1818[project .optional-dependencies ]
1919ingress_torch_mlir = [
20- " torch-mlir==20260125.703"
20+ " torch-mlir==20260125.703" ,
21+ " ml_dtypes" ,
2122]
2223# Additional "targets" which pull in optional dependencies -- use `uv sync --extra TARGET`
2324ingress_torch_cpu = [
You can’t perform that action at this time.
0 commit comments