Skip to content

Commit 943f149

Browse files
authored
feat(default.py): support ViT (#145)
* feat(default.py): add support ViT opr * docs(torch.py): remove useless
1 parent b107203 commit 943f149

File tree

2 files changed

+76
-29
lines changed

2 files changed

+76
-29
lines changed

ppq/executor/op/torch/default.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,35 @@ def Mul_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCont
333333
return multiplicand * multiplier
334334

335335

336+
def MultiHeadAttention_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
337+
if len(values) != 11:
338+
raise NotImplementedError('Not implement simplified MultiHeadAttention')
339+
340+
q,k,v,q_w,q_b,k_w,k_b,v_w,v_b,o_w,o_b = values
341+
embed_dim = op.attributes.get('embed_dim')
342+
num_heads = op.attributes.get('num_heads')
343+
344+
if embed_dim is None or num_heads is None:
345+
raise ValueError('Cannot fetch embed_dim or num_heads')
346+
347+
# setup parameters
348+
batch_size = q.shape[0]
349+
head_dim = embed_dim // num_heads
350+
scale = head_dim ** -0.5
351+
352+
q = F.linear(q, q_w, q_b)
353+
k = F.linear(k, k_w, k_b)
354+
v = F.linear(v, v_w, v_b)
355+
356+
energy = (q @ k.transpose(-2, -1)) * scale
357+
attn = energy.softmax(dim=-1)
358+
359+
x = (attn @ v).transpose(1, 2).reshape(batch_size, -1, embed_dim)
360+
x = F.linear(x, o_w, o_b)
361+
362+
return x
363+
364+
336365
def Add_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
337366
"""Performs element-wise binary addition (with Numpy-style broadcasting
338367
support).
@@ -786,6 +815,9 @@ def GatherND_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
786815
reshaped_output = output.reshape(*shape_i, *shape_j, *shape_k)
787816
return output
788817

818+
def Gelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
819+
[input_value] = values
820+
return F.gelu(input_value)
789821

790822
def Greater_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
791823
input_a, input_b = values
@@ -1436,7 +1468,7 @@ def Split_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendCo
14361468
split = op.attributes.get('split', 0)
14371469
[input_value] = values
14381470
if 'split' not in op.attributes:
1439-
split = input_value.shape[axis] // len(op.outputs)
1471+
split = input_value.shape[axis] // len(op.outputs)
14401472
outputs = torch.split(input_value, split, axis)
14411473
return outputs
14421474

@@ -1525,6 +1557,18 @@ def LeakyRelu_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacke
15251557
return output
15261558

15271559

1560+
def LayerNorm_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs):
1561+
if len(values) != 3:
1562+
raise ValueError('Unsupported LayerNorm without affine')
1563+
1564+
input_data, weight, bias = values
1565+
eps = op.attributes.get('epsilon', 1e-5)
1566+
normalized_shape = weight.shape
1567+
1568+
output = F.layer_norm(input_data, normalized_shape, weight, bias, eps)
1569+
return output
1570+
1571+
15281572
def Pad_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs):
15291573
mode = op.attributes.get('mode', 'constant')
15301574
input_data = values[0]
@@ -2118,20 +2162,20 @@ def Identity_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBacken
21182162
return values[0]
21192163

21202164
def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
2121-
"""
2122-
Produces a one-hot tensor based on inputs. The locations represented by the index values in the 'indices'
2123-
input tensor will have 'on_value' and the other locations will have 'off_value' in the output tensor,
2124-
2125-
where 'on_value' and 'off_value' are specified as part of required input argument 'values',
2126-
which is a two-element tensor of format [off_value, on_value].
2127-
2128-
The rank of the output tensor will be one greater than the rank of the input tensor.
2129-
The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
2130-
If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
2131-
i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
2132-
2133-
The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
2134-
input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
2165+
"""Produces a one-hot tensor based on inputs. The locations represented by
2166+
the index values in the 'indices' input tensor will have 'on_value' and the
2167+
other locations will have 'off_value' in the output tensor,
2168+
2169+
where 'on_value' and 'off_value' are specified as part of required input argument 'values',
2170+
which is a two-element tensor of format [off_value, on_value].
2171+
2172+
The rank of the output tensor will be one greater than the rank of the input tensor.
2173+
The additional dimension is for one-hot representation. The additional dimension will be inserted at the position specified by 'axis'.
2174+
If 'axis' is not specified then then additional dimension will be inserted as the innermost dimension,
2175+
i.e. axis=-1. The size of the additional dimension is specified by required scalar input 'depth'.
2176+
2177+
The type of the output tensor is the same as the type of the 'values' input. Any entries in the 'indices'
2178+
input tensor with values outside the range [-depth, depth-1] will result in one-hot representation
21352179
with all 'off_value' values in the output tensor.
21362180
21372181
when axis = 0:
@@ -2144,30 +2188,30 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
21442188
21452189
Attributes
21462190
axis : int (default is -1)
2147-
(Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
2148-
the additional dimension will be inserted as the innermost/last dimension in the output tensor.
2191+
(Optional) Axis along which one-hot representation in added. Default: axis=-1. axis=-1 means that
2192+
the additional dimension will be inserted as the innermost/last dimension in the output tensor.
21492193
Negative value means counting dimensions from the back. Accepted range is [-r-1, r] where r = rank(indices).
2150-
2194+
21512195
Inputs
21522196
indices (non-differentiable) : T1
21532197
Input tensor containing indices. Any entries in the 'indices' input tensor with values outside the range [-depth, depth-1]
2154-
will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
2198+
will result in one-hot representation with all 'off_value' values in the output tensor.In case 'indices' is of non-integer type,
21552199
the values will be casted to int64 before use.
2156-
2200+
21572201
depth (non-differentiable) : T2
2158-
Scalar specifying the number of classes in one-hot tensor.
2202+
Scalar specifying the number of classes in one-hot tensor.
21592203
This is also the size of the one-hot dimension (specified by 'axis' attribute) added on in the output tensor.
2160-
The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
2204+
The values in the 'indices' input tensor are expected to be in the range [-depth, depth-1].
21612205
In case 'depth' is of non-integer type, it will be casted to int64 before use.
21622206
21632207
values (non-differentiable) : T3
2164-
Rank 1 tensor containing exactly two elements,
2165-
in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
2208+
Rank 1 tensor containing exactly two elements,
2209+
in the format [off_value, on_value], where 'on_value' is the value used for filling locations specified in 'indices' input tensor,
21662210
and 'off_value' is the value used for filling locations other than those specified in 'indices' input tensor.
21672211
21682212
Outputs
21692213
output (non-differentiable) : T3
2170-
Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
2214+
Tensor of rank one greater than input tensor 'indices', i.e. rank(output) = rank(indices) + 1.
21712215
The data type for the elements of the output tensor is the same as the type of input 'values' is used.
21722216
"""
21732217
# implementation from https://github.com/ToriML/onnx2pytorch/blob/master/onnx2pytorch/operations/onehot.py
@@ -2187,10 +2231,10 @@ def Onehot_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendC
21872231
order.insert(axis, -1)
21882232
out = out.permute(order)
21892233
return out
2190-
2234+
21912235
def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor:
21922236
"""
2193-
Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
2237+
Reciprocal takes one input data (Tensor) and produces one output data (Tensor) where the reciprocal is,
21942238
y = 1/x, is applied to the tensor elementwise.
21952239
21962240
Version
@@ -2231,18 +2275,21 @@ def Reciprocal_forward(op: Operation, values: List[torch.Tensor], ctx: TorchBack
22312275
'Gather': Gather_forward,
22322276
'GatherElements': Gather_forward,
22332277
'GatherND': GatherND_forward,
2278+
'Gelu': Gelu_forward,
22342279
'Gemm': Gemm_forward,
22352280
'grid_sampler': Grid_sampler_forward,
22362281
'GlobalAveragePool': AveragePool_forward,
22372282
'GlobalMaxPool': MaxPool2d_forward,
22382283
'Greater': Greater_forward,
2284+
'LayerNorm': LayerNorm_forward,
22392285
'LeakyRelu': LeakyRelu_forward,
22402286
'Less': Less_forward,
22412287
'MatMul': MatMul_forward,
22422288
'Max': Eltwise_forward,
22432289
'MaxPool': MaxPool2d_forward,
22442290
'Min': Eltwise_forward,
22452291
'Mul': Mul_forward,
2292+
'MultiHeadAttention': MultiHeadAttention_forward,
22462293
'NonMaxSuppression': _NMS_forward,
22472294
'NonZero': NonZero_forward,
22482295
'Not': Not_forward,

ppq/quantization/quantizer/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def quantize_operations(
110110
operation_platforms[op_name] = self.target_platform
111111
else: operation_platforms[op_name] = self.default_platform
112112

113-
# maunnl override.
113+
# manual override.
114114
if op_name in operation_platforms:
115115
operation.platform = operation_platforms[op_name]
116116

117117
# build operation_quantization_configs
118-
# every quantable op MUST have a quantization config
118+
# every quantizable op MUST have a quantization config
119119
# if operation.type is listed in quantable_operation_types while a operation_quantization_configs is given
120120
# it will override the setting of quantable_operation_types
121121
for op_name, operation in self._graph.operations.items():

0 commit comments

Comments
 (0)