-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[WebGPU EP] Add EINSUM implementation #24358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU EP] Add EINSUM implementation #24358
Conversation
ce7ce7f
to
62da03b
Compare
@satyajandhyala @xiaofeihan1 @qjia7 @guschmue @fs-eire, pls help to reiview, thanks. |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds the native Einsum implementation for the WebGPU provider, enhancing operator support by integrating a new kernel alongside its corresponding tests. Key changes include:
- Adding a new test case in onnxruntime/test/providers/cpu/math/einsum_test.cc for explicit Einsum reduction with multi-input.
- Enabling the Einsum kernel registration in onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc.
- Introducing the Einsum operator implementation in onnxruntime/core/providers/webgpu/math/einsum.h.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
onnxruntime/test/providers/cpu/math/einsum_test.cc | Added a test case for explicit Einsum reduction to scalar with multi-input. |
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Enabled the Einsum kernel by uncommenting the registration entry. |
onnxruntime/core/providers/webgpu/math/einsum.h | Introduced the new Einsum operator abstraction and related utility classes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
### Description <!-- Describe your changes. --> This PR added the native implementation of einsum operator, based and expanded on existing einsum.ts. All the test cases in einsum_test.cc have been passed. The equation attribute value of einsum op is a string which consists of left hand side (LHS) and optionally right hand side (RHS) separated by '->'. Ex. - "ij->ji" matrix transpose - "ii->i" diagonal elements of a square matrix - "ij->" sum over all elements of a matrix - "ij,jk->ik" explicit matrix multiplication - "ij,jk" implicit matrix multiplication - "ij,jk->" matrix multiplication and sum over all elements - "ij,jk,kl->il" three matrix multiplication - "...ij,...jk->...ik" batched matmul with broadcasting - ",...i->...i" matrix element multiplication with one scalar - "abc,cd->abc" keep the original abc matrix shape but matmul and sum over along d LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable. Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to 'Z' or '...' to represent arbitrary dimensions or empty to represent a scalar. Empty RHS are handleed differently for implicit vs explicit modes. - Implicit mode - arrow is not in the equation where the equation "ij,jk" equals to "ij,jk->ik" which is actually a matrix multiplication. - Explicit mode - arrow is in the equation where the equation "ij,jk->" contains two steps, first step is a matrix multiplication just like the implicit mode, and the second step is to sum up the matrix produced by the first step to a scalar. For all the test cases, pls refer to einsum_test.cc
Description
This PR added the native implementation of einsum operator, based and expanded on existing einsum.ts. All the test cases in einsum_test.cc have been passed.
The equation attribute value of einsum op is a string which consists of left hand side (LHS) and optionally right hand side (RHS) separated by '->'. Ex.
LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
'Z' or '...' to represent arbitrary dimensions or empty to represent a scalar.
Empty RHS are handleed differently for implicit vs explicit modes.
For all the test cases, pls refer to einsum_test.cc