diff --git a/CHANGELOG.md b/CHANGELOG.md index 59aaa5fd..4e598786 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed incorrect reduction condition in `fused_scatter_reduce` Triton kernel + ### Security ## [0.6.0] - 2026-03-24 diff --git a/pyg_lib/ops/scatter_reduce.py b/pyg_lib/ops/scatter_reduce.py index 0eea3b43..e0a31a5f 100644 --- a/pyg_lib/ops/scatter_reduce.py +++ b/pyg_lib/ops/scatter_reduce.py @@ -60,9 +60,9 @@ def _fused_scatter_reduce_forward_kernel( tl.atomic_add(out_ptr + out_offsets, inputs, mask=mask) elif REDUCE1 == 2: # mean tl.atomic_add(out_ptr + out_offsets, inputs, mask=mask) - elif REDUCE2 == 3: # min + elif REDUCE1 == 3: # min tl.atomic_min(out_ptr + out_offsets, inputs, mask=mask) - elif REDUCE3 == 4: # max + elif REDUCE1 == 4: # max tl.atomic_max(out_ptr + out_offsets, inputs, mask=mask) if REDUCE2 > 0: