Skip to content

[Bugfix] Update truncation from i32 to i8 when storing to i1#10011

Open
greatofdream wants to merge 3 commits intotriton-lang:mainfrom
greatofdream:fix_bool_sum
Open

[Bugfix] Update truncation from i32 to i8 when storing to i1#10011
greatofdream wants to merge 3 commits intotriton-lang:mainfrom
greatofdream:fix_bool_sum

Conversation

@greatofdream
Copy link
Copy Markdown

@greatofdream greatofdream commented Apr 13, 2026

tl.sum with boolean output leads to storing i32 to i8. However, using bitwise truncation return 0 when result is a multiple of 256.

This PR fixes #9991 by explictly cast the result as tl.int1 before storing result into boolean output tensor.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because too simple?.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@greatofdream greatofdream requested a review from ptillet as a code owner April 13, 2026 06:32
@greatofdream greatofdream changed the title [Bugfix] Fix incorrect trunci i32 to i8 lowering when storing to tl.int1 (bool) [Bugfix] Fix incorrect truncation from i32 to i8 lowering when storing to i1 Apr 13, 2026
@greatofdream greatofdream force-pushed the fix_bool_sum branch 2 times, most recently from d67dcb5 to ca37258 Compare April 13, 2026 06:45
@greatofdream greatofdream changed the title [Bugfix] Fix incorrect truncation from i32 to i8 lowering when storing to i1 [Bugfix] Update truncation from i32 to i8 when storing to i1 Apr 13, 2026
Comment thread python/triton/language/semantic.py Outdated
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = self.cast(ptr, ptr_ty)
val = self.cast(val, tl.int1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this is the right thing to do. The downcasting happens when the size of the element in memory is smaller the element type but booleans are stored as 8 bits value.
I don't think it makes sense to do extra masking of the higher bits here.
Do you have a real life example where this is a problem?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downcasting happens when the size of the element in memory is smaller the element type but booleans are stored as 8 bits value.

Thanks for your feedback. The booleans are stored as 8 bits value but accumulator is i32. There exist a truncation when storing the booleans from i32.

I don't think it makes sense to do extra masking of the higher bits here.

For example, when there exist 256 True value in a boolean Tensor, tl.sum calculated as 256 in i32 accumulator and truncate it as 0 after arith.trunci as i8.

Do you have a real life example where this is a problem?

There exist an example in #9991 , which compared the results between torch and triton kernel. I extracted TTIR from the triton kernel and I pasted the part of load and store operation in the following:

    %x_27 = tt.load %x_26 : tensor<128x2x!tt.ptr<i8>> loc(#loc20)
    %x_28 = arith.cmpi ne, %x_27, %x : tensor<128x2xi8> loc(#loc20)
    %input = arith.extui %x_28 : tensor<128x2xi1> to tensor<128x2xi32> loc(#loc30)
  • However, store operation does nothing. The IR is as following:
    %0 = tt.bitcast %out_ptr0 : !tt.ptr<i1> -> !tt.ptr<i8> loc(#loc15)
    %1 = arith.trunci %ret_29 : i32 to i8 loc(#loc15)
    tt.store %0, %1 : !tt.ptr<i8> loc(#loc15)

Copy link
Copy Markdown
Author

@greatofdream greatofdream Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a real life example where this is a problem?

@ThomasRaoux I wrote a test case in python/test/unit/language/test_core.py , which checks the summation of 32*32 True values. It should returns True when output is a boolean Tensor.

Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fact that int1 pointers are loaded/stored as int8 tensors, even though the language has native int1 is very unexpected behavior and a bit of a gotcha. It would be nice if we could fix it, but I'm not sure if it might inadvertently break someone by silently changing the type inference. It's a fairly scary change unfortunately :/

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the fact that int1 pointers are loaded/stored as int8 tensors, even though the language has native int1 is very unexpected behavior and a bit of a gotcha.

I believe using int8 as a container for int1 values comes is primarily a hardware-driven design choice. As long as the underlying boolean semantics are correctly preserved during load/store operations, treating the memory as an int8 tensor works fine, keeping the int8 layer is transparent to developer.

but I'm not sure if it might inadvertently break someone by silently changing the type inference. It's a fairly scary change unfortunately

This modification only injects an additional arith.cmpi when triton kernel contains tl.sum() with output dtype is boolean or dtype=tl.int1. arith.cmpi is lowered from val = self.cast(val, tl.int1).
Some old codes may already suffer this strange difference behavior with torch.sum for a long time.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not transparent to the developer, it changes the semantics so that they're treated as an integer in the range [-128, 127] instead of as a boolean. That's the whole point of this PR, is it not?

Copy link
Copy Markdown
Author

@greatofdream greatofdream Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not transparent to the developer, it changes the semantics so that they're treated as an integer in the range [-128, 127] instead of as a boolean. That's the whole point of this PR, is it not?

A simple example:

@triton.jit
def triton_sum_bool(in_ptr0, out_ptr0, L: tl.constexpr):
    idx = tl.arange(0, L)
    x = tl.load(in_ptr0 + idx)
    ret = tl.sum(x)
    tl.store(out_ptr0, ret)
  • According to documentation, ret = tl.sum(x) upcasts x as int32.
  • When tl.store(out_ptr0, ret) as a boolean tensor out_ptr0, the ret is firstly cast into int8, and further store into out.

The ret values are treated as an integer undoubtedly. But the implicit cast to i8 is a bit of strange.

In principle we could define the tl.store(i1) semantic as some dtype -> i8 -> i1 in triton. But it seems that this behavior is deviated from expectation in special cases.

To clarify, this PR does alter the overall type inference rules. Internally, the data flow handles the int32->i8->i1 conversion. The only distinction is the data flow changed as int32->i1->i8->i1. From an external perspective, the input/output type contracts are preserved, so existing type derivation remains unaffected.

ZhangAiqiang added 2 commits April 14, 2026 22:08
tl.sum with boolean (i1) output leads to storing i32 to i8. However,
using bitwise truncation return 0 when result is a multiple of 256.

This PR fixes triton-lang#9991 by explictly cast the result as tl.int1 for
boolean output tensor.
Check the boolean output for tl.sum(i1). When there exist 32*32 True,
the output should not be wrongly cast as 0 (False).
tl.sum(i1, dtype=tl.int1) already using i1 accumulator. Therefore, there
is no need to cast for this case. The accumulator result is val and this
PR use val.type.scalar to determine whether to cast result to i1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tl.store(i1, i32) implicitly casts int32 as int8 lead unexpected behavior

3 participants