Skip to content

fix: trilu to use last two dims for N and M#1009

Open
sashass1315 wants to merge 1 commit into
zkonduit:mainfrom
sashass1315:tcck
Open

fix: trilu to use last two dims for N and M#1009
sashass1315 wants to merge 1 commit into
zkonduit:mainfrom
sashass1315:tcck

Conversation

@sashass1315

Copy link
Copy Markdown

Update trilu to derive N and M from the last two dimensions (len-2, len-1) instead of hard-coding a.dims()[1] and a.dims()[2]. This aligns with the ONNX Trilu spec ([, N, M]) and ensures correct behavior for inputs with zero or multiple batch dimensions. Previously, rank > 3 (or rank == 2) tensors produced incorrect masking because the loops assumed exactly one batch dimension. This change makes the implementation robust across all valid ranks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant