-
Notifications
You must be signed in to change notification settings - Fork 71
MatMul Broadcasting #197
Description
Describe the bug
Hello,
Thanks for this amazing project !
I was trying to run the sentence-transformers/all-MiniLM-L6-v2 model and I had the following error:
Can't create session:
GpuError(CompileError { node: "/encoder/layer.0/attention/self/query/MatMul",
error: UnimplementedVariant {
variant: "broadcasting for two stacks of matrixes (left side has shape 1x512x384:f32,
right side has shape 384x384:f32)",
op: "MatMul" }
})
It seems that you can't multiply 1x512x384 @ 384x384 = 1x512z384. I don't know if there is a need to match the same dimensions.
I understand that MatMul doesn't support shape inference and I don't know if this issue is related to this. I might be doing something wrong when preparing the model so could you help on this matter. all-MiniLM-L6-v2 is basically a bert model and I saw that Bert ran successfully using wonnx.
To Reproduce
- Transform
sentence-transformers/all-MiniLM-L6-v2to ONNX format using optimum lib - Ran
onnx-simplify:
python -m onnxsim model.onnx sim_model.onnx --overwrite-input-shape "input_ids:1,512" "attention_mask:1,512" "token_type_ids:1,512"Note I also ran
nnx prepare model.onnx model-prepared.onnx --set batch=1 --set sequence=512on another test but it didn't solve the issue
- Ran the following test:
use std::collections::HashMap;
use tokio_test;
use wonnx::{utils::InputTensor, Session};
#[test]
fn test_load_model() {
let session = tokio_test::block_on(Session::from_path(PATH_XXX.onnx).expect("Can't create session");
let mut input: HashMap<String, InputTensor> = HashMap::new();
let tokens = vec![1f32; 512];
let attention_mask = vec![1f32; 512];
let token_type_ids = vec![0f32; 512];
// For now ['input_ids', 'token_type_ids', 'attention_mask']
input.insert("input_ids".to_string(), tokens[..].into());
input.insert("attention_mask".to_string(), attention_mask[..].into());
input.insert("token_type_ids".to_string(), token_type_ids[..].into());
let output = tokio_test::block_on(session.run(&input)).unwrap();
assert!(output.len() > 1)Expected behavior
Run session inference
Desktop (please complete the following information):
- OS: MacOS 13.5.1
Thanks for your help, I'm also open to contributing if there is a fix to implement if you point me to the right direction 😄