-
Notifications
You must be signed in to change notification settings - Fork 663
Expand file tree
/
Copy pathTosaLegalizeUtils.cpp
More file actions
599 lines (512 loc) · 23.5 KB
/
TosaLegalizeUtils.cpp
File metadata and controls
599 lines (512 loc) · 23.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
namespace tosa {
Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
Operation *op, ArrayRef<int32_t> multipliers) {
if (scale32) {
return tosa::getConstTensor<int32_t>(
rewriter, op, multipliers,
{static_cast<int64_t>(multipliers.size())})
.value();
} else {
SmallVector<int16_t> vec(multipliers.begin(), multipliers.end());
return tosa::getConstTensor<int16_t>(rewriter, op, vec,
{static_cast<int64_t>(vec.size())})
.value();
}
}
// Create a TOSA rescale op from input framework tensor, zero points and
// rounding mode
Value buildRescale(PatternRewriter &rewriter, Operation *op,
ShapedType output_type, Value input_val, double scale,
int64_t input_zp, int64_t output_zp,
tosa::RoundingMode rounding_mode, bool scale32) {
int32_t multiplier;
int32_t shift;
int32_t scale_width = scale32 ? 32 : 16;
if (!computeMultiplierAndShift(scale, multiplier, shift, scale_width))
op->emitError("buildRescale: shift must be in the range 2 <= shift <= 62");
Value multiplier_val =
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
auto shift_val = tosa::getConstTensor<int8_t>(
rewriter, op, {static_cast<int8_t>(shift)}, {1})
.value();
bool input_unsigned = input_val.getType().isUnsignedInteger();
bool output_unsigned = output_type.isUnsignedInteger();
// Create input_zp matches the input type and output_zp matches the output
// type of RescaleOp
const auto input_zp_val = tosa::createZeroPointTensor(
rewriter, op->getLoc(), dyn_cast<TensorType>(input_val.getType()),
input_zp);
if (!input_zp_val.has_value())
op->emitError("Failed to create input zero-point tensor for RescaleOp.");
const auto output_zp_val = tosa::createZeroPointTensor(
rewriter, op->getLoc(), output_type, output_zp);
if (!output_zp_val.has_value())
op->emitError("Failed to create output zero-point tensor for RescaleOp.");
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val,
input_zp_val.value(), output_zp_val.value(),
rewriter.getBoolAttr(scale32),
tosa::RoundingModeAttr::get(rewriter.getContext(), rounding_mode),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
rewriter.getBoolAttr(output_unsigned));
return rescale_op.getResult();
}
// Creates TOSA rescale op with int32 output
Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
Value input_val, double input_scale,
int64_t input_zp) {
// Output is always int32 type
auto input_type = dyn_cast<mlir::ShapedType>(input_val.getType());
assert(input_type);
auto output_type = input_type.clone(rewriter.getI32Type());
return buildRescale(rewriter, op, output_type, input_val, input_scale,
input_zp, 0, tosa::RoundingMode::SINGLE_ROUND, true);
}
// Check if scale32 mode is used for given output_element_type
bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
return (output_element_type.getStorageTypeIntegralWidth() == 8);
}
// Create a 32-bit float constant operator from a float
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
float val) {
auto const_type = RankedTensorType::get({}, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, val);
auto const_op =
tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr);
return const_op.getResult();
}
// Create an int8_t const tosa.mul shift tensor from an int
Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op,
int32_t shift) {
auto shiftType = RankedTensorType::get({1}, rewriter.getIntegerType(8));
auto shiftAttr = DenseElementsAttr::get(
shiftType, rewriter.getIntegerAttr(rewriter.getIntegerType(8), shift));
auto constShift =
tosa::ConstOp::create(rewriter, op->getLoc(), shiftType, shiftAttr);
return constShift.getResult();
}
// Create a zero constant tensor of the desired type and shape.
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
Operation *op, Type type) {
RankedTensorType resultType = dyn_cast<RankedTensorType>(type);
if (!resultType) {
(void)rewriter.notifyMatchFailure(op, "not ranked tensor type");
return std::nullopt;
}
auto resultShape = resultType.getShape();
ShapedType zeroType =
RankedTensorType::get(resultShape, resultType.getElementType());
Attribute zeroAttr = rewriter.getZeroAttr(zeroType);
return CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), zeroType,
cast<ElementsAttr>(zeroAttr))
.getResult();
}
// Templated function to create a constant op for given type and shape.
// T: storage C type.
// Default template creates a constant tensor in T.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto width = sizeof(T) * 8;
if constexpr (std::is_same_v<T, bool>)
width = 1;
auto const_type =
RankedTensorType::get(shape, rewriter.getIntegerType(width));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr);
if (dtype) {
return tosa::tosaCastTensorToType(rewriter, const_op,
RankedTensorType::get(shape, *dtype))
.value();
}
return const_op.getResult();
}
// Template specialization for APInt
template <>
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op, ArrayRef<APInt> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto const_type = RankedTensorType::get(
shape, rewriter.getIntegerType(vec[0].getBitWidth()));
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr);
if (dtype) {
return tosa::tosaCastTensorToType(rewriter, const_op,
RankedTensorType::get(shape, *dtype))
.value();
}
return const_op.getResult();
}
// Template specialization for float
template <>
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op, ArrayRef<float> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}
if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return std::nullopt;
}
auto const_type = RankedTensorType::get(shape, rewriter.getF32Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);
auto const_op =
tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr);
if (dtype) {
return tosa::tosaCastTensorToType(rewriter, const_op,
RankedTensorType::get(shape, *dtype))
.value();
}
return const_op.getResult();
}
// Valid TOSA casting pairs according to TOSA spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_cast
// Note: currently TOSA doesn't support casting to and from I64 and F64
[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) {
// clang-format off
if ((src == dest) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(16)) ||
(src.isInteger(32) && dest.isInteger(8)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isF16()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isInteger(32)) ||
(src.isInteger(16) && dest.isInteger(8)) ||
(src.isInteger(16) && dest.isInteger(1)) ||
(src.isInteger(16) && dest.isBF16()) ||
(src.isInteger(16) && dest.isF32()) ||
(src.isInteger(16) && dest.isF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(32)) ||
(src.isInteger(8) && dest.isInteger(16)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
(src.isInteger(8) && dest.isF32()) ||
(src.isInteger(8) && dest.isF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(32)) ||
(src.isInteger(1) && dest.isInteger(16)) ||
(src.isInteger(1) && dest.isInteger(8)) ||
// f32 -> *
(src.isF32() && dest.isInteger(32)) ||
(src.isF32() && dest.isInteger(16)) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && isa<Float8E4M3Type>(dest)) ||
(src.isF32() && isa<Float8E5M2Type>(dest)) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && isa<Float8E4M3Type>(dest)) ||
(src.isF16() && isa<Float8E5M2Type>(dest)) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && isa<Float8E4M3Type>(dest)) ||
(src.isBF16() && isa<Float8E5M2Type>(dest)) ||
// fp8e4m3 -> *
(isa<Float8E4M3Type>(src) && dest.isBF16()) ||
(isa<Float8E4M3Type>(src) && dest.isF32()) ||
(isa<Float8E4M3Type>(src) && dest.isF16()) ||
// fp8e5m2 -> *
(isa<Float8E5M2Type>(src) && dest.isBF16()) ||
(isa<Float8E5M2Type>(src) && dest.isF32()) ||
(isa<Float8E5M2Type>(src) && dest.isF16())) {
return success();
}
// clang-format on
return failure();
}
// Default function to create tosa.cast op. This should be called instead of
// directly calling rewriter.create<tosa::CastOp>.
std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
TensorType destType) {
Operation *op = src.getDefiningOp();
TensorType srcType = dyn_cast<TensorType>(src.getType());
Type srcElemTy = srcType.getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
// Temporarily disable checkValidityOfCast as it's currently strictly
// following TOSA spec and might cause many e2e tests to fail. This is because
// even though there are some casting pairs that are not congruent to TOSA
// spec, they are still permissible. TOSA validation should flag these illegal
// constructs in a per-profile manner. This strict validity check will be
// enabled later in a potential `--strict` mode which checks for strict
// casting only when needed (the default value of `--strict` mode will be
// off).
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
// return std::nullopt;
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger(1)) {
// TOSA does not support casting from float->i1.
// In PyTorch the bool value will be True if any element is non-zero
Value zeroValue = *getConstTensor<float>(rewriter, op, 0.0f, {}, srcElemTy);
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return std::nullopt;
auto cmpTy = srcType.clone(rewriter.getIntegerType(1));
Value isEq =
tosa::EqualOp::create(rewriter, op->getLoc(), cmpTy, src, zeroValue);
return tosa::LogicalNotOp::create(rewriter, op->getLoc(),
srcType.clone(destElemTy), isEq);
}
if (srcElemTy.isInteger(1) && llvm::isa<FloatType>(destElemTy)) {
// TOSA does not support casting from i1->float.
// Instead, we cast to i8 and then to the float.
TensorType midType = srcType.clone(rewriter.getIntegerType(8));
Value mid = tosa::CastOp::create(rewriter, op->getLoc(), midType, src);
return tosa::CastOp::create(rewriter, op->getLoc(),
srcType.clone(destElemTy), mid);
}
if (srcElemTy == destElemTy)
return src;
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger() &&
!destElemTy.isInteger(1)) {
// For float->int conversion, tosa.cast performs round-to-nearest.
// PyTorch performs round-to-zero instead.
// Generate round-to-zero conversion prior to tosa.cast to match with
// expected torch behavior.
auto floor = tosa::FloorOp::create(rewriter, op->getLoc(), srcType, src);
auto ceil = tosa::CeilOp::create(rewriter, op->getLoc(), srcType, src);
auto zeroValue =
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return std::nullopt;
auto boolType = srcType.clone(rewriter.getIntegerType(1));
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, src);
src = tosa::CreateOpAndInfer<tosa::SelectOp>(
rewriter, op->getLoc(), srcType, isNegative, ceil, floor);
}
TensorType castedSrcType = srcType.clone(destElemTy);
return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}
Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input) {
auto inputTy = cast<RankedTensorType>(input.getType());
auto elemTy = inputTy.getElementType();
if (!(elemTy.isInteger(32) || elemTy.isInteger(64)))
return input;
auto castTy =
RankedTensorType::get(inputTy.getShape(), rewriter.getF32Type());
return tosa::CastOp::create(rewriter, op->getLoc(), castTy, input);
}
// Template instantiation
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int8_t>(PatternRewriter &, Operation *, ArrayRef<int8_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int16_t>(PatternRewriter &, Operation *, ArrayRef<int16_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int32_t>(PatternRewriter &, Operation *, ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int64_t>(PatternRewriter &, Operation *, ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType) {
auto inputTy = llvm::dyn_cast<ShapedType>(input.getType());
if (!inputTy)
return failure();
auto inputETy = inputTy.getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();
// Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time
// FP16 is supported, the accumulator type can be selected based on trade-off
// between performance and accuracy. Set to FP32 by default.
accType = isa<FloatType>(inputETy)
? mlir::TypeAttr::get(rewriter.getF32Type())
: mlir::TypeAttr::get(rewriter.getIntegerType(32));
return success();
}
// Get accumulator type for TOSA convolution ops
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType inputTy,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType) {
auto inputElemTy = inputTy.getElementType();
auto weightElemTy = weightTy.getElementType();
auto outputElemTy = outputTy.getElementType();
auto quantTy = dyn_cast<quant::QuantizedType>(inputElemTy);
if (quantTy)
inputElemTy = quantTy.getStorageType();
// Get TOSA conv ops acc type based on input, weight, and output types
// according to the spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
// https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d
//
// For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the
// output type but does not offer any guarantee on the numerical precision
// since such cases will fail TOSA validation.
if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) ||
(inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) ||
(inputElemTy.isBF16() && weightElemTy.isBF16() &&
outputElemTy.isBF16())) {
accType = mlir::TypeAttr::get(rewriter.getF32Type());
} else if (inputElemTy.isInteger(8) &&
(weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) &&
outputElemTy.isInteger(32)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(32));
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((isa<Float8E4M3Type>(inputElemTy) &&
isa<Float8E4M3Type>(weightElemTy) && outputElemTy.isF16()) ||
(isa<Float8E5M2Type>(inputElemTy) &&
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
}
return success();
}
FailureOr<Value> getConvBiasForNoneType(Operation *op,
PatternRewriter &rewriter,
Type inputElemTy, Type outputElemTy,
int64_t numOutputChannels) {
Type biasElemTy;
if (isa<quant::QuantizedType>(outputElemTy)) {
auto input_qtype = dyn_cast<mlir::quant::QuantizedType>(inputElemTy);
if (!input_qtype) {
return rewriter.notifyMatchFailure(op,
"output is qtype but input is not");
}
int input_bits = input_qtype.getStorageTypeIntegralWidth();
if (input_bits != 8) {
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt
// to define a 48-bit int.
return rewriter.notifyMatchFailure(
op, "Only int8 input tensor to conv2d is supported.");
}
// For signed int8 input tensor, int32 bias and output
// tensor are generated.
int bias_bits = 32;
biasElemTy = rewriter.getIntegerType(bias_bits);
} else {
biasElemTy = outputElemTy;
}
if (ShapedType::isDynamic(numOutputChannels))
return rewriter.notifyMatchFailure(
op, "cannot synthesize conv bias with dynamic output channels");
int32_t oc = static_cast<int32_t>(numOutputChannels);
if (biasElemTy.isInteger()) {
SmallVector<int32_t> zeroVec(oc, 0);
return tosa::getConstTensor<int32_t>(rewriter, op, zeroVec, {oc}).value();
} else {
SmallVector<float> zeroVec(oc, 0);
return tosa::getConstTensor<float>(rewriter, op, zeroVec, {oc}, biasElemTy)
.value();
}
}
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
Operation *op, Value inputNHWC,
ArrayRef<int64_t> padExtents) {
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");
if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
return inputNHWC;
SmallVector<int64_t, 8> nhwcPadding = {
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);
auto inputTy = dyn_cast<RankedTensorType>(inputNHWC.getType());
if (!inputTy)
return inputNHWC;
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
inputTy.getShape().end());
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
if (ShapedType::isDynamic(dim))
return ShapedType::kDynamic;
return dim + before + after;
};
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);
auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());
Type elemTy = inputTy.getElementType();
Value padConst;
if (isa<mlir::FloatType>(elemTy)) {
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
} else {
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
}
return tosa::PadOp::create(rewriter, loc, resultTy, inputNHWC, nhwcPadShape,
padConst)
.getResult();
}
FailureOr<Value> getZeroPointValue(PatternRewriter &rewriter, Operation *op,
Value tensor, Type elemType) {
Location loc = op->getLoc();
Value zp;
// Torch::getZeroPoint looks at the defining op of `tensor` to find
// the quantization parameters.
torch::Torch::getZeroPoint(tensor, zp);
if (!zp) {
// Initialize zero constant values as zero-points, if the input tensor isn't
// quantized
zp = tosa::createZeroPointTensor(rewriter, loc, elemType, 0).value();
} else {
int64_t zpConst;
if (!matchPattern(zp, torch::Torch::m_TorchConstantInt(&zpConst)))
return rewriter.notifyMatchFailure(
op, "zero point must be a scalar constant");
zp = tosa::createZeroPointTensor(rewriter, loc, elemType, zpConst).value();
}
return zp;
}
bool typeHasZeroDim(ShapedType type) {
auto outShape = type.getShape();
return llvm::any_of(outShape, [](int64_t dim) { return dim == 0; });
}
} // namespace tosa
} // namespace mlir