Skip to content

Commit b3cbf30

Browse files
committed
fix
1 parent 60251b1 commit b3cbf30

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

examples/quantization/matmul_a16wx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def __call__(self, k_size: int, n_size: int, src_ptr: void_p, dst_ptr: void_p):
8181

8282
g_src = self.global_view(src_ptr, dtype=self.b_dtype, shape=[k_size, n_size])
8383
r_src = self.load_global(
84-
g_src, offsets=[offset_k, offset_n], layout=self.tile_layout
84+
g_src, offsets=[offset_k, offset_n], shape=self.tile_layout.shape
8585
)
86+
self.annotate_layout(r_src, self.tile_layout)
8687
r_dst = self.view(r_src, layout=self.flatten_tile_layout, dtype=uint8)
8788
g_dst = self.global_view(
8889
dst_ptr,
@@ -127,8 +128,9 @@ def __call__(self, k_size: int, n_size: int, src_ptr: void_p, dst_ptr: void_p):
127128
r_src = self.load_global(
128129
g_src,
129130
offsets=[self.blockIdx.x, self.blockIdx.y, 0],
130-
layout=self.flatten_tile_layout,
131+
shape=self.flatten_tile_layout.shape,
131132
)
133+
self.annotate_layout(r_src, self.flatten_tile_layout)
132134
r_dst = self.view(r_src, layout=self.tile_layout, dtype=self.b_dtype)
133135
g_dst = self.global_view(dst_ptr, dtype=self.b_dtype, shape=[k_size, n_size])
134136
self.store_global(g_dst, r_dst, offsets=[offset_k, offset_n], dims=[2])

python/tilus/lang/transpiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,10 @@ def visit_Call(self, expr: ast.Call) -> Any:
689689
raise TilusProgramError(self, expr, str(e))
690690
else:
691691
# case 4
692-
ret = func(*args, **kwargs)
692+
try:
693+
ret = func(*args, **kwargs)
694+
except TypeError as e:
695+
raise TilusProgramError(self, expr, str(e)) from e
693696
elif isinstance(func, types.FunctionType):
694697
# case 4
695698
ret = func(*args, **kwargs)

tests/instructions/test_where.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ def __call__(self, m: int32, n: int32, cond_ptr: ~boolean, x_ptr: ~int32, y_ptr:
4141
gy = self.global_view(y_ptr, dtype=int32, shape=(m, n))
4242
go = self.global_view(out_ptr, dtype=int32, shape=(m, n))
4343

44-
rc = self.load_global(gc, offsets=[m_offset, n_offset], layout=self.layout)
44+
rc = self.load_global(gc, offsets=[m_offset, n_offset], shape=self.layout.shape)
4545
rx = self.load_global(gx, offsets=[m_offset, n_offset], shape=[self.block_m, self.block_n])
4646
ry = self.load_global(gy, offsets=[m_offset, n_offset], shape=[self.block_m, self.block_n])
4747
ro = self.where(rc, rx, ry)
4848

4949
self.store_global(go, ro, offsets=[m_offset, n_offset])
5050

51+
self.annotate_layout(rc, self.layout)
52+
5153

5254
@pytest.mark.parametrize("m, n, layout", [[16, 16, spatial(4, 8)], [128, 128, local(2, 2).spatial(4, 8).local(2, 2)]])
5355
def test_where(
@@ -65,3 +67,7 @@ def test_where(
6567
kernel(m, n, cond, x, y, actual)
6668

6769
assert torch.allclose(actual, expected), f"Failed for layout {layout} with m={m}, n={n}"
70+
71+
72+
if __name__ == "__main__":
73+
pytest.main([__file__])

0 commit comments

Comments
 (0)