[Bugfix] Update truncation from i32 to i8 when storing to i1#10011
[Bugfix] Update truncation from i32 to i8 when storing to i1#10011greatofdream wants to merge 3 commits intotriton-lang:mainfrom
Conversation
496111c to
0d14270
Compare
d67dcb5 to
ca37258
Compare
| elt_ty = tl.int8 | ||
| ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) | ||
| ptr = self.cast(ptr, ptr_ty) | ||
| val = self.cast(val, tl.int1) |
There was a problem hiding this comment.
I'm not convinced this is the right thing to do. The downcasting happens when the size of the element in memory is smaller the element type but booleans are stored as 8 bits value.
I don't think it makes sense to do extra masking of the higher bits here.
Do you have a real life example where this is a problem?
There was a problem hiding this comment.
The downcasting happens when the size of the element in memory is smaller the element type but booleans are stored as 8 bits value.
Thanks for your feedback. The booleans are stored as 8 bits value but accumulator is i32. There exist a truncation when storing the booleans from i32.
I don't think it makes sense to do extra masking of the higher bits here.
For example, when there exist 256 True value in a boolean Tensor, tl.sum calculated as 256 in i32 accumulator and truncate it as 0 after arith.trunci as i8.
Do you have a real life example where this is a problem?
There exist an example in #9991 , which compared the results between torch and triton kernel. I extracted TTIR from the triton kernel and I pasted the part of load and store operation in the following:
loadoperation cast thei8asi1to avoid wrong truncation. The IR is as following:
%x_27 = tt.load %x_26 : tensor<128x2x!tt.ptr<i8>> loc(#loc20)
%x_28 = arith.cmpi ne, %x_27, %x : tensor<128x2xi8> loc(#loc20)
%input = arith.extui %x_28 : tensor<128x2xi1> to tensor<128x2xi32> loc(#loc30)
- However,
storeoperation does nothing. The IR is as following:
%0 = tt.bitcast %out_ptr0 : !tt.ptr<i1> -> !tt.ptr<i8> loc(#loc15)
%1 = arith.trunci %ret_29 : i32 to i8 loc(#loc15)
tt.store %0, %1 : !tt.ptr<i8> loc(#loc15)
There was a problem hiding this comment.
Do you have a real life example where this is a problem?
@ThomasRaoux I wrote a test case in python/test/unit/language/test_core.py , which checks the summation of 32*32 True values. It should returns True when output is a boolean Tensor.
There was a problem hiding this comment.
I think the fact that int1 pointers are loaded/stored as int8 tensors, even though the language has native int1 is very unexpected behavior and a bit of a gotcha. It would be nice if we could fix it, but I'm not sure if it might inadvertently break someone by silently changing the type inference. It's a fairly scary change unfortunately :/
There was a problem hiding this comment.
I think the fact that int1 pointers are loaded/stored as int8 tensors, even though the language has native int1 is very unexpected behavior and a bit of a gotcha.
I believe using int8 as a container for int1 values comes is primarily a hardware-driven design choice. As long as the underlying boolean semantics are correctly preserved during load/store operations, treating the memory as an int8 tensor works fine, keeping the int8 layer is transparent to developer.
but I'm not sure if it might inadvertently break someone by silently changing the type inference. It's a fairly scary change unfortunately
This modification only injects an additional arith.cmpi when triton kernel contains tl.sum() with output dtype is boolean or dtype=tl.int1. arith.cmpi is lowered from val = self.cast(val, tl.int1).
Some old codes may already suffer this strange difference behavior with torch.sum for a long time.
There was a problem hiding this comment.
It's not transparent to the developer, it changes the semantics so that they're treated as an integer in the range [-128, 127] instead of as a boolean. That's the whole point of this PR, is it not?
There was a problem hiding this comment.
It's not transparent to the developer, it changes the semantics so that they're treated as an integer in the range [-128, 127] instead of as a boolean. That's the whole point of this PR, is it not?
A simple example:
@triton.jit
def triton_sum_bool(in_ptr0, out_ptr0, L: tl.constexpr):
idx = tl.arange(0, L)
x = tl.load(in_ptr0 + idx)
ret = tl.sum(x)
tl.store(out_ptr0, ret)- According to documentation,
ret = tl.sum(x)upcasts x as int32. - When
tl.store(out_ptr0, ret)as a boolean tensorout_ptr0, theretis firstly cast intoint8, and further store intoout.
The ret values are treated as an integer undoubtedly. But the implicit cast to i8 is a bit of strange.
In principle we could define the tl.store(i1) semantic as some dtype -> i8 -> i1 in triton. But it seems that this behavior is deviated from expectation in special cases.
To clarify, this PR does alter the overall type inference rules. Internally, the data flow handles the int32->i8->i1 conversion. The only distinction is the data flow changed as int32->i1->i8->i1. From an external perspective, the input/output type contracts are preserved, so existing type derivation remains unaffected.
tl.sum with boolean (i1) output leads to storing i32 to i8. However, using bitwise truncation return 0 when result is a multiple of 256. This PR fixes triton-lang#9991 by explictly cast the result as tl.int1 for boolean output tensor.
Check the boolean output for tl.sum(i1). When there exist 32*32 True, the output should not be wrongly cast as 0 (False).
e3f743c to
2e17c2f
Compare
tl.sum(i1, dtype=tl.int1) already using i1 accumulator. Therefore, there is no need to cast for this case. The accumulator result is val and this PR use val.type.scalar to determine whether to cast result to i1.
tl.sum with boolean output leads to storing
i32toi8. However, using bitwise truncation return 0 when result is a multiple of 256.This PR fixes #9991 by explictly cast the result as tl.int1 before storing result into boolean output tensor.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end teststoo simple?.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)