Skip to content

Exploring Global Reduce Optimization: Could Reducing Memory Roundtrips Improve Performance? #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions marlin/marlin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -527,18 +527,14 @@ __global__ void Marlin(
int row = (threadIdx.x % 32) / 4;

if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,
// hence we also use async-copies even though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
);
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
int4 c_val = C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)];
sh[c_sh_wr + c_sh_wr_delta * i] = c_val;
}
}
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
}

#pragma unroll
Expand Down