Skip to content

[FEA] Matrix column shift #2634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions cpp/include/raft/matrix/detail/shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resources.hpp>

namespace raft::matrix {
enum ShiftDirection { TOWARDS_ZERO, TOWARDS_END };
enum ShiftType { ROW, COL };
} // namespace raft::matrix

namespace raft::matrix::detail {
enum FillType { SINGLE_VAL, VALUES, SELF };

template <typename T, typename fill_value, FillType fill_type>
RAFT_KERNEL col_shift_towards_end(
T* in_out, size_t n_rows, size_t n_cols, size_t k, fill_value value)
{
size_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < n_rows) {
size_t base_idx = row * n_cols;
for (size_t target_col = n_cols - 1; target_col >= k; target_col--) {
in_out[base_idx + target_col] = in_out[base_idx + (target_col - k)];
}
if constexpr (fill_type == FillType::SINGLE_VAL) {
T val = static_cast<T>(value);
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = val;
}
} else if constexpr (fill_type == FillType::VALUES) {
const T* values = static_cast<const T*>(value);
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = values[row * k + i];
}
} else { // FillType::SELF
for (size_t i = 0; i < k; i++) {
in_out[base_idx + i] = static_cast<T>(row);
}
}
}
}

template <typename T, typename fill_value, FillType fill_type>
RAFT_KERNEL col_shift_towards_zero(
T* in_out, size_t n_rows, size_t n_cols, size_t k, fill_value value)
{
size_t row = blockIdx.x * blockDim.x + threadIdx.x;
if (row < n_rows) {
size_t base_idx = row * n_cols;
for (size_t target_col = 0; target_col < n_cols - k; target_col++) {
in_out[base_idx + target_col] = in_out[base_idx + (target_col + k)];
}
size_t base_col = n_cols - k;
if constexpr (fill_type == FillType::SINGLE_VAL) {
T val = static_cast<T>(value);
for (size_t i = 0; i < k; i++) {
in_out[base_idx + base_col + i] = val;
}
} else if constexpr (fill_type == FillType::VALUES) {
const T* values = static_cast<const T*>(value);
for (size_t i = 0; i < k; i++) {
in_out[base_idx + base_col + i] = values[row * k + i];
}
} else { // FillType::SELF
for (size_t i = 0; i < k; i++) {
in_out[base_idx + base_col + i] = static_cast<T>(row);
}
}
}
}

template <typename T, typename fill_value, FillType fill_type>
RAFT_KERNEL row_shift_towards_end(
T* in_out, size_t n_rows, size_t n_cols, size_t k, fill_value value)
{
size_t col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < n_cols) {
for (size_t target_row = n_rows - 1; target_row >= k; target_row--) {
in_out[target_row * n_cols + col] = in_out[(target_row - k) * n_cols + col];
}

if constexpr (fill_type == FillType::SINGLE_VAL) {
T val = static_cast<T>(value);
for (size_t i = 0; i < k; i++) {
in_out[i * n_cols + col] = val;
}
} else if constexpr (fill_type == FillType::VALUES) {
const T* values = static_cast<const T*>(value);
for (size_t i = 0; i < k; i++) {
in_out[i * n_cols + col] = values[i * n_cols + col];
}
} else { // FillType::SELF
for (size_t i = 0; i < k; i++) {
in_out[i * n_cols + col] = static_cast<T>(col);
}
}
}
}

template <typename T, typename fill_value, FillType fill_type>
RAFT_KERNEL row_shift_towards_zero(
T* in_out, size_t n_rows, size_t n_cols, size_t k, fill_value value)
{
size_t col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < n_cols) {
for (size_t target_row = 0; target_row < n_rows - k; target_row++) {
in_out[target_row * n_cols + col] = in_out[(target_row + k) * n_cols + col];
}
size_t base_row = n_rows - k;
if constexpr (fill_type == FillType::SINGLE_VAL) {
T val = static_cast<T>(value);
for (size_t i = 0; i < k; i++) {
in_out[(base_row + i) * n_cols + col] = val;
}
} else if constexpr (fill_type == FillType::VALUES) {
const T* values = static_cast<const T*>(value);
for (size_t i = 0; i < k; i++) {
in_out[(base_row + i) * n_cols + col] = values[i * n_cols + col];
}
} else { // FillType::SELF
for (size_t i = 0; i < k; i++) {
in_out[(base_row + i) * n_cols + col] = static_cast<T>(col);
}
}
}
}

template <typename ValueT, typename IdxT, typename fill_value, FillType fill_type>
void shift_dispatch(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
fill_value value,
size_t k,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
size_t n_rows = in_out.extent(0);
size_t n_cols = in_out.extent(1);
size_t TPB = 256;
auto stream = raft::resource::get_cuda_stream(handle);

if (shift_type == ShiftType::COL) {
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB);
if (shift_direction == ShiftDirection::TOWARDS_ZERO) {
col_shift_towards_zero<ValueT, fill_value, fill_type>
<<<num_blocks, TPB, 0, stream>>>(in_out.data_handle(), n_rows, n_cols, k, value);
} else { // ShiftDirection::TOWARDS_END
col_shift_towards_end<ValueT, fill_value, fill_type>
<<<num_blocks, TPB, 0, stream>>>(in_out.data_handle(), n_rows, n_cols, k, value);
}
} else { // ShiftType::ROW
size_t num_blocks = static_cast<size_t>((n_cols + TPB) / TPB);
if (shift_direction == ShiftDirection::TOWARDS_ZERO) {
row_shift_towards_zero<ValueT, fill_value, fill_type>
<<<num_blocks, TPB, 0, stream>>>(in_out.data_handle(), n_rows, n_cols, k, value);
} else { // ShiftDirection::TOWARDS_END
row_shift_towards_end<ValueT, fill_value, fill_type>
<<<num_blocks, TPB, 0, stream>>>(in_out.data_handle(), n_rows, n_cols, k, value);
}
}
raft::resource::sync_stream(handle);
}

template <typename ValueT, typename IdxT>
void shift(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
ValueT val,
size_t k,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
shift_dispatch<ValueT, IdxT, ValueT, SINGLE_VAL>(
handle, in_out, val, k, shift_direction, shift_type);
}

template <typename ValueT, typename IdxT>
void shift(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
raft::device_matrix_view<const ValueT, IdxT> values,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
size_t k = shift_type == ShiftType::COL ? values.extent(1) : values.extent(0);
shift_dispatch<ValueT, IdxT, const ValueT*, VALUES>(
handle, in_out, values.data_handle(), k, shift_direction, shift_type);
}

template <typename ValueT, typename IdxT>
void shift_self(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
size_t k,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
shift_dispatch<ValueT, IdxT, ValueT, SELF>(
handle, in_out, static_cast<ValueT>(0), k, shift_direction, shift_type);
}

} // namespace raft::matrix::detail
122 changes: 122 additions & 0 deletions cpp/include/raft/matrix/shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/detail/shift.cuh>

namespace raft::matrix {

/**
* @brief In-place operation. Shifts rows or columns to shift_direction by k, and fills the empty
* values with "val" Example 1) if we have a row-major 3x4 matrix in_out = [[1,2,3,4], [5,6,7,8],
* [9,10,11,12]], val=100, k=2, shift_direction = ShiftDirection::TOWARDS_END and shift_type =
* ShiftType::COL, then we end up with [[100,100,1,2], [100,100,5,6], [100,100,9,10]]. Example 2) if
* we have a row-major 3x4 matrix in_out = [[1,2,3,4], [5,6,7,8], [9,10,11,12]], val=100, k=1,
* shift_direction = ShiftDirection::TOWARDS_ZERO and shift_type = ShiftType::ROW, then we end up
* with [[5,6,7,8], [9,10,11,12], [100,100,100,100]]
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] val: value to fill in the first k columns (same for all rows)
* @param[in] k: shift size
* @param[in] shift_direction: ShiftDirection::TOWARDS_ZERO shifts towards the 0th row/col
* direction, and ShiftDirection::TOWARDS_END shifts towards the (nrow-1)th row/col direction
* @param[in] shift_type: ShiftType::ROW shifts rows and ShiftType::COL shift columns
*/
template <typename ValueT, typename IdxT>
void shift(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
ValueT val,
size_t k,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
if (shift_type == ShiftType::COL) {
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k,
"Shift size k should be smaller than the number of columns in matrix.");
} else {
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(0)) > k,
"Shift size k should be smaller than the number of rows in matrix.");
}
detail::shift(handle, in_out, val, k, shift_direction, shift_type);
}

/**
* @brief col_shift: in-place shifts all columns by k columns to the right and replaces the first
* n_rows x k part of the in_out matrix with the "values" matrix
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] values: value matrix of size (n_rows x k) to fill in the first k columns
* @param[in] shift_direction: ShiftDirection::TOWARDS_ZERO shifts towards the 0th row/col
* direction, and ShiftDirection::TOWARDS_END shifts towards the (nrow-1)th row/col direction
* @param[in] shift_type: ShiftType::ROW shifts rows and ShiftType::COL shift columns
*/
template <typename ValueT, typename IdxT>
void shift(raft::resources const& handle,
raft::device_matrix_view<ValueT, IdxT, row_major> in_out,
raft::device_matrix_view<const ValueT, IdxT> values,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
if (shift_type == ShiftType::COL) {
RAFT_EXPECTS(in_out.extent(0) == values.extent(0),
"in_out matrix and the values matrix should haver the same number of rows when "
"using shift_type=ShiftType::COL");
RAFT_EXPECTS(in_out.extent(1) > values.extent(1),
"number of columns in in_out should be > number of columns in values when using "
"shift_type=ShiftType::COL");
} else {
RAFT_EXPECTS(in_out.extent(1) == values.extent(1),
"in_out matrix and the values matrix should haver the same number of cols when "
"using shift_type=ShiftType::ROW");
RAFT_EXPECTS(in_out.extent(0) > values.extent(0),
"number of rows in in_out should be > number of rows in values when using "
"shift_type=ShiftType::ROW");
}

detail::shift(handle, in_out, values, shift_direction, shift_type);
}

/**
* @brief col_shift: in-place shifts all columns by k columns to the right and fills the first k
* columns with its row id
* @param[in] handle: raft handle
* @param[in out] in_out: input matrix of size (n_rows, n_cols)
* @param[in] k: shift size
* @param[in] shift_direction: ShiftDirection::TOWARDS_ZERO shifts towards the 0th row/col
* direction, and ShiftDirection::TOWARDS_END shifts towards the (nrow-1)th row/col direction
* @param[in] shift_type: ShiftType::ROW shifts rows and ShiftType::COL shift columns
*/
template <typename math_t, typename matrix_idx_t>
void shift_self(raft::resources const& handle,
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out,
size_t k,
ShiftDirection shift_direction = ShiftDirection::TOWARDS_END,
ShiftType shift_type = ShiftType::COL)
{
if (shift_type == ShiftType::COL) {
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k,
"Shift size k should be smaller than the number of columns in matrix.");
} else {
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(0)) > k,
"Shift size k should be smaller than the number of rows in matrix.");
}
detail::shift_self(handle, in_out, k, shift_direction, shift_type);
}

} // namespace raft::matrix
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ if(BUILD_TESTS)
matrix/diagonal.cu
matrix/gather.cu
matrix/scatter.cu
matrix/shift.cu
matrix/eye.cu
matrix/linewise_op.cu
matrix/math.cu
Expand Down
Loading
Loading