Skip to content

[FEA] Matrix shift rows and columns #2634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: branch-25.06
Choose a base branch
from
120 changes: 120 additions & 0 deletions cpp/include/raft/matrix/detail/shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace raft::matrix::detail {

template <typename T>
RAFT_KERNEL col_right_shift(T* in_out, size_t n_rows, size_t n_cols, size_t k, T val)
{
size_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < n_rows) {
size_t base_idx = row * n_cols;
size_t cols_to_shift = n_cols - k;
for (size_t i = 1; i <= cols_to_shift; i++) {
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)];
}
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = val;
}
}
}

template <typename math_t, typename matrix_idx_t>
void col_right_shift(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
math_t val,
size_t k)
{
size_t n_rows = in_out.extent(0);
size_t n_cols = in_out.extent(1);
size_t TPB = 256;
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB);

col_right_shift<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>(
in_out.data_handle(), n_rows, n_cols, k, val);
}

template <typename T>
RAFT_KERNEL col_right_shift(T* in_out, size_t n_rows, size_t n_cols, size_t k, const T* values)
{
size_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < n_rows) {
size_t base_idx = row * n_cols;
size_t cols_to_shift = n_cols - k;
for (size_t i = 1; i <= cols_to_shift; i++) {
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)];
}
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = values[row * k + i];
}
}
}

template <typename math_t, typename matrix_idx_t>
void col_right_shift(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
raft::device_matrix_view<const math_t, matrix_idx_t> values)
{
size_t n_rows = in_out.extent(0);
size_t n_cols = in_out.extent(1);
size_t TPB = 256;
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB);

size_t k = values.extent(1);

col_right_shift<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>(
in_out.data_handle(), n_rows, n_cols, k, values.data_handle());
return;
}

template <typename T>
RAFT_KERNEL col_right_shift_self(T* in_out, size_t n_rows, size_t n_cols, size_t k)
{
size_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < n_rows) {
size_t base_idx = row * n_cols;
size_t cols_to_shift = n_cols - k;
for (size_t i = 1; i <= cols_to_shift; i++) {
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)];
}
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = row;
}
}
}

template <typename math_t, typename matrix_idx_t>
void col_right_shift_self(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
size_t k)
{
size_t n_rows = in_out.extent(0);
size_t n_cols = in_out.extent(1);
size_t TPB = 256;
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB);

col_right_shift_self<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>(
in_out.data_handle(), n_rows, n_cols, k);
return;
}

} // namespace raft::matrix::detail
80 changes: 80 additions & 0 deletions cpp/include/raft/matrix/shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/detail/shift.cuh>

namespace raft::matrix {

/**
* @brief col_shift: in-place shifts all columns by k columns to the right and fills the first k
* columns in with "val"
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] val: value to fill in the first k columns (same for all rows)
* @param[in] k: shift size
*/
template <typename math_t, typename matrix_idx_t>
void col_right_shift(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
math_t val,
size_t k)
{
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k,
"Shift size k should be smaller than the number of columns in matrix.");
detail::col_right_shift(handle, in_out, val, k);
}

/**
* @brief col_shift: in-place shifts all columns by k columns to the right and replaces the first
* n_rows x k part of the in_out matrix with the "values" matrix
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] values: value matrix of size (n_rows x k) to fill in the first k columns
*/
template <typename math_t, typename matrix_idx_t>
void col_right_shift(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
raft::device_matrix_view<const math_t, matrix_idx_t> values)
{
RAFT_EXPECTS(in_out.extent(0) == values.extent(0),
"in_out matrix and the values matrix should haver the same number of rows");
RAFT_EXPECTS(in_out.extent(1) > values.extent(1),
"number of columns in in_out should be > number of columns in values");
detail::col_right_shift(handle, in_out, values);
}

/**
* @brief col_shift: in-place shifts all columns by k columns to the right and fills the first k
* columns with its row id
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] k: shift size
*/
template <typename math_t, typename matrix_idx_t>
void col_right_shift_self(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
size_t k)
{
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k,
"Shift size k should be smaller than the number of columns in matrix.");
detail::col_right_shift_self(handle, in_out, k);
}

} // namespace raft::matrix
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ if(BUILD_TESTS)
matrix/diagonal.cu
matrix/gather.cu
matrix/scatter.cu
matrix/shift.cu
matrix/eye.cu
matrix/linewise_op.cu
matrix/math.cu
Expand Down
Loading