diff --git a/include/cudaq/Optimizer/Dialect/CC/CCTypes.td b/include/cudaq/Optimizer/Dialect/CC/CCTypes.td index 18bce4e156a..fa7706a4ac9 100644 --- a/include/cudaq/Optimizer/Dialect/CC/CCTypes.td +++ b/include/cudaq/Optimizer/Dialect/CC/CCTypes.td @@ -83,6 +83,7 @@ def cc_StructType : CCType<"Struct", "struct", ); let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; let builders = [ TypeBuilder<(ins CArg<"llvm::StringRef">:$name, @@ -154,6 +155,7 @@ def cc_ArrayType : CCType<"Array", "array"> { ); let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; let extraClassDeclaration = [{ using SizeType = std::int64_t; @@ -249,6 +251,8 @@ def cc_StdVectorType : CCType<"Stdvec", "stdvec", [], let assemblyFormat = "`<` qualified($elementType) `>`"; + let genVerifyDecl = 1; + let builders = [ TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{ return Base::get(elementType.getContext(), elementType); diff --git a/lib/Frontend/nvqpp/ASTBridge.cpp b/lib/Frontend/nvqpp/ASTBridge.cpp index 5d33b6d06f4..74ada130f28 100644 --- a/lib/Frontend/nvqpp/ASTBridge.cpp +++ b/lib/Frontend/nvqpp/ASTBridge.cpp @@ -430,14 +430,17 @@ namespace cudaq::details { bool QuakeBridgeVisitor::generateFunctionDeclaration( StringRef funcName, const clang::FunctionDecl *x) { - auto loc = toLocation(x); allowUnknownRecordType = true; - if (!TraverseType(x->getType())) - emitFatalError(loc, "failed to generate type for kernel function"); + if (!TraverseType(x->getType())) { + reportClangError(x, mangler, "failed to generate type for kernel function"); + typeStack.clear(); + return false; + } allowUnknownRecordType = false; if (!doSyntaxChecks(x)) return false; auto funcTy = cast(popType()); + auto loc = toLocation(x); [[maybe_unused]] auto fnPair = getOrAddFunc(loc, funcName, funcTy); assert(fnPair.first && "expected FuncOp to be created"); if (!isa(x) || x->isStatic()) diff --git a/lib/Frontend/nvqpp/ConvertDecl.cpp b/lib/Frontend/nvqpp/ConvertDecl.cpp index 101e34d653f..1c25ea319e3 100644 --- a/lib/Frontend/nvqpp/ConvertDecl.cpp +++ b/lib/Frontend/nvqpp/ConvertDecl.cpp @@ -180,7 +180,15 @@ bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) { // Traverse template argument 0 to get the vector's element type. if (!cts || !TraverseType(cts->getTemplateArgs()[0].getAsType())) return false; - return pushType(cc::StdvecType::get(ctx, popType())); + auto ty = popType(); + if (quake::isQuantumType(ty)) { + if (ty == quake::RefType::get(ctx)) + return pushType(quake::VeqType::getUnsized(ctx)); + cudaq::emitFatalError(toLocation(x->getSourceRange()), + "std::vector element type is not supported"); + return false; + } + return pushType(cc::StdvecType::get(ctx, ty)); } // std::vector => cc.stdvec if (name.equals("_Bit_reference") || name.equals("__bit_reference")) { diff --git a/lib/Frontend/nvqpp/ConvertType.cpp b/lib/Frontend/nvqpp/ConvertType.cpp index d71be7745c7..14d1eaf170f 100644 --- a/lib/Frontend/nvqpp/ConvertType.cpp +++ b/lib/Frontend/nvqpp/ConvertType.cpp @@ -242,10 +242,21 @@ bool QuakeBridgeVisitor::VisitRecordDecl(clang::RecordDecl *x) { SmallVector fieldTys = lastTypes(std::distance(x->field_begin(), x->field_end())); auto [width, alignInBytes] = getWidthAndAlignment(x); + + // This is a struq if it is not empty and all members are quantum references. bool isStruq = !fieldTys.empty(); - for (auto ty : fieldTys) + bool quantumMembers = false; + for (auto ty : fieldTys) { + if (quake::isQuantumType(ty)) + quantumMembers = true; if (!quake::isQuantumReferenceType(ty)) isStruq = false; + } + if (quantumMembers && !isStruq) { + reportClangError(x, mangler, + "hybrid quantum-classical struct types are not allowed"); + return false; + } auto ty = [&]() -> Type { if (isStruq) @@ -458,7 +469,16 @@ bool QuakeBridgeVisitor::VisitRValueReferenceType( bool QuakeBridgeVisitor::VisitConstantArrayType(clang::ConstantArrayType *t) { auto size = t->getSize().getZExtValue(); - return pushType(cc::ArrayType::get(builder.getContext(), popType(), size)); + auto ty = popType(); + if (quake::isQuantumType(ty)) { + auto *ctx = builder.getContext(); + if (ty == quake::RefType::get(ctx)) + return pushType(quake::VeqType::getUnsized(ctx)); + emitFatalError(builder.getUnknownLoc(), + "array element type is not supported"); + return false; + } + return pushType(cc::ArrayType::get(builder.getContext(), ty, size)); } bool QuakeBridgeVisitor::pushType(Type t) { diff --git a/lib/Optimizer/Dialect/CC/CCTypes.cpp b/lib/Optimizer/Dialect/CC/CCTypes.cpp index e3b98de5e61..5ed781ce997 100644 --- a/lib/Optimizer/Dialect/CC/CCTypes.cpp +++ b/lib/Optimizer/Dialect/CC/CCTypes.cpp @@ -8,6 +8,7 @@ #include "cudaq/Optimizer/Dialect/CC/CCTypes.h" #include "cudaq/Optimizer/Dialect/CC/CCDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -103,6 +104,16 @@ cc::StructType::getPreferredAlignment(const DataLayout &dataLayout, return getAlignment(); } +LogicalResult +cc::StructType::verify(llvm::function_ref emitError, + mlir::StringAttr, llvm::ArrayRef members, + bool, bool, unsigned long, unsigned int) { + for (auto ty : members) + if (quake::isQuantumType(ty)) + return emitError() << "cc.struct may not contain quake types: " << ty; + return success(); +} + //===----------------------------------------------------------------------===// // ArrayType //===----------------------------------------------------------------------===// @@ -139,6 +150,24 @@ void cc::ArrayType::print(AsmPrinter &printer) const { printer << '>'; } +LogicalResult +cc::ArrayType::verify(function_ref emitError, Type eleTy, + long) { + if (quake::isQuantumType(eleTy)) + return emitError() << "cc.array may not have a quake element type: " + << eleTy; + return success(); +} + +LogicalResult +cc::StdvecType::verify(function_ref emitError, + Type eleTy) { + if (quake::isQuantumType(eleTy)) + return emitError() << "cc.stdvec may not have a quake element type: " + << eleTy; + return success(); +} + } // namespace cudaq //===----------------------------------------------------------------------===// diff --git a/test/AST-error/struct_quantum_and_classical.cpp b/test/AST-error/struct_quantum_and_classical.cpp index f62c661f779..ae9f25544a9 100644 --- a/test/AST-error/struct_quantum_and_classical.cpp +++ b/test/AST-error/struct_quantum_and_classical.cpp @@ -20,6 +20,7 @@ struct test { __qpu__ void hello(cudaq::qubit &q) { h(q); } +// expected-error@+1 {{failed to generate type for kernel function}} __qpu__ void kernel(test t) { h(t.q); hello(t.q[0]);