Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions compiler/AST/build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,12 +1609,21 @@ BlockStmt* buildCoforallLoopStmt(Expr* indices,
beginBlk->blockInfoSet(new CallExpr(PRIM_BLOCK_COFORALL));
addByrefVars(beginBlk, byref_vars);
beginBlk->insertAtHead(body);
#ifndef TARGET_HSA
beginBlk->insertAtTail(new CallExpr("_downEndCount", coforallCount));
#else
beginBlk->insertAtTail(new CallExpr("_completeTaskGroup", coforallCount));
#endif
BlockStmt* block = ForLoop::buildForLoop(indices, new SymExpr(tmpIter), beginBlk, true, zippered);
block->insertAtHead(new CallExpr(PRIM_MOVE, coforallCount, new CallExpr("_endCountAlloc", /*forceLocalTypes=*/gTrue)));
block->insertAtHead(new DefExpr(coforallCount));
#ifndef TARGET_HSA
beginBlk->insertBefore(new CallExpr("_upEndCount", coforallCount));
block->insertAtTail(new CallExpr("_waitEndCount", coforallCount));
#else
beginBlk->insertBefore(new CallExpr("_initTaskGroup", coforallCount));
block->insertAtTail(new CallExpr("_finalizeTaskGroup", coforallCount));
#endif
block->insertAtTail(new CallExpr("_endCountFree", coforallCount));
nonVectorCoforallBlk->insertAtTail(block);
}
Expand All @@ -1624,15 +1633,27 @@ BlockStmt* buildCoforallLoopStmt(Expr* indices,
beginBlk->blockInfoSet(new CallExpr(PRIM_BLOCK_COFORALL));
addByrefVars(beginBlk, byref_vars);
beginBlk->insertAtHead(body->copy());
#ifndef TARGET_HSA
beginBlk->insertAtTail(new CallExpr("_downEndCount", coforallCount));
#else
beginBlk->insertAtTail(new CallExpr("_completeTaskGroup", coforallCount));
#endif
VarSymbol* numTasks = newTemp("numTasks");
vectorCoforallBlk->insertAtTail(new DefExpr(numTasks));
vectorCoforallBlk->insertAtTail(new CallExpr(PRIM_MOVE, numTasks, new CallExpr(".", tmpIter, new_CStringSymbol("size"))));
#ifndef TARGET_HSA
vectorCoforallBlk->insertAtTail(new CallExpr("_upEndCount", coforallCount, /*countRunningTasks=*/gTrue, numTasks));
#else
vectorCoforallBlk->insertAtTail(new CallExpr("_initTaskGroup", coforallCount, /*countRunningTasks=*/gTrue, numTasks));
#endif
BlockStmt* block = ForLoop::buildForLoop(indices, new SymExpr(tmpIter), beginBlk, true, zippered);
vectorCoforallBlk->insertAtHead(new CallExpr(PRIM_MOVE, coforallCount, new CallExpr("_endCountAlloc", /*forceLocalTypes=*/gTrue)));
vectorCoforallBlk->insertAtHead(new DefExpr(coforallCount));
#ifndef TARGET_HSA
block->insertAtTail(new CallExpr("_waitEndCount", coforallCount, /*countRunningTasks=*/gTrue, numTasks));
#else
block->insertAtTail(new CallExpr("_finalizeTaskGroup", coforallCount, /*countRunningTasks=*/gTrue, numTasks));
#endif
block->insertAtTail(new CallExpr("_endCountFree", coforallCount));
vectorCoforallBlk->insertAtTail(block);
}
Expand Down Expand Up @@ -2890,12 +2911,20 @@ buildBeginStmt(CallExpr* byref_vars, Expr* stmt) {
return body;
} else {
BlockStmt* block = buildChapelStmt();
#ifndef TARGET_HSA
block->insertAtTail(new CallExpr("_upEndCount"));
#else
block->insertAtTail(new CallExpr("_addToTaskGroup"));
#endif
BlockStmt* beginBlock = new BlockStmt();
beginBlock->blockInfoSet(new CallExpr(PRIM_BLOCK_BEGIN));
addByrefVars(beginBlock, byref_vars);
beginBlock->insertAtHead(stmt);
#ifndef TARGET_HSA
beginBlock->insertAtTail(new CallExpr("_downEndCount"));
#else
beginBlock->insertAtTail(new CallExpr("_completeTaskGroup"));
#endif
block->insertAtTail(beginBlock);
return block;
}
Expand All @@ -2911,7 +2940,11 @@ buildSyncStmt(Expr* stmt) {
block->insertAtTail(new CallExpr(PRIM_MOVE, endCountSave, new CallExpr(PRIM_GET_END_COUNT)));
block->insertAtTail(new CallExpr(PRIM_SET_END_COUNT, new CallExpr("_endCountAlloc", /* forceLocalTypes= */gFalse)));
block->insertAtTail(stmt);
#ifndef TARGET_HSA
block->insertAtTail(new CallExpr("_waitEndCount"));
#else
block->insertAtTail(new CallExpr("_finalizeTaskGroup"));
#endif
block->insertAtTail(new CallExpr("_endCountFree", new CallExpr(PRIM_GET_END_COUNT)));
block->insertAtTail(new CallExpr(PRIM_SET_END_COUNT, endCountSave));
return block;
Expand Down Expand Up @@ -2947,13 +2980,24 @@ buildCobeginStmt(CallExpr* byref_vars, BlockStmt* block) {
addByrefVars(beginBlk, byref_vars ? byref_vars->copy() : NULL);
stmt->insertBefore(beginBlk);
beginBlk->insertAtHead(stmt->remove());
#ifndef TARGET_HSA
beginBlk->insertAtTail(new CallExpr("_downEndCount", cobeginCount));
block->insertAtHead(new CallExpr("_upEndCount", cobeginCount));
#else
beginBlk->insertAtTail(new CallExpr("_completeTaskGroup", cobeginCount));
#endif
}
#ifdef TARGET_HSA
block->insertAtHead(new CallExpr("_initTaskGroup", cobeginCount));
#endif

block->insertAtHead(new CallExpr(PRIM_MOVE, cobeginCount, new CallExpr("_endCountAlloc", /* forceLocalTypes= */gTrue)));
block->insertAtHead(new DefExpr(cobeginCount));
#ifndef TARGET_HSA
block->insertAtTail(new CallExpr("_waitEndCount", cobeginCount));
#else
block->insertAtTail(new CallExpr("_finalizeTaskGroup", cobeginCount));
#endif
block->insertAtTail(new CallExpr("_endCountFree", cobeginCount));

block->astloc = cobeginCount->astloc; // grab the location of 'cobegin' kw
Expand Down
16 changes: 12 additions & 4 deletions compiler/AST/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ CallExpr::CallExpr(BaseAST* base,
BaseAST* arg2,
BaseAST* arg3,
BaseAST* arg4,
BaseAST* arg5) :
BaseAST* arg5,
BaseAST* arg6) :
Expr(E_CallExpr),
primitive(NULL),
baseExpr(NULL),
Expand All @@ -815,6 +816,7 @@ CallExpr::CallExpr(BaseAST* base,
callExprHelper(this, arg3);
callExprHelper(this, arg4);
callExprHelper(this, arg5);
callExprHelper(this, arg6);

argList.parent = this;

Expand All @@ -827,7 +829,8 @@ CallExpr::CallExpr(PrimitiveOp* prim,
BaseAST* arg2,
BaseAST* arg3,
BaseAST* arg4,
BaseAST* arg5) :
BaseAST* arg5,
BaseAST* arg6) :
Expr(E_CallExpr),
primitive(prim),
baseExpr(NULL),
Expand All @@ -840,6 +843,7 @@ CallExpr::CallExpr(PrimitiveOp* prim,
callExprHelper(this, arg3);
callExprHelper(this, arg4);
callExprHelper(this, arg5);
callExprHelper(this, arg6);

argList.parent = this;

Expand All @@ -851,7 +855,8 @@ CallExpr::CallExpr(PrimitiveTag prim,
BaseAST* arg2,
BaseAST* arg3,
BaseAST* arg4,
BaseAST* arg5) :
BaseAST* arg5,
BaseAST* arg6) :
Expr(E_CallExpr),
primitive(primitives[prim]),
baseExpr(NULL),
Expand All @@ -864,6 +869,7 @@ CallExpr::CallExpr(PrimitiveTag prim,
callExprHelper(this, arg3);
callExprHelper(this, arg4);
callExprHelper(this, arg5);
callExprHelper(this, arg6);

argList.parent = this;

Expand All @@ -876,7 +882,8 @@ CallExpr::CallExpr(const char* name,
BaseAST* arg2,
BaseAST* arg3,
BaseAST* arg4,
BaseAST* arg5) :
BaseAST* arg5,
BaseAST* arg6) :
Expr(E_CallExpr),
primitive(NULL),
baseExpr(new UnresolvedSymExpr(name)),
Expand All @@ -889,6 +896,7 @@ CallExpr::CallExpr(const char* name,
callExprHelper(this, arg3);
callExprHelper(this, arg4);
callExprHelper(this, arg5);
callExprHelper(this, arg6);

argList.parent = this;

Expand Down
10 changes: 6 additions & 4 deletions compiler/codegen/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5350,11 +5350,12 @@ void CallExpr::codegenInvokeOnFun() {
void CallExpr::codegenInvokeTaskFun(const char* name) {
FnSymbol* fn = isResolved();
GenRet taskList = codegenValue(get(1));
GenRet taskGroup = codegenValue(get(6));
GenRet taskListNode;
GenRet taskBundle;
GenRet bundleSize;

std::vector<GenRet> args(8);
std::vector<GenRet> args(9);

// get(1) is a ref/wide ref to a task list value
// get(2) is the node ID owning the task list
Expand All @@ -5374,9 +5375,10 @@ void CallExpr::codegenInvokeTaskFun(const char* name) {
args[2] = codegenCast("chpl_task_bundle_p", taskBundle);
args[3] = bundleSize;
args[4] = taskList;
args[5] = codegenValue(taskListNode);
args[6] = fn->linenum();
args[7] = new_IntSymbol(gFilenameLookupCache[fn->fname()], INT_SIZE_32);
args[5] = taskGroup;
args[6] = codegenValue(taskListNode);
args[7] = fn->linenum();
args[8] = new_IntSymbol(gFilenameLookupCache[fn->fname()], INT_SIZE_32);

genComment(fn->cname, true);

Expand Down
12 changes: 8 additions & 4 deletions compiler/include/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,28 +199,32 @@ class CallExpr : public Expr {
BaseAST* arg2 = NULL,
BaseAST* arg3 = NULL,
BaseAST* arg4 = NULL,
BaseAST* arg5 = NULL);
BaseAST* arg5 = NULL,
BaseAST* arg6 = NULL);

CallExpr(PrimitiveOp* prim,
BaseAST* arg1 = NULL,
BaseAST* arg2 = NULL,
BaseAST* arg3 = NULL,
BaseAST* arg4 = NULL,
BaseAST* arg5 = NULL);
BaseAST* arg5 = NULL,
BaseAST* arg6 = NULL);

CallExpr(PrimitiveTag prim,
BaseAST* arg1 = NULL,
BaseAST* arg2 = NULL,
BaseAST* arg3 = NULL,
BaseAST* arg4 = NULL,
BaseAST* arg5 = NULL);
BaseAST* arg5 = NULL,
BaseAST* arg6 = NULL);

CallExpr(const char* name,
BaseAST* arg1 = NULL,
BaseAST* arg2 = NULL,
BaseAST* arg3 = NULL,
BaseAST* arg4 = NULL,
BaseAST* arg5 = NULL);
BaseAST* arg5 = NULL,
BaseAST* arg6 = NULL);

~CallExpr();

Expand Down
4 changes: 4 additions & 0 deletions compiler/passes/buildDefaultFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ static void build_chpl_entry_points() {
// endcount (see comment above)
//
if (fMinimalModules == false) {
#ifndef TARGET_HSA
chpl_gen_main->insertAtTail(new CallExpr("_waitEndCount"));
#else
chpl_gen_main->insertAtTail(new CallExpr("_finalizeTaskGroup"));
#endif
}

chpl_gen_main->insertAtTail(new CallExpr(PRIM_RETURN, main_ret));
Expand Down
4 changes: 2 additions & 2 deletions compiler/passes/insertLineNumbers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ static void moveLinenoInsideArgBundle()
// than the expected number. Both block types below expect an
// argument bundle, and the on-block expects an additional argument
// that is the locale on which it should be executed.
if ((fn->numFormals() > 4 && fn->hasFlag(FLAG_ON_BLOCK)) ||
(fn->numFormals() > 5 && !fn->hasFlag(FLAG_ON_BLOCK) &&
if ((fn->numFormals() > 5 && fn->hasFlag(FLAG_ON_BLOCK)) ||
(fn->numFormals() > 6 && !fn->hasFlag(FLAG_ON_BLOCK) &&
(fn->hasFlag(FLAG_BEGIN_BLOCK) ||
fn->hasFlag(FLAG_COBEGIN_OR_COFORALL_BLOCK)))) {

Expand Down
28 changes: 22 additions & 6 deletions compiler/passes/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ static BundleArgsFnData bundleArgsFnDataInit = { true, NULL, NULL };
static void insertEndCounts();
static void passArgsToNestedFns();
static void create_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, BundleArgsFnData &baData);
static void call_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, VarSymbol* args_buf, VarSymbol* args_buf_len, VarSymbol* tempc, FnSymbol *wrap_fn, Symbol* taskList, Symbol* taskListNode);
static void call_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, VarSymbol* args_buf, VarSymbol* args_buf_len, VarSymbol* tempc, FnSymbol *wrap_fn, Symbol* taskList, Symbol* taskListNode, Symbol *taskGroup);
static void findBlockRefActuals(Vec<Symbol*>& refSet, Vec<Symbol*>& refVec);
static void findHeapVarsAndRefs(Map<Symbol*,Vec<SymExpr*>*>& defMap,
Vec<Symbol*>& refSet, Vec<Symbol*>& refVec,
Expand Down Expand Up @@ -410,6 +410,7 @@ bundleArgs(CallExpr* fcall, BundleArgsFnData &baData) {
// first argument to a task launch function.
Symbol* endCount = NULL;
VarSymbol *taskList = NULL;
VarSymbol *taskGroup = NULL;
VarSymbol *taskListNode = NULL;

if (!fn->hasFlag(FLAG_ON)) {
Expand Down Expand Up @@ -466,6 +467,16 @@ bundleArgs(CallExpr* fcall, BundleArgsFnData &baData) {
endCount,
endCount->typeInfo()->getField("taskList"))));

// Now get the taskGroup field out of the end count.

taskGroup = newTemp(astr("_taskGroup", fn->name), dtCVoidPtr);

fcall->insertBefore(new DefExpr(taskGroup));
fcall->insertBefore(new CallExpr(PRIM_MOVE, taskGroup,
new CallExpr(PRIM_GET_MEMBER,
endCount,
endCount->typeInfo()->getField("taskGroup"))));


// Now get the node ID field for the end count,
// which is where the task list is stored.
Expand All @@ -481,7 +492,7 @@ bundleArgs(CallExpr* fcall, BundleArgsFnData &baData) {
create_block_fn_wrapper(fn, fcall, baData);

// call wrapper-function
call_block_fn_wrapper(fn, fcall, allocated_args, tmpsz, tempc, baData.wrap_fn, taskList, taskListNode);
call_block_fn_wrapper(fn, fcall, allocated_args, tmpsz, tempc, baData.wrap_fn, taskList, taskListNode, taskGroup);
baData.firstCall = false;
}

Expand All @@ -492,7 +503,8 @@ static CallExpr* helpFindDownEndCount(BlockStmt* block)
while (cur && (isCallExpr(cur) || isDefExpr(cur) || isBlockStmt(cur))) {
if (CallExpr* call = toCallExpr(cur)) {
if (call->isResolved())
if (strcmp(call->resolvedFunction()->name, "_downEndCount") == 0)
if (strcmp(call->resolvedFunction()->name, "_downEndCount") == 0 ||
strcmp(call->resolvedFunction()->name, "_completeTaskGroup") == 0)
return call;
} else if (BlockStmt* inner = toBlockStmt(cur)) {
// Need to handle local blocks since the compiler
Expand Down Expand Up @@ -646,6 +658,10 @@ static void create_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, BundleArgsFnD
dtCVoidPtr->refType );
taskListArg->addFlag(FLAG_NO_CODEGEN);
wrap_fn->insertFormalAtTail(taskListArg);
ArgSymbol *taskGroupArg = new ArgSymbol( INTENT_IN, "dummy_taskGroup",
dtCVoidPtr->refType );
taskGroupArg->addFlag(FLAG_NO_CODEGEN);
wrap_fn->insertFormalAtTail(taskGroupArg);
ArgSymbol *taskListNode = new ArgSymbol( INTENT_IN, "dummy_taskListNode",
dtInt[INT_SIZE_DEFAULT]);
taskListNode->addFlag(FLAG_NO_CODEGEN);
Expand Down Expand Up @@ -749,20 +765,20 @@ static void create_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, BundleArgsFnD

static void call_block_fn_wrapper(FnSymbol* fn, CallExpr* fcall, VarSymbol*
args_buf, VarSymbol* args_buf_len, VarSymbol* tempc, FnSymbol *wrap_fn,
Symbol* taskList, Symbol* taskListNode)
Symbol* taskList, Symbol* taskListNode, Symbol *taskGroup)
{
// The wrapper function is called with the bundled argument list.
if (fn->hasFlag(FLAG_ON)) {
// For an on block, the first argument is also passed directly
// to the wrapper function.
// The forking function uses this to fork a task on the target locale.
fcall->insertBefore(new CallExpr(wrap_fn, fcall->get(1)->remove(), args_buf, args_buf_len, tempc));
fcall->insertBefore(new CallExpr(wrap_fn, fcall->get(1)->remove(), args_buf, args_buf_len, tempc));//, new SymExpr(taskGroup)));
} else {
// For non-on blocks, the task list is passed directly to the function
// (so that codegen can find it).
// We need the taskList.
INT_ASSERT(taskList);
fcall->insertBefore(new CallExpr(wrap_fn, new SymExpr(taskList), new SymExpr(taskListNode), args_buf, args_buf_len, tempc));
fcall->insertBefore(new CallExpr(wrap_fn, new SymExpr(taskList), new SymExpr(taskListNode), args_buf, args_buf_len, tempc, new SymExpr(taskGroup)));
}

fcall->remove(); // rm orig. call
Expand Down
6 changes: 3 additions & 3 deletions make/compiler/Makefile.hsa
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ include $(CHPL_MAKE_HOME)/make/compiler/Makefile.gnu
ifdef CHPL_ROCM
# ROCm locations
CLOC=/opt/rocm/cloc/bin/cloc.sh
LIBS+=-lhsa-runtime64 -lhsakmt -lm
LIBS+=-latmi_runtime -lm

# TODO: move these in third-party directory?
GEN_LFLAGS+=-L/opt/rocm/lib -L/opt/rocm/hsa/lib
HSA_INCLUDES=-I/opt/rocm/hsa/include
GEN_LFLAGS+=-L/opt/rocm/lib -L/opt/rocm/hsa/lib -L/opt/rocm/atmi/lib
HSA_INCLUDES=-I/opt/rocm/atmi/include
else
# HSA locations
CLOC=$(THIRD_PARTY_DIR)/hsa/cloc/bin/cloc.sh
Expand Down
19 changes: 19 additions & 0 deletions make/tasks/Makefile.atmi
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2004-2015 Cray Inc.
# Other additional copyright holders may be indicated within.
#
# The entirety of this work is licensed under the Apache License,
# Version 2.0 (the "License"); you may not use this file except
# in compliance with the License.
#
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

RUNTIME_INCLS += -I$(QTHREAD_INCLUDE_DIR)
CHPL_MAKE_THREADS=none
Loading