Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ enum Device {
Xpu(usize),
Xla(usize),
Mlu(usize),
/// User didn't specify acceletor, torch
Hpu,
/// User didn't specify accelerator, torch
/// is responsible for choosing.
Anonymous(usize),
}
Expand Down Expand Up @@ -296,6 +297,7 @@ impl<'source> FromPyObject<'source> for Device {
"xpu" => Ok(Device::Xpu(0)),
"xla" => Ok(Device::Xla(0)),
"mlu" => Ok(Device::Mlu(0)),
"hpu" => Ok(Device::Hpu),
name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda),
name if name.starts_with("npu:") => parse_device(name).map(Device::Npu),
name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu),
Expand Down Expand Up @@ -327,6 +329,7 @@ impl<'py> IntoPyObject<'py> for Device {
Device::Xpu(n) => format!("xpu:{n}").into_pyobject(py).map(|x| x.into_any()),
Device::Xla(n) => format!("xla:{n}").into_pyobject(py).map(|x| x.into_any()),
Device::Mlu(n) => format!("mlu:{n}").into_pyobject(py).map(|x| x.into_any()),
Device::Hpu => "hpu".into_pyobject(py).map(|x| x.into_any()),
Device::Anonymous(n) => n.into_pyobject(py).map(|x| x.into_any()),
}
}
Expand Down
20 changes: 19 additions & 1 deletion bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import unittest
from importlib.util import find_spec

import torch

Expand All @@ -9,11 +10,12 @@

try:
import torch_npu # noqa

npu_present = True
except Exception:
npu_present = False

hpu_present = find_spec("habana_frameworks") is not None


class TorchTestCase(unittest.TestCase):
def test_serialization(self):
Expand Down Expand Up @@ -170,6 +172,22 @@ def test_npu(self):
for k, v in reloaded.items():
self.assertTrue(torch.allclose(data[k], reloaded[k]))

@unittest.skipIf(not hpu_present, "HPU is not available")
def test_hpu(self):
# must be run to load torch with Intel Gaudi bindings
import habana_frameworks.torch.core as htcore

data = {
"test1": torch.zeros((2, 2), dtype=torch.float32).to("hpu"),
"test2": torch.zeros((2, 2), dtype=torch.float16).to("hpu"),
}
local = "./tests/data/out_safe_pt_mmap_small_hpu.safetensors"
save_file(data, local)

reloaded = load_file(local, device="hpu")
for k, v in reloaded.items():
self.assertTrue(torch.allclose(data[k], reloaded[k]))

@unittest.skipIf(not torch.cuda.is_available(), "Cuda is not available")
def test_anonymous_accelerator(self):
data = {
Expand Down
Loading