Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMBuilder

LINK_LIBS PUBLIC
OMHasOnnxSubgraphOpInterface
OMMlirUtilities
OMONNXOps
OMResultTypeInferenceOpInterface
MLIRFuncDialect
Expand Down
56 changes: 3 additions & 53 deletions src/Builder/ModelInputShaper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"

#include "src/Support/TypeUtilities.hpp"

#include <algorithm>
#include <cstdlib>
#include <sstream>
Expand Down Expand Up @@ -53,59 +55,7 @@ ModelInputShaper::ModelInputShaper() : force_dim_dynamic_enabled_(false) {

void ModelInputShaper::setShapeInformation(
const std::string &shapeInformation) {
if (!shapeInformation.empty()) {
std::stringstream shapeInfoString(shapeInformation);
std::string shapeString;
while (std::getline(shapeInfoString, shapeString, INPUT_SEP)) {
size_t pos = shapeString.find(INPUT_DIM_SEP);
std::string inputString = shapeString.substr(0, pos);
std::string dimString = shapeString.substr(pos + 1);

// Parse dimString.
std::stringstream dimSizes(dimString);
std::string dimStr;
std::vector<int64_t> dims;
while (std::getline(dimSizes, dimStr, DIM_SEP)) {
int64_t dimSize = std::stoi(dimStr);
assert((dimSize == ModelInputShaper::kUserDynamic || dimSize > 0) &&
"dim must be -1 or > 0");
if (dimSize == ModelInputShaper::kUserDynamic)
dimSize = ShapedType::kDynamic;
dims.emplace_back(dimSize);
}

// Parse inputString.
assert(std::count(inputString.begin(), inputString.end(),
INPUT_RANGE_SEP) <= 1 &&
"input_id is invalid");
// Check if users input a range or not.
size_t rangePos = inputString.find(INPUT_RANGE_SEP);
std::string startString = inputString.substr(0, rangePos);
std::string endString = inputString.substr(rangePos + 1);
assert(endString != "" && "input_id has _ at the end");
bool isRangeInput = (startString != "");
// Insert (input_id, dim_value) to the shape info.
SmallVector<int64_t> inputIDs;
if (isRangeInput) {
int64_t startID = std::stoi(startString);
int64_t endID = std::stoi(endString);
assert((startID >= 0) && "start_id must be >= 0");
assert((endID >= 0) && "end_id must be >= 0");
for (int64_t i = startID; i <= endID; ++i)
inputIDs.emplace_back(i);
} else {
int64_t inputID = std::stoi(inputString);
assert((inputID >= 0 || inputID == kUserAllInputs) &&
"input_id must be -1 or >= 0");
inputIDs.emplace_back(inputID);
}
for (int64_t inputID : inputIDs) {
// The semantics of c++ map.insert() makes sure that only the first
// setting of inputID is inserted.
inputs_shape_information_.insert(std::make_pair(inputID, dims));
}
}
}
inputs_shape_information_ = parseShapeInformation(shapeInformation);
}

namespace {
Expand Down
35 changes: 35 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ bool enableSafeCodeGen; // common for both
bool disableMemRefPrefetch; // common for both
uint64_t compilationNumThreads; // common for both
std::vector<std::string> decomposeOpsInONNX; // common for both
std::string shapeInformationUB; // common for both
std::string shapeInformationLB; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
bool preserveLocations; // onnx-mlir only
Expand Down Expand Up @@ -261,6 +263,39 @@ static llvm::cl::opt<bool, true> enableSafeCodeGenOpt("enable-safe-code-gen",
llvm::cl::location(enableSafeCodeGen), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<std::string, true> shapeInformationUBOpt(
"shapeInformationUB",
llvm::cl::desc(
"Specify the upper bound (inclusive) of dimension size for the inputs "
"of\n"
"the ONNX model. A popular use case is the maximum sequence length of"
"encoder model.\n"
"\"value\" is in the format of "
"\"INPUT_ID1:D1xD2x...xDn,INPUT_ID2:D1xD2x...xDn, ...\",\n"
"where \"INPUT_ID1, INPUT_ID2, ...\" are input indices (They can be an "
"integer starting from 0, a range e.g. 5-17, or -1 for all input "
"indices), and\n \"D1, D2, ...\" are the UB (positive "
"integers or -1 for unknown UB).\n"
"Such information will be used by verifyInputTensor and optimizations"),
llvm::cl::value_desc("value"), llvm::cl::location(shapeInformationUB),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<std::string, true> shapeInformationLBOpt(
"shapeInformationLB",
llvm::cl::desc(
"Specify the lower bound (inclusive) of dimension size\n"
"for the inputs of the ONNX model. A possible example is to used for\n"
"batch size if the scheduler can guarantee the minimum batch size\n"
"\"value\" is in the format of "
"\"INPUT_ID1:D1xD2x...xDn,INPUT_ID2:D1xD2x...xDn, ...\",\n"
"where \"INPUT_ID1, INPUT_ID2, ...\" are input indices (They can be an "
"integer starting from 0, a range e.g. 5-17, or -1 for all input "
"indices), and\n \"D1, D2, ...\" are the LB (positive "
"integers or -1 for unknown LB).\n"
"Such information will be used by verifyInputTensor and optimizations"),
llvm::cl::value_desc("value"), llvm::cl::location(shapeInformationLB),
llvm::cl::cat(OnnxMlirCommonOptions));

// TODO(alexe) re-enable prefetch.
static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
Expand Down
2 changes: 2 additions & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ extern bool doNotEmitFullMLIRCode; // onnx-mlir only
extern bool useOnnxModelTypes; // onnx-mlir only
extern int repeatOnnxTransform; // onnx-mlir only
extern std::string shapeInformation; // onnx-mlir only
extern std::string shapeInformationUB; // onnx-mlir only
extern std::string shapeInformationLB; // onnx-mlir only
extern std::string dimParams; // onnx-mlir only
extern ModelSize modelSize; // onnx-mlir only
extern bool storeConstantsToFile; // onnx-mlir only
Expand Down
99 changes: 99 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/Support/JSON.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp"
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Support/TypeUtilities.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_llvm"
Expand All @@ -49,6 +51,13 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
std::map<std::string, SmallVector<MemRefType, 4>> &inputMemRefTypes;
std::map<std::string, SmallVector<MemRefType, 4>> &outputMemRefTypes;
bool verifyInputTensors;
/*
typedef enum {
UB,
LB
} ShapeInfoType;
*/
enum class ShapeInfoType { UB, LB };

KrnlEntryPointOpLowering(LLVMTypeConverter &typeConverter, MLIRContext *ctx,
ArrayRef<bool> outputOMTensorOwnerships, bool singleEntryPoint,
Expand Down Expand Up @@ -190,6 +199,16 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
std::tie(inSigJSON, std::ignore) = sigAttr.getValue().split('@');
emitVerificationCodeForInputTensors(
module, rewriter, loc, apiRegistry, omTensorInputs, inSigJSON);
// Check input tensor dimension specified with option
// shapeInformationUB and shapeInformationLB
if (!shapeInformationUB.empty())
emitVerificationCodeForDimensionBoundOfInputTensors(module, rewriter,
loc, apiRegistry, omTensorInputs, inSigJSON, shapeInformationUB,
ShapeInfoType::UB);
if (!shapeInformationLB.empty())
emitVerificationCodeForDimensionBoundOfInputTensors(module, rewriter,
loc, apiRegistry, omTensorInputs, inSigJSON, shapeInformationLB,
ShapeInfoType::LB);
}

// 3. Emit code to prepare MemRefs from OMTensor inputs and call
Expand Down Expand Up @@ -527,6 +546,86 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
}
}

void emitVerificationCodeForDimensionBoundOfInputTensors(ModuleOp &module,
PatternRewriter &rewriter, Location loc,
const RuntimeAPIRegistry &apiRegistry, Value omTensorInputs,
StringRef inSigJSON, const std::string boundInfoString,
ShapeInfoType boundType) const {
std::map<int64_t, std::vector<int64_t>> input_bound_information =
parseShapeInformation(boundInfoString);
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<KrnlBuilder, LLVMBuilder> create(rewriter, loc);
Type int64Ty = rewriter.getI64Type();
Type opaquePtrTy = getPointerType(context, rewriter.getI8Type());
// Get a pointer to the list of input omTensors.
Value omTensorPtrArr = RuntimeAPI::callApi(rewriter, loc, apiRegistry,
RuntimeAPI::API::GET_OMT_ARRAY, {omTensorInputs});
auto JSONInput = llvm::json::parse(inSigJSON.data());
assert(JSONInput && "failed to parse json");
auto JSONArray = JSONInput->getAsArray();
assert(JSONArray && "failed to parse json as array");
int64_t inputNum = JSONArray->size();
for (const auto &pair : input_bound_information) {
int64_t startInput = pair.first;
int64_t endInput = pair.first + 1;
// Input -1 means that this shape information is for all inputs
if (pair.first == -1) {
startInput = 0;
endInput = inputNum;
}
for (int64_t inputID = startInput; inputID < endInput; inputID++) {
Value omTensorPtrAddr = create.llvm.getElemPtr(
getPointerType(context, opaquePtrTy), opaquePtrTy, omTensorPtrArr,
ArrayRef<LLVM::GEPArg>{static_cast<int32_t>(inputID)});
Value omTensorPtr = create.llvm.load(opaquePtrTy, omTensorPtrAddr);
Value sizesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry,
RuntimeAPI::API::GET_DATA_SHAPE, {omTensorPtr});

// Check whether the bound info is valid: no more dims than tensor
auto JSONItem = (*JSONArray)[inputID].getAsObject();
auto JSONDimArray = JSONItem->getArray("dims");
uint64_t rank = JSONDimArray->size();
assert(pair.second.size() <= rank && "invalid shapeInformation ");

// Check each dimension
for (uint64_t dimID = 0; dimID < pair.second.size(); ++dimID) {
int64_t bound = pair.second[dimID];
if (bound == -1) // No bound info for this dimension
continue;
Value actualDim = create.llvm.load(int64Ty,
create.llvm.getElemPtr(getPointerType(context, int64Ty), int64Ty,
sizesArrayPtr,
ArrayRef<LLVM::GEPArg>{static_cast<int32_t>(dimID)}));

switch (boundType) {
case ShapeInfoType::UB:
noLessOrFailed(module, rewriter, loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like this should be for LB.

create.llvm.constant(int64Ty, static_cast<int64_t>(bound)),
actualDim,
"verifyInputTensors failed: the upper bound for the input " +
std::to_string(inputID) + " of dimension " +
std::to_string(dimID) +
" is set by --shapeInformationUB as " +
std::to_string(bound) + ", but got ");
break;
case ShapeInfoType::LB:
noGreaterOrFailed(module, rewriter, loc,
create.llvm.constant(int64Ty, static_cast<int64_t>(bound)),
actualDim,
"verifyInputTensors failed: the lower bound for the input " +
std::to_string(inputID) + " of dimension " +
std::to_string(dimID) +
" is set by --shapeInformationLB as " +
std::to_string(bound) + ", but got ");
break;
default:
assert(false && "Unsupported BoundType");
}
}
}
}
}

void recordEntryPointSignatures(ModuleOp &module,
std::string currentEntryPointName, KrnlEntryPointOp entryOp,
SmallVectorImpl<LLVM::GlobalOp> &entryGlobalOps,
Expand Down
25 changes: 25 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,31 @@ void noGreaterOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
});
}

void noLessOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, std::string errorMsg, bool appendRHS) {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::slt, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty()) {
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
}
// Set errno.
emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

void equalOrReturn(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, Value retVal, std::string errorMsg) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
Expand Down
5 changes: 5 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ void noGreaterOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
std::string errorMsg = "", bool appendRHS = true);

/// Emit code for `IF lhs < rhs THEN return null ELSE do nothing`.
void noLessOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
std::string errorMsg = "", bool appendRHS = true);

/// Emit code for `IF lhs != rhs THEN return retVal ELSE do nothing`.
void equalOrReturn(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs, mlir::Value retVal,
Expand Down
73 changes: 73 additions & 0 deletions src/Support/TypeUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,77 @@ unsigned getIntOrFloatByteWidth(Type ty) {
return llvm::divideCeil(ty.getIntOrFloatBitWidth(), 8);
}

std::map<int64_t, std::vector<int64_t>> parseShapeInformation(
const std::string &shapeInformation) {
// For users of onnx-mlir.
// -1 is used for dynamic/unknown dimension.
static constexpr int64_t kUserDynamic = -1;
// -1 is used to indicate all input indices.
static constexpr int64_t kUserAllInputs = -1;

// Separator between inputs.
static constexpr char INPUT_SEP = ',';
// Separator between dimensions.
static constexpr char DIM_SEP = 'x';
// Separator between one input and its dimensions.
static constexpr char INPUT_DIM_SEP = ':';
// Separator to define a range of input indices, e.g. 2-5.
static constexpr char INPUT_RANGE_SEP = '-';
std::map<int64_t, std::vector<int64_t>> inputs_shape_information;
if (!shapeInformation.empty()) {
std::stringstream shapeInfoString(shapeInformation);
std::string shapeString;
while (std::getline(shapeInfoString, shapeString, INPUT_SEP)) {
size_t pos = shapeString.find(INPUT_DIM_SEP);
std::string inputString = shapeString.substr(0, pos);
std::string dimString = shapeString.substr(pos + 1);

// Parse dimString.
std::stringstream dimSizes(dimString);
std::string dimStr;
std::vector<int64_t> dims;
while (std::getline(dimSizes, dimStr, DIM_SEP)) {
int64_t dimSize = std::stoi(dimStr);
assert((dimSize == kUserDynamic || dimSize > 0) &&
"dim must be -1 or > 0");
if (dimSize == kUserDynamic)
dimSize = ShapedType::kDynamic;
dims.emplace_back(dimSize);
}

// Parse inputString.
assert(std::count(inputString.begin(), inputString.end(),
INPUT_RANGE_SEP) <= 1 &&
"input_id is invalid");
// Check if users input a range or not.
size_t rangePos = inputString.find(INPUT_RANGE_SEP);
std::string startString = inputString.substr(0, rangePos);
std::string endString = inputString.substr(rangePos + 1);
assert(endString != "" && "input_id has _ at the end");
bool isRangeInput = (startString != "");
// Insert (input_id, dim_value) to the shape info.
SmallVector<int64_t> inputIDs;
if (isRangeInput) {
int64_t startID = std::stoi(startString);
int64_t endID = std::stoi(endString);
assert((startID >= 0) && "start_id must be >= 0");
assert((endID >= 0) && "end_id must be >= 0");
for (int64_t i = startID; i <= endID; ++i)
inputIDs.emplace_back(i);
} else {
int64_t inputID = std::stoi(inputString);
assert((inputID >= 0 || inputID == kUserAllInputs) &&
"input_id must be -1 or >= 0");
inputIDs.emplace_back(inputID);
}
for (int64_t inputID : inputIDs) {
// The semantics of c++ map.insert() makes sure that only the first
// setting of inputID is inserted.
inputs_shape_information.insert(std::make_pair(inputID, dims));
}
}
}
return inputs_shape_information;
}

} // namespace onnx_mlir
3 changes: 3 additions & 0 deletions src/Support/TypeUtilities.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ bool sameEncodingAttr(mlir::Type t1, mlir::Type t2);
/// Get the byte width of an int or float type.
unsigned getIntOrFloatByteWidth(mlir::Type ty);

std::map<int64_t, std::vector<int64_t>> parseShapeInformation(
const std::string &);

} // namespace onnx_mlir
#endif
Loading