diff --git a/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td b/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td index aee865529aa..a638b76139e 100644 --- a/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td +++ b/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td @@ -80,6 +80,7 @@ class CIR_VariadicTensorOf allowedRuntimeTypes, Variadic>, CIR_RuntimeType>>; +def CIR_I4 : I<4>; def CIR_Int32Or64 : SignlessIntOfWidths<[32, 64]>; def CIR_BoolTensor : CIR_TensorOf<[I1]>; @@ -259,6 +260,8 @@ class CIR_Op traits = []> : // Whether the Circle operator has options in the schema representation. bit hasOptions = 0b0; + // Whether the Circle operator has options2 in the schema representation. + bit hasOptions2 = 0b0; // Use to specify a custom options type for Circle operators where // the option's name does not match the Cirlce operator's name. diff --git a/circle-mlir/circle-mlir/lib/tools/converter-gen/converter_gen.cc b/circle-mlir/circle-mlir/lib/tools/converter-gen/converter_gen.cc index 3bcc0b66f64..86d39effee6 100644 --- a/circle-mlir/circle-mlir/lib/tools/converter-gen/converter_gen.cc +++ b/circle-mlir/circle-mlir/lib/tools/converter-gen/converter_gen.cc @@ -91,6 +91,16 @@ static inline bool IsLstmOp(const StringRef op_name) { return op_name.take_back(6) == "LSTMOp"; } +static int HasOptions(const Record &def) { + if (def.getValueAsBit("hasOptions")) { + return 1; + } + if (def.getValueAsBit("hasOptions2")) { + return 2; + } + return 0; +} + static void EmitOptionBuilders(const RecordKeeper &record_keeper, const std::vector &defs, raw_ostream *ostream) { @@ -98,8 +108,11 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper, const auto attr_type = record_keeper.getClass("Attr"); for (const auto *def : defs) { + const int has_options = HasOptions(*def); // Circle ops without options are skipped over. - if (!def->getValueAsBit("hasOptions")) continue; + if (!has_options) { + continue; + } StringRef op_name = def->getName().drop_front(4); // Strip 'CIR_' prefix std::string option_name = GetOperatorOptionName(*def); @@ -204,7 +217,8 @@ static void EmitOperatorBuilders(const std::vector &defs, // Build the FlatBuffer operator os << " return circle::CreateOperator(\n" " *fbb, opcode_index, inputs, outputs,\n"; - if (def->getValueAsBit("hasOptions")) { + const int has_options = HasOptions(*def); + if (has_options == 1) { auto option_name = GetOperatorOptionName(*def); std::string circle_option_name = option_name == "BasicLSTMOptions" ? "LSTMOptions" : option_name; @@ -217,8 +231,26 @@ static void EmitOperatorBuilders(const std::vector &defs, // used by custom or flex ops and those ops are handled manually. os << " /*custom_options=*/0, " << "circle::CustomOptionsFormat_FLEXBUFFERS,\n" - << " /*mutating_variable_inputs=*/0" - << (has_intermediates ? ", intermediates" : "") << ");\n}\n\n"; + << " /*mutating_variable_inputs=*/0," + << (has_intermediates ? "intermediates" : "/*intermediates=*/0"); + + if (has_options == 2) { + os << ",\n" + << " /*large_custom_options_offset=*/0,\n" + << " /*large_custom_options_size=*/0"; + os << ",\n"; + const std::string option_name = GetOperatorOptionName(*def); + os << " circle::BuiltinOptions2_" << option_name << ", " + << "Create" << option_name << "(tflOp, fbb).Union()"; + } else { + os << ",\n" + << " /*large_custom_options_offset=*/0,\n" + << " /*large_custom_options_size=*/0"; + os << ",\n"; + os << " circle::BuiltinOptions2_NONE, /*builtin_options2=*/0"; + } + + os << ");\n}\n\n"; } } @@ -355,24 +387,43 @@ static void EmitBuildOperator(const std::vector &defs, // Emit a function that converts a BuiltinOptionsUnion to a vector of attributes // Signature: -// void mlir::BuiltinOptionsToAttributes( -// circle::BuiltinOptionsUnion op_union, +// void mlir::BuiltinOptions{id}ToAttributes( +// circle::BuiltinOptions{id}Union op_union, // mlir::Builder builder, // llvm::SmallVectorImpl &attributes); -static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, - const std::vector &defs, - raw_ostream *ostream) { +// +// where id is an empty string if builtin_options_id is 1, or builtin_options_id +// otherwise. +static void EmitBuiltinOptionsToAttributes( + const RecordKeeper &record_keeper, const std::vector &defs, + raw_ostream *ostream, const int builtin_options_id) { raw_ostream &os = *ostream; + const std::string builtin_options_suffix = [&] { + switch (builtin_options_id) { + case 1: + return ""; + case 2: + return "2"; + } + return "UnknownId"; + }(); + // Signature - os << "void mlir::BuiltinOptionsToAttributes(" - "circle::BuiltinOptionsUnion op_union, " + os << "void mlir::BuiltinOptions" << builtin_options_suffix + << "ToAttributes(" + "circle::BuiltinOptions" + << builtin_options_suffix + << "Union op_union, " "mlir::Builder builder, " "llvm::SmallVectorImpl &attributes) {\n"; const auto attr_type = record_keeper.getClass("Attr"); for (const auto *def : defs) { - if (!def->getValueAsBit("hasOptions")) continue; + const int has_options = HasOptions(*def); + if (has_options != builtin_options_id) { + continue; + } auto option_name = GetOperatorOptionName(*def); // Basic LSTM and LSTM ops share the same option to attribute converter. if (option_name == "BasicLSTMOptions") { @@ -405,9 +456,14 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, os << " return;\n"; os << " }\n"; } + if (builtin_options_id == 2) { + os << " BuiltinOptions2ToAttributesManual(op_union, builder, " + "attributes);\n"; + } // Fallthrough case is no attributes os << "}"; } + // The function below has a non-constant reference as that is required by LLVM's // TableGenMain. // NOLINTNEXTLINE @@ -440,8 +496,11 @@ static bool OperatorWritersMain(raw_ostream &os, const RecordKeeper &records) { os << "\n\n"; EmitBuildOperator(defs, &os); os << "\n\n"; - EmitBuiltinOptionsToAttributes(records, defs, &os); + EmitBuiltinOptionsToAttributes(records, defs, &os, /*builtin_options_id=*/1); os << "\n\n"; + // TODO support options2 + //EmitBuiltinOptionsToAttributes(records, defs, &os, /*builtin_options_id=*/2); + //os << "\n\n"; EmitOperandNumbers(records, defs, &os); return false;