@@ -211,6 +211,7 @@ def create_self_attention_node(
211211 ) -> NodeProto :
212212 # all_inputs are (B, N, S, H)
213213 if self .enable_packed_qkv :
214+ # Implement Stack via Unsqueeze+Concat
214215 unsqueeze_q_node_name = self .model .create_node_name ("Unsqueeze" )
215216 unsqueeze_k_node_name = self .model .create_node_name ("Unsqueeze" )
216217 unsqueeze_v_node_name = self .model .create_node_name ("Unsqueeze" )
@@ -297,8 +298,98 @@ def create_self_attention_node(
297298 )
298299
299300 return attention_node
300- else : # Not packed
301- raise NotImplementedError ("Unpacked QKV self-attention not implemented." )
301+ else : # Not packed. (CPU-compatible)
302+ # Transpose nodes: (B, N, S, H) -> (B, S, N, H)
303+ transpose_q_node_name = self .model .create_node_name ("Transpose" )
304+ transpose_k_node_name = self .model .create_node_name ("Transpose" )
305+ transpose_v_node_name = self .model .create_node_name ("Transpose" )
306+ transpose_q_node = helper .make_node (
307+ "Transpose" ,
308+ inputs = [matmul_q .output [0 ]],
309+ outputs = [transpose_q_node_name + "_out" ],
310+ name = transpose_q_node_name ,
311+ perm = [0 , 2 , 1 , 3 ],
312+ )
313+ self .node_name_to_graph_name [transpose_q_node .name ] = self .this_graph_name
314+ transpose_k_node = helper .make_node (
315+ "Transpose" ,
316+ inputs = [matmul_k .output [0 ]],
317+ outputs = [transpose_k_node_name + "_out" ],
318+ name = transpose_k_node_name ,
319+ perm = [0 , 2 , 1 , 3 ],
320+ )
321+ self .node_name_to_graph_name [transpose_k_node .name ] = self .this_graph_name
322+ transpose_v_node = helper .make_node (
323+ "Transpose" ,
324+ inputs = [matmul_v .output [0 ]],
325+ outputs = [transpose_v_node_name + "_out" ],
326+ name = transpose_v_node_name ,
327+ perm = [0 , 2 , 1 , 3 ],
328+ )
329+ self .node_name_to_graph_name [transpose_v_node .name ] = self .this_graph_name
330+
331+ # Reshape nodes: (B, S, N, H) -> (B, S, NH)
332+ reshape_q_node_name = self .model .create_node_name ("Reshape" )
333+ reshape_k_node_name = self .model .create_node_name ("Reshape" )
334+ reshape_v_node_name = self .model .create_node_name ("Reshape" )
335+ for n in (reshape_q_node_name , reshape_k_node_name , reshape_v_node_name ):
336+ self .add_initializer (
337+ name = n + "_shape" ,
338+ data_type = TensorProto .INT64 ,
339+ dims = [3 ],
340+ vals = [0 , 0 , hidden_size ],
341+ raw = False ,
342+ )
343+ reshape_q_node = helper .make_node (
344+ "Reshape" ,
345+ inputs = [transpose_q_node_name + "_out" , reshape_q_node_name + "_shape" ],
346+ outputs = [reshape_q_node_name + "_out" ],
347+ name = reshape_q_node_name ,
348+ )
349+ self .node_name_to_graph_name [reshape_q_node .name ] = self .this_graph_name
350+ reshape_k_node = helper .make_node (
351+ "Reshape" ,
352+ inputs = [transpose_k_node_name + "_out" , reshape_k_node_name + "_shape" ],
353+ outputs = [reshape_k_node_name + "_out" ],
354+ name = reshape_k_node_name ,
355+ )
356+ self .node_name_to_graph_name [reshape_k_node .name ] = self .this_graph_name
357+ reshape_v_node = helper .make_node (
358+ "Reshape" ,
359+ inputs = [transpose_v_node_name + "_out" , reshape_v_node_name + "_shape" ],
360+ outputs = [reshape_v_node_name + "_out" ],
361+ name = reshape_v_node_name ,
362+ )
363+ self .node_name_to_graph_name [reshape_v_node .name ] = self .this_graph_name
364+
365+ self .nodes_to_add .extend (
366+ [
367+ transpose_q_node ,
368+ transpose_k_node ,
369+ transpose_v_node ,
370+ reshape_q_node ,
371+ reshape_k_node ,
372+ reshape_v_node ,
373+ ]
374+ )
375+
376+ attention_inputs = [
377+ reshape_q_node_name + "_out" ,
378+ reshape_k_node_name + "_out" ,
379+ reshape_v_node_name + "_out" ,
380+ ]
381+
382+ attention_node_name = self .model .create_node_name ("MultiHeadAttention" )
383+ attention_node = helper .make_node (
384+ "MultiHeadAttention" ,
385+ inputs = attention_inputs ,
386+ outputs = [output ],
387+ name = attention_node_name ,
388+ domain = "com.microsoft" ,
389+ num_heads = num_heads ,
390+ )
391+
392+ return attention_node
302393
303394 def create_cross_attention_node (
304395 self ,
0 commit comments