-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathtensor.py
More file actions
137 lines (113 loc) · 4.55 KB
/
tensor.py
File metadata and controls
137 lines (113 loc) · 4.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# tensor.py -*- Python -*-
#
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# (c) Copyright 2025-2026 Advanced Micro Devices, Inc.
import numpy as np
import pyxrt as xrt
from ..tensor_class import Tensor
from aie.helpers.util import np_ndarray_type_get_shape
class XRTTensor(Tensor):
"""
Tensor object backed by memory accessble from the 'npu' and 'cpu' devices, managed using PyXRT.
The class provides common tensor operations such as creation,
filling with values, and accessing data.
"""
def __init__(
self,
shape_or_data,
dtype=np.uint32,
device="npu",
flags=xrt.bo.host_only,
group_id=0,
xrt_device=None,
):
"""
Initialize the XRTTensor.
Args:
shape_or_data (tuple or array-like):
- If a tuple, creates a new tensor with the given shape and dtype.
- If array-like, wraps the data into a tensor with optional dtype casting.
dtype (np.dtype, optional): Data type of the tensor. Defaults to np.uint32.
device (str, optional): Device string identifier. Defaults to 'npu'.
flags (optional): XRT buffer object flags. Defaults to xrt.bo.host_only.
group_id (int, optional): XRT buffer object group ID. Defaults to 0.
xrt_device (optional): Existing PyXRT device handle to use for BO allocation.
When omitted, a new handle for device index 0 is opened for this tensor.
"""
super().__init__(shape_or_data, dtype=dtype, device=device)
self.xrt_device = xrt_device if xrt_device is not None else xrt.device(0)
# Extract the shape
if isinstance(shape_or_data, tuple):
# If this is a shape, check for it "ShapeLike"-ness using numpy ndarray types.
np_type = np.ndarray[shape_or_data, np.dtype[dtype]]
self._shape = np_ndarray_type_get_shape(np_type)
elif hasattr(shape_or_data, "shape"):
# If this is a shaped thing, we will trust it.
self._shape = shape_or_data.shape
np_data = shape_or_data
else:
# TODO(efficiency): Extra data copy here (when necessary)
# so we can borrow verification of array-like things from numpy.
np_data = np.asarray(shape_or_data, dtype=dtype)
self._shape = np_data.shape
# Ideally, we use xrt::ext::bo host-only BO but there are no bindings for that currently.
# Eventually, xrt:ext::bo uses the 0 magic number that shall be fixed in the future, so that is used as a default.
# https://github.com/Xilinx/XRT/blob/9b114f18c4fcf4e3558291aa2d78f6d97c406365/src/runtime_src/core/common/api/xrt_bo.cpp#L1626
self._bo = xrt.bo(
self.xrt_device,
int(np.prod(self._shape) * np.dtype(self.dtype).itemsize),
flags,
group_id,
)
ptr = self._bo.map()
self._data = np.frombuffer(ptr, dtype=self.dtype).reshape(self._shape)
if not isinstance(shape_or_data, tuple):
np.copyto(self._data, np_data)
else:
self._data.fill(0)
if self.device == "npu":
self._sync_to_device()
@property
def data(self):
"""
Get the underlying numpy array.
Returns:
np.ndarray: The underlying data.
"""
return self._data
@property
def shape(self):
"""
Get the shape of the tensor.
Returns:
tuple: The shape of the tensor.
"""
return self._shape
def _sync_to_device(self):
"""
Syncs the tensor data from the host to the device memory.
"""
return self._bo.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE)
def _sync_from_device(self):
"""
Syncs the tensor data from the device to the host memory.
"""
return self._bo.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
def __del__(self):
"""
Destructor for Tensor.
Releases associated device memory (e.g., XRT buffer object).
"""
if hasattr(self, "_bo"):
del self._bo
self._bo = None
def buffer_object(self):
"""
Returns the XRT buffer object associated with this tensor.
Returns:
buffer_object: The XRT buffer object associated with this tensor.
"""
return self._bo