Skip to content

Commit af8d685

Browse files
tsingmicro-public-etsingmicro-public-ezhzhcookie
authored
[TLE] Add TLE (DSA) support for tsingmicro (#426)
* [BACKEND] Add TLE (DSA) support and TLEToMK pipeline for tsingmicro * [TLE] third_party/tle DSA dialect and DsaToCore conversion * [BACKEND] TLEToMK, Tx81 recv/send, MK/Tx81/compiler integration * [PYTHON] triton.experimental.tle language and ir bindings * [BUILD] CMake, setup.py, wheel/build scripts * [CI] tsingmicro workflow and flaggems CI script * [EXAMPLE] tle DSA NOC GEMM example --------- Co-authored-by: tsingmicro-public-e <ludingtao@tsingmicro.com> * [BUILD] Fix tsingmicro build [BUILD] Fix tsingmicro build * [CI] Fix test tsingmicro path --------- Co-authored-by: tsingmicro-public-e <ludingtao@tsingmicro.com> Co-authored-by: zhengyang <zhengyang@baai.ac.cn>
1 parent 9cb8d8f commit af8d685

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+3835
-94
lines changed

.github/workflows/tsingmicro-build-and-test.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ jobs:
8484
export PYTHONPATH=${LLVM_SYSPATH}/python_packages/mlir_core:$PYTHONPATH
8585
export LD_LIBRARY_PATH=$TX8_DEPS_ROOT/lib:$LD_LIBRARY_PATH
8686
87-
cd third_party/tsingmicro/examples
87+
# tsingmicro
88+
pushd third_party/tsingmicro/examples
8889
python3 bare_matmul_autotune.py >result-bare_matmul_autotune.txt
8990
python3 embedding.py >result-embedding.txt
9091
python3 mult_ir.py >result-mult_ir.txt
@@ -99,3 +100,9 @@ jobs:
99100
python3 test_softmax.py >result-test_softmax.txt
100101
python3 test_vec_add.py >result-test_vec_add.txt
101102
python3 time1.py >result-time1.txt
103+
popd
104+
105+
# tle on tsingmicro
106+
pushd third_party/tsingmicro/examples/tle
107+
python3 test_tle_dsa_noc_gemm_4096.py >result-noc_gemm.txt
108+
popd

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,10 @@ if(TRITON_BUILD_PYTHON_MODULE)
285285
list(APPEND TRITON_PLUGIN_NAMES "proton")
286286
add_subdirectory(third_party/proton/dialect)
287287
288+
# Add TLE plugin
289+
list(APPEND TRITON_PLUGIN_NAMES "tle")
290+
add_subdirectory(third_party/tle)
291+
288292
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
289293
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
290294
set(TRITON_LIBRARIES
@@ -460,6 +464,8 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
460464
add_subdirectory(third_party/${CODEGEN_BACKEND})
461465
endforeach()
462466
add_subdirectory(third_party/proton/dialect)
467+
# flagtree tle
468+
add_subdirectory(third_party/tle)
463469
endif()
464470
465471
find_package(Threads REQUIRED)

python/setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,10 @@ def get_packages():
741741
"triton/backends",
742742
"triton/tools",
743743
"triton/tools/extra",
744+
"triton/experimental",
745+
"triton/experimental/tle",
746+
"triton/experimental/tle/language",
747+
"triton/experimental/tle/language/dsa",
744748
]
745749
if helper.flagtree_backend == "xpu":
746750
packages.append("triton/language/extra/xpu")

python/setup_tools/setup_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
'LLVM_INCLUDE_DIRS': Path(path) / "include",
1717
'LLVM_LIBRARY_DIR': Path(path) / "lib",
1818
'LLVM_SYSPATH': path,
19+
'PATH': os.pathsep.join([str(Path(path) / "bin"), os.getenv("PATH", "")]),
1920
})
2021

2122

@@ -432,6 +433,7 @@ def uninstall_triton():
432433
pre_hook=lambda: check_env('TX8_DEPS_ROOT'),
433434
post_hook=lambda path: set_env({
434435
'TX8_DEPS_ROOT': path,
436+
'TX8_YOC_RT_THREAD_SMP': Path(path) / "tx8-yoc-rt-thread-smp",
435437
}),
436438
)
437439

python/src/ir.cc

Lines changed: 1 addition & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2424
#include "mlir/Transforms/LocationSnapshot.h"
2525

26+
#include "ir.h"
2627
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2728
#include "triton/Dialect/Triton/IR/Dialect.h"
2829
#include "triton/Dialect/Triton/IR/Types.h"
@@ -56,90 +57,6 @@ llvm::raw_ostream &mlir_dumps_or_dbgs() {
5657
}
5758
}
5859

59-
// A custom op builder that keeps track of the last location
60-
class TritonOpBuilder {
61-
public:
62-
TritonOpBuilder(MLIRContext *context) {
63-
builder = std::make_unique<OpBuilder>(context);
64-
lastLoc = std::make_unique<Location>(builder->getUnknownLoc());
65-
}
66-
67-
OpBuilder &getBuilder() { return *builder; }
68-
MLIRContext *getContext() { return builder->getContext(); }
69-
70-
bool isLineInfoEnabled() { return lineInfoEnabled; }
71-
72-
void setLastLoc(Location loc) {
73-
if (lineInfoEnabled)
74-
lastLoc = std::make_unique<Location>(loc);
75-
}
76-
77-
void setLastLoc(const std::string &fileName, int line, int column) {
78-
auto context = builder->getContext();
79-
setLastLoc(FileLineColLoc::get(context, fileName, line, column));
80-
}
81-
82-
Location getLastLoc() {
83-
assert(lastLoc);
84-
return *lastLoc;
85-
}
86-
87-
void setInsertionPointToStart(Block &block) {
88-
if (!block.empty())
89-
setLastLoc(block.begin()->getLoc());
90-
else
91-
setLastLoc(builder->getUnknownLoc());
92-
builder->setInsertionPointToStart(&block);
93-
}
94-
95-
void setInsertionPointToEnd(Block &block) {
96-
if (!block.empty())
97-
setLastLoc(block.back().getLoc());
98-
else
99-
setLastLoc(builder->getUnknownLoc());
100-
builder->setInsertionPointToEnd(&block);
101-
}
102-
103-
void setInsertionPointAfter(Operation &op) {
104-
setLastLoc(op.getLoc());
105-
builder->setInsertionPointAfter(&op);
106-
}
107-
108-
void restoreInsertionPoint(OpBuilder::InsertPoint pt) {
109-
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
110-
setLastLoc(pt.getPoint()->getLoc());
111-
else
112-
setLastLoc(builder->getUnknownLoc());
113-
builder->restoreInsertionPoint(pt);
114-
}
115-
116-
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
117-
auto loc = getLastLoc();
118-
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
119-
}
120-
121-
// Overload to create or fold a single result operation.
122-
template <typename OpTy, typename... Args>
123-
std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value>
124-
createOrFold(Args &&...args) {
125-
auto loc = getLastLoc();
126-
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
127-
}
128-
129-
// Overload to create or fold a zero result operation.
130-
template <typename OpTy, typename... Args>
131-
std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy>
132-
createOrFold(Args &&...args) {
133-
auto loc = getLastLoc();
134-
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
135-
}
136-
137-
private:
138-
std::unique_ptr<OpBuilder> builder;
139-
std::unique_ptr<Location> lastLoc;
140-
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
141-
};
142-
14360
// Run the pass manager under a source manager diagnostic handler, which
14461
// enables emitted MLIR diagnostics to directly reference Python source
14562
// code. This diagnostic handler supports filtering diagnostic info by

python/src/ir.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#pragma once
2+
3+
#include "mlir/IR/Builders.h"
4+
#include "triton/Tools/Sys/GetEnv.hpp"
5+
6+
#include <cassert>
7+
#include <memory>
8+
9+
// A custom op builder that keeps track of the last location.
10+
class TritonOpBuilder {
11+
public:
12+
TritonOpBuilder(mlir::MLIRContext *context) {
13+
builder = std::make_unique<mlir::OpBuilder>(context);
14+
lastLoc = std::make_unique<mlir::Location>(builder->getUnknownLoc());
15+
}
16+
17+
mlir::OpBuilder &getBuilder() { return *builder; }
18+
mlir::MLIRContext *getContext() { return builder->getContext(); }
19+
20+
bool isLineInfoEnabled() { return lineInfoEnabled; }
21+
22+
void setLastLoc(mlir::Location loc) {
23+
if (lineInfoEnabled)
24+
lastLoc = std::make_unique<mlir::Location>(loc);
25+
}
26+
27+
void setLastLoc(const std::string &fileName, int line, int column) {
28+
auto context = builder->getContext();
29+
setLastLoc(mlir::FileLineColLoc::get(context, fileName, line, column));
30+
}
31+
32+
mlir::Location getLastLoc() {
33+
assert(lastLoc);
34+
return *lastLoc;
35+
}
36+
37+
void setInsertionPointToStart(mlir::Block &block) {
38+
if (!block.empty())
39+
setLastLoc(block.begin()->getLoc());
40+
else
41+
setLastLoc(builder->getUnknownLoc());
42+
builder->setInsertionPointToStart(&block);
43+
}
44+
45+
void setInsertionPointToEnd(mlir::Block &block) {
46+
if (!block.empty())
47+
setLastLoc(block.back().getLoc());
48+
else
49+
setLastLoc(builder->getUnknownLoc());
50+
builder->setInsertionPointToEnd(&block);
51+
}
52+
53+
void setInsertionPointAfter(mlir::Operation &op) {
54+
setLastLoc(op.getLoc());
55+
builder->setInsertionPointAfter(&op);
56+
}
57+
58+
void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) {
59+
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
60+
setLastLoc(pt.getPoint()->getLoc());
61+
else
62+
setLastLoc(builder->getUnknownLoc());
63+
builder->restoreInsertionPoint(pt);
64+
}
65+
66+
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
67+
auto loc = getLastLoc();
68+
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
69+
}
70+
71+
template <typename OpTy, typename... Args>
72+
std::enable_if_t<OpTy::template hasTrait<mlir::OpTrait::OneResult>(),
73+
mlir::Value>
74+
createOrFold(Args &&...args) {
75+
auto loc = getLastLoc();
76+
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
77+
}
78+
79+
template <typename OpTy, typename... Args>
80+
std::enable_if_t<OpTy::template hasTrait<mlir::OpTrait::ZeroResults>(), OpTy>
81+
createOrFold(Args &&...args) {
82+
auto loc = getLastLoc();
83+
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
84+
}
85+
86+
private:
87+
std::unique_ptr<mlir::OpBuilder> builder;
88+
std::unique_ptr<mlir::Location> lastLoc;
89+
bool lineInfoEnabled =
90+
!mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
91+
};

python/src/main.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ namespace py = pybind11;
99
#define FOR_EACH_3(MACRO, X, ...) MACRO(X) FOR_EACH_2(MACRO, __VA_ARGS__)
1010
#define FOR_EACH_4(MACRO, X, ...) MACRO(X) FOR_EACH_3(MACRO, __VA_ARGS__)
1111
#define FOR_EACH_5(MACRO, X, ...) MACRO(X) FOR_EACH_4(MACRO, __VA_ARGS__)
12+
#define FOR_EACH_6(MACRO, X, ...) MACRO(X) FOR_EACH_5(MACRO, __VA_ARGS__)
1213

1314
#define FOR_EACH_NARG(...) FOR_EACH_NARG_(__VA_ARGS__, FOR_EACH_RSEQ_N())
1415
#define FOR_EACH_NARG_(...) FOR_EACH_ARG_N(__VA_ARGS__)
15-
#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, N, ...) N
16-
#define FOR_EACH_RSEQ_N() 5, 4, 3, 2, 1, 0
16+
#define FOR_EACH_ARG_N(_1, _2, _3, _4, _5, _6, N, ...) N
17+
#define FOR_EACH_RSEQ_N() 6, 5, 4, 3, 2, 1, 0
1718

1819
#define CONCATENATE(x, y) CONCATENATE1(x, y)
1920
#define CONCATENATE1(x, y) x##y
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# flagtree tle
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# flagtree tle
2+
from .distributed import (
3+
B,
4+
P,
5+
S,
6+
ShardedTensor,
7+
ShardingSpec,
8+
device_mesh,
9+
distributed_barrier,
10+
distributed_dot,
11+
make_sharded_tensor,
12+
remote,
13+
reshard,
14+
shard_id,
15+
sharding,
16+
)
17+
18+
from . import language
19+
20+
# try:
21+
# from . import raw
22+
# except ModuleNotFoundError:
23+
# raw = None
24+
25+
__all__ = [
26+
"device_mesh",
27+
"S",
28+
"P",
29+
"B",
30+
"sharding",
31+
"ShardingSpec",
32+
"ShardedTensor",
33+
"make_sharded_tensor",
34+
"reshard",
35+
"remote",
36+
"shard_id",
37+
"distributed_barrier",
38+
"distributed_dot",
39+
"language",
40+
]
41+
42+
# if raw is not None:
43+
# __all__.append("raw")

0 commit comments

Comments
 (0)