Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def cc_StructType : CCType<"Struct", "struct",
);

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let builders = [
TypeBuilder<(ins CArg<"llvm::StringRef">:$name,
Expand Down Expand Up @@ -154,6 +155,7 @@ def cc_ArrayType : CCType<"Array", "array"> {
);

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let extraClassDeclaration = [{
using SizeType = std::int64_t;
Expand Down Expand Up @@ -249,6 +251,8 @@ def cc_StdVectorType : CCType<"Stdvec", "stdvec", [],

let assemblyFormat = "`<` qualified($elementType) `>`";

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
Expand Down
9 changes: 6 additions & 3 deletions lib/Frontend/nvqpp/ASTBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,17 @@ namespace cudaq::details {

bool QuakeBridgeVisitor::generateFunctionDeclaration(
StringRef funcName, const clang::FunctionDecl *x) {
auto loc = toLocation(x);
allowUnknownRecordType = true;
if (!TraverseType(x->getType()))
emitFatalError(loc, "failed to generate type for kernel function");
if (!TraverseType(x->getType())) {
reportClangError(x, mangler, "failed to generate type for kernel function");
typeStack.clear();
return false;
}
allowUnknownRecordType = false;
if (!doSyntaxChecks(x))
return false;
auto funcTy = cast<FunctionType>(popType());
auto loc = toLocation(x);
[[maybe_unused]] auto fnPair = getOrAddFunc(loc, funcName, funcTy);
assert(fnPair.first && "expected FuncOp to be created");
if (!isa<clang::CXXMethodDecl>(x) || x->isStatic())
Expand Down
10 changes: 9 additions & 1 deletion lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,15 @@ bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) {
// Traverse template argument 0 to get the vector's element type.
if (!cts || !TraverseType(cts->getTemplateArgs()[0].getAsType()))
return false;
return pushType(cc::StdvecType::get(ctx, popType()));
auto ty = popType();
if (quake::isQuantumType(ty)) {
if (ty == quake::RefType::get(ctx))
return pushType(quake::VeqType::getUnsized(ctx));
cudaq::emitFatalError(toLocation(x->getSourceRange()),
"std::vector element type is not supported");
return false;
}
return pushType(cc::StdvecType::get(ctx, ty));
}
// std::vector<bool> => cc.stdvec<i1>
if (name.equals("_Bit_reference") || name.equals("__bit_reference")) {
Expand Down
24 changes: 22 additions & 2 deletions lib/Frontend/nvqpp/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,21 @@ bool QuakeBridgeVisitor::VisitRecordDecl(clang::RecordDecl *x) {
SmallVector<Type> fieldTys =
lastTypes(std::distance(x->field_begin(), x->field_end()));
auto [width, alignInBytes] = getWidthAndAlignment(x);

// This is a struq if it is not empty and all members are quantum references.
bool isStruq = !fieldTys.empty();
for (auto ty : fieldTys)
bool quantumMembers = false;
for (auto ty : fieldTys) {
if (quake::isQuantumType(ty))
quantumMembers = true;
if (!quake::isQuantumReferenceType(ty))
isStruq = false;
}
if (quantumMembers && !isStruq) {
reportClangError(x, mangler,
"hybrid quantum-classical struct types are not allowed");
return false;
}

auto ty = [&]() -> Type {
if (isStruq)
Expand Down Expand Up @@ -458,7 +469,16 @@ bool QuakeBridgeVisitor::VisitRValueReferenceType(

bool QuakeBridgeVisitor::VisitConstantArrayType(clang::ConstantArrayType *t) {
auto size = t->getSize().getZExtValue();
return pushType(cc::ArrayType::get(builder.getContext(), popType(), size));
auto ty = popType();
if (quake::isQuantumType(ty)) {
auto *ctx = builder.getContext();
if (ty == quake::RefType::get(ctx))
return pushType(quake::VeqType::getUnsized(ctx));
emitFatalError(builder.getUnknownLoc(),
"array element type is not supported");
return false;
}
return pushType(cc::ArrayType::get(builder.getContext(), ty, size));
}

bool QuakeBridgeVisitor::pushType(Type t) {
Expand Down
29 changes: 29 additions & 0 deletions lib/Optimizer/Dialect/CC/CCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -103,6 +104,16 @@ cc::StructType::getPreferredAlignment(const DataLayout &dataLayout,
return getAlignment();
}

LogicalResult
cc::StructType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::StringAttr, llvm::ArrayRef<mlir::Type> members,
bool, bool, unsigned long, unsigned int) {
for (auto ty : members)
if (quake::isQuantumType(ty))
return emitError() << "cc.struct may not contain quake types: " << ty;
return success();
}

//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -139,6 +150,24 @@ void cc::ArrayType::print(AsmPrinter &printer) const {
printer << '>';
}

LogicalResult
cc::ArrayType::verify(function_ref<InFlightDiagnostic()> emitError, Type eleTy,
long) {
if (quake::isQuantumType(eleTy))
return emitError() << "cc.array may not have a quake element type: "
<< eleTy;
return success();
}

LogicalResult
cc::StdvecType::verify(function_ref<InFlightDiagnostic()> emitError,
Type eleTy) {
if (quake::isQuantumType(eleTy))
return emitError() << "cc.stdvec may not have a quake element type: "
<< eleTy;
return success();
}

} // namespace cudaq

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions test/AST-error/struct_quantum_and_classical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct test {

__qpu__ void hello(cudaq::qubit &q) { h(q); }

// expected-error@+1 {{failed to generate type for kernel function}}
__qpu__ void kernel(test t) {
h(t.q);
hello(t.q[0]);
Expand Down
Loading