[PT FE] Added support for torch_scatter::segment_mean_csr #33457
+235
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #29734
Details
Implemented PyTorch Frontend translator for
torch_scatter::segment_mean_csrusing a prefix-sum + gather boundaries approach (no unrolling over segments).The segmentation axis is computed dynamically as
axis = indptr.dim() - 1(matches torch_scatter semantics).Added handling for batch/broadcast-style indptr (when
axis > 0) via bidirectional broadcast to[src_batch_dims..., indptr_last_dim].Uses i64 indices for
indptr/shape operations to avoid overflow and to match Gather-supported index types.Handles the optional
outargument viacontext.mutate_input(out_idx, result)(if provided).Added a regression test for TorchScript that covers:
indptrwith shape(1, n)(segments along dim 1)indptr(segments along dim 0)Limitations / Notes
srcandindptr(validated).indptr[i+1] == indptr[i]) are handled by clamping length to 1 to avoid division by zero (result becomes 0 for empty segments).torch_scatteris not installed.How to test
PyTorch layer test:
python -m pytest tests/layer_tests/pytorch_tests/test_segment_mean_csr.py -k segment_mean_csr -vTickets