-
Notifications
You must be signed in to change notification settings - Fork 48
Expand file tree
/
Copy pathaxpy.py
More file actions
178 lines (147 loc) · 5.31 KB
/
axpy.py
File metadata and controls
178 lines (147 loc) · 5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (C) 2026, Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT
"""Vectorized AXPY Example
Implements the AXPY operation on 1D vectors [N]:
y = a * x + y
where a is a scalar and x, y are vectors.
Uses a 1x2 AIE herd with DMA transfers between L3 and L1 memory.
Computation is vectorized using vector.fma (fused multiply-add)
with configurable VECTOR_SIZE (default 16).
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ml_dtypes import bfloat16
from air.ir import *
from air.dialects.air import *
from air.dialects import arith
from air.dialects.arith import ConstantOp
from air.dialects.memref import AllocOp, DeallocOp
from air.dialects.vector import BroadcastOp, fma
from air.dialects.func import FuncOp
from air.dialects.scf import for_, yield_
from air.backend.xrt_runner import type_mapper, make_air_parser, run_on_npu
from utils import vec_read, vec_write
import numpy as np
np.random.seed(42)
range_ = for_
@module_builder
def build_module(n, tile_n, np_dtype_in, alpha=2.0, vector_size=16):
xrt_dtype_in = type_mapper(np_dtype_in)
num_tiles = 2
assert n % (tile_n * num_tiles) == 0
assert tile_n % vector_size == 0
VECTOR_SIZE = vector_size
index_type = IndexType.get()
l3memrefTy = MemRefType.get([n], xrt_dtype_in)
l1MemrefTy = l1_memref_type([tile_n], xrt_dtype_in)
vecTy = vec_type(VECTOR_SIZE, xrt_dtype_in)
imap = identity_map_attr()
@FuncOp.from_py_func(l3memrefTy, l3memrefTy, l3memrefTy)
def axpy(arg0, arg1, arg2):
# arg0 = x (input), arg1 = y (input), arg2 = output
@herd(
name="herd_0",
sizes=[1, num_tiles],
operands=[arg0, arg1, arg2],
)
def herd_body(
_tx,
_ty,
_sx,
_sy,
_l3_x,
_l3_y,
_l3_out,
):
l1_x_data = AllocOp(l1MemrefTy, [], [])
l1_y_data = AllocOp(l1MemrefTy, [], [])
l1_out_data = AllocOp(l1MemrefTy, [], [])
for _l_ivx in range_(0, n, tile_n * num_tiles):
offset = tile_offset_1d(_l_ivx, _ty, tile_n)
dma_memcpy_nd(
l1_x_data,
_l3_x,
src_offsets=[offset],
src_sizes=[tile_n],
src_strides=[1],
)
dma_memcpy_nd(
l1_y_data,
_l3_y,
src_offsets=[offset],
src_sizes=[tile_n],
src_strides=[1],
)
c0 = ConstantOp(index_type, 0)
cVecSize = ConstantOp(index_type, VECTOR_SIZE)
cTileN = ConstantOp(index_type, tile_n)
cst0 = arith.ConstantOp(xrt_dtype_in, 0.0)
# Broadcast scalar alpha to vector
a_const = arith.ConstantOp(xrt_dtype_in, alpha)
v_a = BroadcastOp(vecTy, a_const)
for j in range_(c0, cTileN, cVecSize):
v_x = vec_read(l1_x_data, j, VECTOR_SIZE, c0, vecTy, cst0, imap)
v_y = vec_read(l1_y_data, j, VECTOR_SIZE, c0, vecTy, cst0, imap)
# a * x + y via vector.fma
v_result = fma(v_a, v_x, v_y)
vec_write(v_result, l1_out_data, j, VECTOR_SIZE, c0, imap)
yield_([])
# Write result from l1_out back to L3 output buffer
dma_memcpy_nd(
_l3_out,
l1_out_data,
dst_offsets=[offset],
dst_sizes=[tile_n],
dst_strides=[1],
)
DeallocOp(l1_x_data)
DeallocOp(l1_y_data)
DeallocOp(l1_out_data)
yield_([])
if __name__ == "__main__":
N = 65536
TILE_N = 1024
INPUT_DATATYPE = bfloat16
ALPHA = 2.0
parser = make_air_parser("Builds, runs, and tests the AXPY example")
parser.add_argument("--n", type=int, default=N, help="Total number of elements")
parser.add_argument("--tile-n", type=int, default=TILE_N, help="Tile size")
parser.add_argument(
"--alpha", type=float, default=ALPHA, help="Scalar multiplier a"
)
parser.add_argument(
"--vector-size",
type=int,
default=16,
help="Vector size for SIMD operations",
)
args = parser.parse_args()
mlir_module = build_module(
args.n, args.tile_n, INPUT_DATATYPE, args.alpha, args.vector_size
)
if args.print_module_only:
print(mlir_module)
exit(0)
input_x = np.random.randn(args.n).astype(INPUT_DATATYPE)
input_y = np.random.randn(args.n).astype(INPUT_DATATYPE)
sampled_indices = np.vstack([np.random.randint(0, args.n, 100)])
sampled_values = np.array(
[args.alpha * input_x[i] + input_y[i] for i in zip(*sampled_indices)],
dtype=INPUT_DATATYPE,
)
sampled_data = {
"shape": (args.n,),
"indices": sampled_indices,
"values": sampled_values,
}
exit(
run_on_npu(
args,
mlir_module,
inputs=[input_x, input_y],
instance_name="axpy",
stochastic_expected_outputs=[sampled_data],
rtol=1e-2,
)
)