Skip to content

Commit 3e6fc2c

Browse files
authored
[TLE] Move TLE distributed API under language namespace (#448)
* tle: move distributed APIs under language for v3.3.x * tle: tolerate ImportError when optional raw module is unavailable
1 parent fbaed78 commit 3e6fc2c

File tree

4 files changed

+44
-44
lines changed

4 files changed

+44
-44
lines changed
Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,14 @@
11
# flagtree tle
2-
from .distributed import (
3-
B,
4-
P,
5-
S,
6-
ShardedTensor,
7-
ShardingSpec,
8-
device_mesh,
9-
distributed_barrier,
10-
distributed_dot,
11-
make_sharded_tensor,
12-
remote,
13-
reshard,
14-
shard_id,
15-
sharding,
16-
)
17-
182
from . import language
193

20-
# try:
21-
# from . import raw
22-
# except ModuleNotFoundError:
23-
# raw = None
4+
try:
5+
from . import raw
6+
except (ModuleNotFoundError, ImportError):
7+
raw = None
248

259
__all__ = [
26-
"device_mesh",
27-
"S",
28-
"P",
29-
"B",
30-
"sharding",
31-
"ShardingSpec",
32-
"ShardedTensor",
33-
"make_sharded_tensor",
34-
"reshard",
35-
"remote",
36-
"shard_id",
37-
"distributed_barrier",
38-
"distributed_dot",
3910
"language",
4011
]
4112

42-
# if raw is not None:
43-
# __all__.append("raw")
13+
if raw is not None:
14+
__all__.append("raw")
Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,39 @@
11
# flagtree tle
22
from .core import (
33
load, )
4+
from .distributed import (
5+
B,
6+
P,
7+
S,
8+
ShardedTensor,
9+
ShardingSpec,
10+
device_mesh,
11+
distributed_barrier,
12+
distributed_dot,
13+
make_sharded_tensor,
14+
remote,
15+
reshard,
16+
shard_id,
17+
sharding,
18+
)
419

520
__all__ = [
621
"load",
22+
"device_mesh",
23+
"S",
24+
"P",
25+
"B",
26+
"sharding",
27+
"ShardingSpec",
28+
"ShardedTensor",
29+
"make_sharded_tensor",
30+
"reshard",
31+
"remote",
32+
"shard_id",
33+
"distributed_barrier",
34+
"distributed_dot",
35+
"distributed",
736
"dsa",
837
]
938

10-
from . import dsa
39+
from . import distributed, dsa

python/triton/experimental/tle/distributed.py renamed to python/triton/experimental/tle/language/distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def remote(
711711
712712
Supported input:
713713
- tle buffered_tensor: returns a remote-marked buffered tensor; caller
714-
should then use `tleg.local_ptr(...)` to materialize remote pointers.
714+
should then use `tle.dsa.local_ptr(...)` to materialize remote pointers.
715715
716716
`shard_id` is the target block id inside the current thread block cluster.
717717
When `scope` is provided, launch cluster dimensions are inferred from that
@@ -730,7 +730,7 @@ def remote(
730730
if (hasattr(tensor, "_tle_remote_shard_id") or hasattr(tensor, "_tle_remote_scope")
731731
or hasattr(tensor.type, "_tle_remote_shard_id") or hasattr(tensor.type, "_tle_remote_scope")):
732732
raise ValueError("remote(buffered_tensor, ...) cannot be applied twice; "
733-
"materialize pointer views with tleg.local_ptr(remote_buffer, indices)")
733+
"materialize pointer views with tle.dsa.local_ptr(remote_buffer, indices)")
734734
if isinstance(shard_id, (int, tuple, list)):
735735
shard_id = _normalize_compile_time_remote_shard_id(shard_id, scope)
736736
else:

third_party/tsingmicro/examples/tle/test_tle_dsa_noc_gemm_4096.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import triton
44
import triton.language as tl
5-
from triton.experimental import tle
5+
import triton.experimental.tle.language as tle
66

77
TILE_NUM = 16
88
M = 4096
@@ -52,17 +52,17 @@ def dsa_shift_n_gemm_kernel(
5252
b_ptrs = B_ptr + offs_k[:, None] * N + offs_sub_n[None, :]
5353
b_init = tl.load(b_ptrs)
5454

55-
send_buf = tle.language.dsa.alloc((BLOCK_K, SUB_N), tl.float16)
56-
recv_buf = tle.language.dsa.alloc((BLOCK_K, SUB_N), tl.float16)
55+
send_buf = tle.dsa.alloc((BLOCK_K, SUB_N), tl.float16)
56+
recv_buf = tle.dsa.alloc((BLOCK_K, SUB_N), tl.float16)
5757

5858
offs_buf_k = tl.arange(0, BLOCK_K)[:, None] + tl.zeros((1, SUB_N), dtype=tl.int32)
5959
offs_buf_n = tl.arange(0, SUB_N)[None, :] + tl.zeros((BLOCK_K, 1), dtype=tl.int32)
6060

61-
send_ptr = tle.language.dsa.local_ptr(send_buf, [offs_buf_k, offs_buf_n])
62-
recv_ptr = tle.language.dsa.local_ptr(recv_buf, [offs_buf_k, offs_buf_n])
61+
send_ptr = tle.dsa.local_ptr(send_buf, [offs_buf_k, offs_buf_n])
62+
recv_ptr = tle.dsa.local_ptr(recv_buf, [offs_buf_k, offs_buf_n])
6363

6464
remote_recv_buf = tle.remote(recv_buf, send_next_tile)
65-
remote_recv_ptr = tle.language.dsa.local_ptr(remote_recv_buf, [offs_buf_k, offs_buf_n])
65+
remote_recv_ptr = tle.dsa.local_ptr(remote_recv_buf, [offs_buf_k, offs_buf_n])
6666

6767
tl.store(send_ptr, b_init)
6868

0 commit comments

Comments
 (0)