Skip to content

Commit 3116e02

Browse files
authored
pnnx drop sdap scale=None for compatiblity with old torch (#5107)
1 parent 14e14a9 commit 3116e02

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ pnnx.Output output 1 0 out
6565
{
6666
return "F.scaled_dot_product_attention";
6767
}
68+
69+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
70+
{
71+
GraphRewriterPass::write(op, captured_params, captured_attrs);
72+
73+
if (captured_params.at("scale").type == 0)
74+
{
75+
// drop scale=None for compatiblity with old torch
76+
op->params.erase("scale");
77+
}
78+
}
6879
};
6980

7081
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)

0 commit comments

Comments
 (0)