@@ -333,6 +333,35 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont
333333 return multiplicand * multiplier
334334
335335
336+ def MultiHeadAttention_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
337+ if len (values ) != 11 :
338+ raise NotImplementedError ('Not implement simplified MultiHeadAttention' )
339+
340+ q ,k ,v ,q_w ,q_b ,k_w ,k_b ,v_w ,v_b ,o_w ,o_b = values
341+ embed_dim = op .attributes .get ('embed_dim' )
342+ num_heads = op .attributes .get ('num_heads' )
343+
344+ if embed_dim is None or num_heads is None :
345+ raise ValueError ('Cannot fetch embed_dim or num_heads' )
346+
347+ # setup parameters
348+ batch_size = q .shape [0 ]
349+ head_dim = embed_dim // num_heads
350+ scale = head_dim ** - 0.5
351+
352+ q = F .linear (q , q_w , q_b )
353+ k = F .linear (k , k_w , k_b )
354+ v = F .linear (v , v_w , v_b )
355+
356+ energy = (q @ k .transpose (- 2 , - 1 )) * scale
357+ attn = energy .softmax (dim = - 1 )
358+
359+ x = (attn @ v ).transpose (1 , 2 ).reshape (batch_size , - 1 , embed_dim )
360+ x = F .linear (x , o_w , o_b )
361+
362+ return x
363+
364+
336365def Add_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
337366 """Performs element-wise binary addition (with Numpy-style broadcasting
338367 support).
@@ -786,6 +815,9 @@ def GatherND_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
786815 reshaped_output = output .reshape (* shape_i , * shape_j , * shape_k )
787816 return output
788817
818+ def Gelu_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
819+ [input_value ] = values
820+ return F .gelu (input_value )
789821
790822def Greater_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
791823 input_a , input_b = values
@@ -1436,7 +1468,7 @@ def Split_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCo
14361468 split = op .attributes .get ('split' , 0 )
14371469 [input_value ] = values
14381470 if 'split' not in op .attributes :
1439- split = input_value .shape [axis ] // len (op .outputs )
1471+ split = input_value .shape [axis ] // len (op .outputs )
14401472 outputs = torch .split (input_value , split , axis )
14411473 return outputs
14421474
@@ -1525,6 +1557,18 @@ def LeakyRelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke
15251557 return output
15261558
15271559
1560+ def LayerNorm_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ):
1561+ if len (values ) != 3 :
1562+ raise ValueError ('Unsupported LayerNorm without affine' )
1563+
1564+ input_data , weight , bias = values
1565+ eps = op .attributes .get ('epsilon' , 1e-5 )
1566+ normalized_shape = weight .shape
1567+
1568+ output = F .layer_norm (input_data , normalized_shape , weight , bias , eps )
1569+ return output
1570+
1571+
15281572def Pad_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ):
15291573 mode = op .attributes .get ('mode' , 'constant' )
15301574 input_data = values [0 ]
@@ -2118,20 +2162,20 @@ def Identity_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
21182162 return values [0 ]
21192163
21202164def Onehot_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
2121- """
2122- Produces a one-hot tensor based on inputs. The locations represented by the index values in the 'indices'
2123- input tensor will have 'on_value' and the other locations will have 'off_value' in the output tensor,
2124-
2125- where 'on_value' and 'off_value' are specified as part of required input argument 'values',
2126- which is a two-element tensor of format [off_value, on_value].
2127-
2128- The rank of the output tensor will be one greater than the rank of the input tensor.
2129- The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
2130- If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
2131- i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
2132-
2133- The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
2134- input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
2165+ """Produces a one-hot tensor based on inputs. The locations represented by
2166+ the index values in the 'indices' input tensor will have 'on_value' and the
2167+ other locations will have 'off_value' in the output tensor,
2168+
2169+ where 'on_value' and 'off_value' are specified as part of required input argument 'values',
2170+ which is a two-element tensor of format [off_value, on_value].
2171+
2172+ The rank of the output tensor will be one greater than the rank of the input tensor.
2173+ The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
2174+ If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
2175+ i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
2176+
2177+ The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
2178+ input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
21352179 with all 'off_value' values in the output tensor.
21362180
21372181 when axis = 0:
@@ -2144,30 +2188,30 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
21442188
21452189 Attributes
21462190 axis : int (default is -1)
2147- (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
2148- the additional dimension will be inserted as the innermost/last dimension in the output tensor.
2191+ (Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
2192+ the additional dimension will be inserted as the innermost/last dimension in the output tensor.
21492193 Negative value means counting dimensions from the back. Accepted range is [-r-1, r] where r = rank(indices).
2150-
2194+
21512195 Inputs
21522196 indices (non-differentiable) : T1
21532197 Input tensor containing indices. Any entries in the 'indices' input tensor with values outside the range [-depth, depth-1]
2154- will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
2198+ will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
21552199 the values will be casted to int64 before use.
2156-
2200+
21572201 depth (non-differentiable) : T2
2158- Scalar specifying the number of classes in one-hot tensor.
2202+ Scalar specifying the number of classes in one-hot tensor.
21592203 This is also the size of the one-hot dimension (specified by 'axis' attribute) added on in the output tensor.
2160- The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
2204+ The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
21612205 In case 'depth' is of non-integer type, it will be casted to int64 before use.
21622206
21632207 values (non-differentiable) : T3
2164- Rank 1 tensor containing exactly two elements,
2165- in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
2208+ Rank 1 tensor containing exactly two elements,
2209+ in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
21662210 and 'off_value' is the value used for filling locations other than those specified in 'indices' input tensor.
21672211
21682212 Outputs
21692213 output (non-differentiable) : T3
2170- Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
2214+ Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
21712215 The data type for the elements of the output tensor is the same as the type of input 'values' is used.
21722216 """
21732217 # implementation from https://github.com/ToriML/onnx2pytorch/blob/master/onnx2pytorch/operations/onehot.py
@@ -2187,10 +2231,10 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
21872231 order .insert (axis , - 1 )
21882232 out = out .permute (order )
21892233 return out
2190-
2234+
21912235def Reciprocal_forward (op : Operation , values : List [torch .Tensor ], ctx : TorchBackendContext = None , ** kwargs ) -> torch .Tensor :
21922236 """
2193- Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
2237+ Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
21942238 y = 1/x, is applied to the tensor elementwise.
21952239
21962240 Version
@@ -2231,18 +2275,21 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack
22312275 'Gather' : Gather_forward ,
22322276 'GatherElements' : Gather_forward ,
22332277 'GatherND' : GatherND_forward ,
2278+ 'Gelu' : Gelu_forward ,
22342279 'Gemm' : Gemm_forward ,
22352280 'grid_sampler' : Grid_sampler_forward ,
22362281 'GlobalAveragePool' : AveragePool_forward ,
22372282 'GlobalMaxPool' : MaxPool2d_forward ,
22382283 'Greater' : Greater_forward ,
2284+ 'LayerNorm' : LayerNorm_forward ,
22392285 'LeakyRelu' : LeakyRelu_forward ,
22402286 'Less' : Less_forward ,
22412287 'MatMul' : MatMul_forward ,
22422288 'Max' : Eltwise_forward ,
22432289 'MaxPool' : MaxPool2d_forward ,
22442290 'Min' : Eltwise_forward ,
22452291 'Mul' : Mul_forward ,
2292+ 'MultiHeadAttention' : MultiHeadAttention_forward ,
22462293 'NonMaxSuppression' : _NMS_forward ,
22472294 'NonZero' : NonZero_forward ,
22482295 'Not' : Not_forward ,
0 commit comments