forked from Xilinx/mlir-aie
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathruntime.py
More file actions
393 lines (339 loc) · 16.3 KB
/
runtime.py
File metadata and controls
393 lines (339 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
# runtime.py -*- Python -*-
#
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# (c) Copyright 2024-2026 Advanced Micro Devices, Inc.
"""Runtime: orchestrates host-side data movement and worker execution for an IRON program."""
from __future__ import annotations
from collections import defaultdict
from contextlib import contextmanager
import logging
import numpy as np
from typing import Callable
logger = logging.getLogger(__name__)
from ...utils import trace as trace_utils
from ... import ir # type: ignore
from ...dialects.aie import tile
from ...dialects.aiex import runtime_sequence
from ...dialects._aiex_ops_gen import dma_await_task, dma_free_task # type: ignore
from ...helpers.taplib import TensorAccessPattern
from ..dataflow import ObjectFifoHandle
from ..device import Tile, AnyShimTile
from ..resolvable import Resolvable
from ..worker import Worker, WorkerRuntimeBarrier, _BarrierSetOp
from .dmatask import DMATask
from .data import RuntimeData
from .endpoint import RuntimeEndpoint
from .taskgroup import RuntimeTaskGroup
from .task import (
RuntimeTask,
RuntimeStartTask,
InlineOpRuntimeTask,
FinishTaskGroupTask,
)
class Runtime(Resolvable):
"""A Runtime contains that operations and structure of all operations that
need to be taken care of by the host/runtime in order to run a program.
"""
# Used to generate unique task group IDs within this Runtime.
__task_group_index = 0
def __init__(
self,
strict_task_groups: bool = True,
) -> None:
"""Initialize a runtime object.
Args:
strict_task_groups (bool): Disallows mixing the default group and explicit task groups during resolution.
This can catch common errors, but can be set to False to disable the checks.
"""
self._rt_data = []
self._tasks: list[RuntimeTask] = []
self._fifos = set()
self._workers = []
self._open_task_groups = []
self._trace_size = None
self._trace_workers = None
self._strict_task_groups = strict_task_groups
self._ddr_id = 4
@contextmanager
def sequence(self, *input_types: type[np.ndarray]):
"""A RuntimeSequence is a sequence of operations that are performed in
support of a program. Common operations include input and output data movement.
Raises:
ValueError: Arguments are validated.
ValueError: If task groups are not finished within the sequence() context, and error will be raised.
Yields:
RuntimeData | tuple[RuntimeData, ...]: Handles to the runtime buffers matching the declared input types.
"""
try:
self._rt_data = list(map(RuntimeData, input_types))
if len(self._rt_data) == 1:
yield self._rt_data[0]
else:
yield tuple(self._rt_data.copy())
finally:
if len(self._open_task_groups) != 0:
tgs_str = ", ".join([str(t) for t in self._open_task_groups])
raise ValueError(f"Failed to close task groups: {tgs_str}")
for of_handle in self._fifos:
# It's very easy to accidentally generate multiple (identical)
# consumers in the runtime. This bit of code prunes out duplicates.
if not of_handle._is_prod:
fifo_obj = of_handle._object_fifo
runtime_cons = None
to_remove = []
for c in fifo_obj._cons:
if isinstance(c.endpoint, RuntimeEndpoint):
if not runtime_cons:
runtime_cons = c
else:
if (
c.depth == runtime_cons.depth
and c.dims_from_stream
== runtime_cons.dims_from_stream
):
to_remove.append(c)
else:
raise ValueError(
f"Found two different RuntimeEndpoints for consumers of the same ObjectFifo: {fifo_obj}"
)
for r in to_remove:
fifo_obj._cons.remove(r)
def task_group(self) -> RuntimeTaskGroup:
"""Generate a handle to a RuntimeTaskGroup.
This should be called within a Runtime.sequence() context.
Returns:
RuntimeTaskGroup: The new RuntimeTaskGroup
"""
tg = RuntimeTaskGroup(self.__task_group_index)
self._open_task_groups.append(tg)
self.__task_group_index += 1
return tg
def finish_task_group(self, task_group: RuntimeTaskGroup):
"""Close out a RuntimeTaskGroup.
This should be called within a Runtime.sequence() context.
Args:
task_group (RuntimeTaskGroup): The task group to close. All associated tasks will be awaited or freed.
"""
self._open_task_groups.remove(task_group)
self._tasks.append(FinishTaskGroupTask(task_group))
def fill(
self,
in_fifo: ObjectFifoHandle,
source: RuntimeData,
tap: TensorAccessPattern | None = None,
task_group: RuntimeTaskGroup | None = None,
wait: bool = False,
tile: Tile = AnyShimTile,
) -> None:
"""Conceptually fill an ObjectFifoHandle (of type producer) with data from a runtime buffer.
This should be called within a Runtime.sequence() context.
Args:
in_fifo (ObjectFifoHandle): The producer ObjectFifoHandle.
source (RuntimeData): The input Runtime data buffer.
tap (TensorAccessPattern | None, optional): A way of specifying how data in the buffer is accessed when sending it to the in_fifo.
If None is given, this will default to a linear transfer containing all data in the source buffer. Defaults to None.
task_group (RuntimeTaskGroup | None, optional): A TaskGroup to associate this task with. Defaults to None.
wait (bool, optional): Whether this Task should be awaited on or not. If not, it will be freed when the task group is finished. Defaults to False.
tile (Tile | None, optional): The Shim tile to associate the data transfer with. Defaults to AnyShimTile.
Raises:
ValueError: Arguments are validated.
"""
if source not in self._rt_data:
raise ValueError(
f"Source {source} is not a RuntimeData object generated by sequence()"
)
rt_endpoint = RuntimeEndpoint(tile)
if tap is None:
tap = source.default_tap()
in_fifo.endpoint = rt_endpoint
self._fifos.add(in_fifo)
self._tasks.append(DMATask(in_fifo, source, tap, task_group, wait))
def drain(
self,
out_fifo: ObjectFifoHandle,
dest: RuntimeData,
tap: TensorAccessPattern | None = None,
task_group: RuntimeTaskGroup | None = None,
wait: bool = False,
tile: Tile = AnyShimTile,
) -> None:
"""Conceptually fill an ObjectFifoHandle (of type consumer) of data and write that data to a runtime buffer.
This should be called within a Runtime.sequence() context.
Args:
out_fifo (ObjectFifoHandle): The consumer ObjectFifoHandle.
dest (RuntimeData): The output Runtime data buffer.
tap (TensorAccessPattern | None, optional): A way of specifying how data in the buffer is accessed when reading from the out_fifo.
If None is given, this will default to a linear transfer containing all data in the destination buffer. Defaults to None.
task_group (RuntimeTaskGroup | None, optional): A TaskGroup to associate this task with. Defaults to None.
wait (bool, optional): Whether this Task should be awaited on or not. If not, it will be freed when the task group is finished. Defaults to False.
tile (Tile | None, optional): The Shim tile to associate the data transfer with. Defaults to AnyShimTile.
Raises:
ValueError: Arguments are validated.
"""
if dest not in self._rt_data:
raise ValueError(
f"Destination {dest} is not a RuntimeData object generated by sequence()"
)
rt_endpoint = RuntimeEndpoint(tile)
if tap is None:
tap = dest.default_tap()
out_fifo.endpoint = rt_endpoint
self._fifos.add(out_fifo)
self._tasks.append(DMATask(out_fifo, dest, tap, task_group, wait))
def start(self, *args: Worker):
"""A placeholder operation to indicate that one or more Worker should be started on the device.
This should be called within a Runtime.sequence() context.
Args:
*args: One or more Workers. If more than one is given, they will be started in order.
Raises:
ValueError: Arguments are validated.
"""
for worker in args:
if not isinstance(worker, Worker):
raise ValueError("Runtime can only start Worker objects")
self._workers.append(worker)
self._tasks.append(RuntimeStartTask(worker))
def inline_ops(self, inline_func: Callable, inline_args: list):
"""Insert an InlineOpRuntimeTask into the runtime.
This should be called within a Runtime.sequence() context.
Args:
inline_func (Callable): The function to execute within an MLIR context.
inline_args (list): The state the function needs to execute.
"""
# TODO: should filter args based on some criteria??
self._tasks.append(InlineOpRuntimeTask(inline_func, inline_args))
def enable_trace(
self,
trace_size: int = None,
workers: list | None = None,
ddr_id: int = 4,
coretile_events: list | None = None,
coremem_events: list | None = None,
memtile_events: list | None = None,
shimtile_events: list | None = None,
):
"""Enable hardware tracing for this program.
Configures the AIE trace units and routes trace packets to DDR via the shim DMA.
Should be called within a :meth:`sequence` context before data movement operations.
Args:
trace_size (int): Size of the trace buffer in bytes.
workers (list[Worker] | None, optional): Specific workers to trace. If None,
all workers with ``trace`` set will be traced. Defaults to None.
ddr_id (int, optional): XRT inout buffer index (0-4) to write trace data
into, mapping to group_id (3-7). Defaults to 4 (group_id 7).
Set to -1 to append trace data after the last runtime_sequence
tensor argument.
coretile_events (list | None, optional): List of up to 8 core tile trace events.
See ``https://xilinx.github.io/mlir-aie/AIEXDialect.html`` for available
events under (type)EventAIE such as CoreEventAIE.
Defaults to None (uses hardware defaults).
coremem_events (list | None, optional): List of up to 8 core memory trace events.
Defaults to None (uses hardware defaults).
memtile_events (list | None, optional): List of up to 8 mem tile trace events.
Defaults to None (uses hardware defaults).
shimtile_events (list | None, optional): List of up to 8 shim tile trace events.
Defaults to None (uses hardware defaults).
"""
self._trace_size = trace_size
self._trace_workers = workers
self._ddr_id = ddr_id
self._coretile_events = coretile_events
self._coremem_events = coremem_events
self._memtile_events = memtile_events
self._shimtile_events = shimtile_events
def set_barrier(self, barrier: WorkerRuntimeBarrier, value: int):
"""Set the value of a worker barrier.
This should be called within a Runtime.sequence() context.
Args:
barrier (WorkerRuntimeBarrier): The WorkerRuntimeBarrier to set.
value (int): The value to set the barrier to.
"""
self._tasks.append(_BarrierSetOp(barrier, value))
@property
def workers(self) -> list[Worker]:
"""The workers associated with the Runtime by calls to start()"""
return self._workers.copy()
@property
def fifos(self) -> list[ObjectFifoHandle]:
"""The ObjectFifoHandles associated with the Runtime by calls to fill() and drain()"""
return self._fifos.copy()
def get_first_cons_shimtile(self):
"""Find the first consumer side of an objfifo that is in the 0th row
and uses it as the trace shim tile
"""
for of_handle in self._fifos:
if not of_handle._is_prod:
endpoint_tile = of_handle._object_fifo._cons[0]._endpoint._tile
if endpoint_tile.row == 0:
return endpoint_tile.op
def resolve(
self,
loc: ir.Location | None = None,
ip: ir.InsertionPoint | None = None,
) -> None:
rt_dtypes = [rt_data.arr_type for rt_data in self._rt_data]
task_group_actions = defaultdict(list)
@runtime_sequence(*rt_dtypes)
def sequence(*args):
if self._trace_size is not None and self._trace_size > 0:
trace_utils.start_trace(
trace_size=self._trace_size,
ddr_id=self._ddr_id,
routing="single",
)
for rt_data, rt_data_val in zip(self._rt_data, args):
rt_data.op = rt_data_val
def finish_task_group(tg, task_group_actions):
actions = task_group_actions[tg]
# We want to keep order, EXCEPT do waits before frees
wait_tasks = [
(fn, args) for (fn, args) in actions if fn == dma_await_task
]
free_tasks = [
(fn, args) for (fn, args) in actions if fn == dma_free_task
]
# Check for anything known -- this shouldn't happen, but we'll catch it gracefully anyways.
if len(wait_tasks) + len(free_tasks) != len(actions):
unknown_actions = [
(fn, args)
for (fn, args) in actions
if fn != dma_await_task and fn != dma_free_task
]
raise Exception(
f"Unknown action type detected: {','.join(unknown_actions)}"
)
for fn, args in wait_tasks + free_tasks:
fn(*args)
task_group_actions[tg] = None
default_task_group = self.task_group()
default_tasks = False
task_group_tasks = False
for task in self._tasks:
task.resolve()
if isinstance(task, DMATask):
if task.task_group:
task_group_tasks = True
current_task_group = task.task_group
else:
default_tasks = True
current_task_group = default_task_group
if task.will_wait():
task_group_actions[current_task_group].append(
(dma_await_task, [task.task])
)
else:
task_group_actions[current_task_group].append(
(dma_free_task, [task.task])
)
if isinstance(task, FinishTaskGroupTask):
finish_task_group(task.task_group, task_group_actions)
if self._strict_task_groups and default_tasks and task_group_tasks:
raise Exception(
f"Mixing explicit task groups and the default task group is prohibitted. "
f"Please assign all default tasks ({task_group_actions[default_task_group]}) to a task group."
)
if task_group_actions[default_task_group]:
finish_task_group(default_task_group, task_group_actions)