@@ -46,10 +46,27 @@ class OpenACCClauseCIREmitter final
46
46
// diagnostics are gone.
47
47
SourceLocation dirLoc;
48
48
49
+ const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr ;
50
+
49
51
void clauseNotImplemented (const OpenACCClause &c) {
50
52
cgf.cgm .errorNYI (c.getSourceRange (), " OpenACC Clause" , c.getClauseKind ());
51
53
}
52
54
55
+ mlir::Value createIntExpr (const Expr *intExpr) {
56
+ mlir::Value expr = cgf.emitScalarExpr (intExpr);
57
+ mlir::Location exprLoc = cgf.cgm .getLoc (intExpr->getBeginLoc ());
58
+
59
+ mlir::IntegerType targetType = mlir::IntegerType::get (
60
+ &cgf.getMLIRContext (), cgf.getContext ().getIntWidth (intExpr->getType ()),
61
+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
62
+ ? mlir::IntegerType::SignednessSemantics::Signed
63
+ : mlir::IntegerType::SignednessSemantics::Unsigned);
64
+
65
+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
66
+ exprLoc, targetType, expr);
67
+ return conversionOp.getResult (0 );
68
+ }
69
+
53
70
// 'condition' as an OpenACC grammar production is used for 'if' and (some
54
71
// variants of) 'self'. It needs to be emitted as a signless-1-bit value, so
55
72
// this function emits the expression, then sets the unrealized conversion
@@ -109,14 +126,15 @@ class OpenACCClauseCIREmitter final
109
126
}
110
127
111
128
void VisitDeviceTypeClause (const OpenACCDeviceTypeClause &clause) {
129
+ lastDeviceTypeClause = &clause;
112
130
if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
113
131
llvm::SmallVector<mlir::Attribute> deviceTypes;
114
132
std::optional<mlir::ArrayAttr> existingDeviceTypes =
115
133
operation.getDeviceTypes ();
116
134
117
135
// Ensure we keep the existing ones, and in the correct 'new' order.
118
136
if (existingDeviceTypes) {
119
- for (const mlir::Attribute & Attr : *existingDeviceTypes)
137
+ for (mlir::Attribute Attr : *existingDeviceTypes)
120
138
deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
121
139
builder.getContext (),
122
140
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
@@ -136,6 +154,51 @@ class OpenACCClauseCIREmitter final
136
154
if (!clause.getArchitectures ().empty ())
137
155
operation.setDeviceType (
138
156
decodeDeviceType (clause.getArchitectures ()[0 ].getIdentifierInfo ()));
157
+ } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
158
+ // Nothing to do here, these constructs don't have any IR for these, as
159
+ // they just modify the other clauses IR. So setting of `lastDeviceType`
160
+ // (done above) is all we need.
161
+ } else {
162
+ return clauseNotImplemented (clause);
163
+ }
164
+ }
165
+
166
+ void VisitNumWorkersClause (const OpenACCNumWorkersClause &clause) {
167
+ if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
168
+ // Collect the 'existing' device-type attributes so we can re-create them
169
+ // and insert them.
170
+ llvm::SmallVector<mlir::Attribute> deviceTypes;
171
+ mlir::ArrayAttr existingDeviceTypes =
172
+ operation.getNumWorkersDeviceTypeAttr ();
173
+
174
+ if (existingDeviceTypes) {
175
+ for (mlir::Attribute Attr : existingDeviceTypes)
176
+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
177
+ builder.getContext (),
178
+ cast<mlir::acc::DeviceTypeAttr>(Attr).getValue ()));
179
+ }
180
+
181
+ // Insert 1 version of the 'int-expr' to the NumWorkers list per-current
182
+ // device type.
183
+ mlir::Value intExpr = createIntExpr (clause.getIntExpr ());
184
+ if (lastDeviceTypeClause) {
185
+ for (const DeviceTypeArgument &arg :
186
+ lastDeviceTypeClause->getArchitectures ()) {
187
+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
188
+ builder.getContext (), decodeDeviceType (arg.getIdentifierInfo ())));
189
+ operation.getNumWorkersMutable ().append (intExpr);
190
+ }
191
+ } else {
192
+ // Else, we just add a single for 'none'.
193
+ deviceTypes.push_back (mlir::acc::DeviceTypeAttr::get (
194
+ builder.getContext (), mlir::acc::DeviceType::None));
195
+ operation.getNumWorkersMutable ().append (intExpr);
196
+ }
197
+
198
+ operation.setNumWorkersDeviceTypeAttr (
199
+ mlir::ArrayAttr::get (builder.getContext (), deviceTypes));
200
+ } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
201
+ llvm_unreachable (" num_workers not valid on serial" );
139
202
} else {
140
203
return clauseNotImplemented (clause);
141
204
}
0 commit comments