-
Notifications
You must be signed in to change notification settings - Fork 208
[FEA] Matrix shift rows and columns #2634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
dfa40f3
col shift logic
jinsolp 5d83e1a
shift logic
jinsolp ead3cbc
static cast
jinsolp fad8b09
redundant comp
jinsolp 616e01f
Merge branch 'branch-25.06' into matrix-shift
jinsolp 4fca0f7
update comments
jinsolp 35ef5e8
Merge branch 'branch-25.06' into matrix-shift
cjnolet 124c1a8
generalized shift logic
jinsolp ae8b3d2
Merge branch 'matrix-shift' of https://github.com/jinsolp/raft into m…
jinsolp a3e4e49
Merge branch 'branch-25.06' into matrix-shift
jinsolp dee6f41
Merge branch 'branch-25.06' into matrix-shift
jinsolp ff02241
detailed documentation
jinsolp b1a29ca
change to towards_beginning
jinsolp 2f98b93
Merge branch 'branch-25.06' into matrix-shift
jinsolp be95a4a
addressing reviews
jinsolp 7f12f71
shift types
jinsolp 3c52fba
Merge branch 'branch-25.06' into matrix-shift
jinsolp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/detail/macros.hpp> | ||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/resources.hpp> | ||
|
||
namespace raft::matrix::detail { | ||
|
||
template <typename T> | ||
RAFT_KERNEL col_right_shift(T* in_out, size_t n_rows, size_t n_cols, size_t k, T val) | ||
{ | ||
size_t row = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (row < n_rows) { | ||
size_t base_idx = row * n_cols; | ||
size_t cols_to_shift = n_cols - k; | ||
for (size_t i = 1; i <= cols_to_shift; i++) { | ||
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)]; | ||
} | ||
for (size_t i = 0; i < k; i++) { | ||
in_out[base_idx + i] = val; | ||
} | ||
} | ||
} | ||
|
||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
math_t val, | ||
size_t k) | ||
{ | ||
size_t n_rows = in_out.extent(0); | ||
size_t n_cols = in_out.extent(1); | ||
size_t TPB = 256; | ||
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB); | ||
|
||
col_right_shift<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>( | ||
in_out.data_handle(), n_rows, n_cols, k, val); | ||
} | ||
|
||
template <typename T> | ||
RAFT_KERNEL col_right_shift(T* in_out, size_t n_rows, size_t n_cols, size_t k, const T* values) | ||
{ | ||
size_t row = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (row < n_rows) { | ||
size_t base_idx = row * n_cols; | ||
size_t cols_to_shift = n_cols - k; | ||
for (size_t i = 1; i <= cols_to_shift; i++) { | ||
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)]; | ||
} | ||
for (size_t i = 0; i < k; i++) { | ||
in_out[base_idx + i] = values[row * k + i]; | ||
} | ||
} | ||
} | ||
|
||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
raft::device_matrix_view<const math_t, matrix_idx_t> values) | ||
{ | ||
size_t n_rows = in_out.extent(0); | ||
size_t n_cols = in_out.extent(1); | ||
size_t TPB = 256; | ||
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB); | ||
|
||
size_t k = values.extent(1); | ||
|
||
col_right_shift<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>( | ||
in_out.data_handle(), n_rows, n_cols, k, values.data_handle()); | ||
return; | ||
} | ||
|
||
template <typename T> | ||
RAFT_KERNEL col_right_shift_self(T* in_out, size_t n_rows, size_t n_cols, size_t k) | ||
{ | ||
size_t row = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (row < n_rows) { | ||
size_t base_idx = row * n_cols; | ||
size_t cols_to_shift = n_cols - k; | ||
for (size_t i = 1; i <= cols_to_shift; i++) { | ||
in_out[base_idx + (n_cols - i)] = in_out[base_idx + (n_cols - k - i)]; | ||
} | ||
for (size_t i = 0; i < k; i++) { | ||
in_out[base_idx + i] = row; | ||
} | ||
} | ||
} | ||
|
||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift_self(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
size_t k) | ||
{ | ||
size_t n_rows = in_out.extent(0); | ||
size_t n_cols = in_out.extent(1); | ||
size_t TPB = 256; | ||
size_t num_blocks = static_cast<size_t>((n_rows + TPB) / TPB); | ||
|
||
col_right_shift_self<math_t><<<num_blocks, TPB, 0, raft::resource::get_cuda_stream(handle)>>>( | ||
in_out.data_handle(), n_rows, n_cols, k); | ||
return; | ||
} | ||
|
||
} // namespace raft::matrix::detail |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/matrix/detail/shift.cuh> | ||
|
||
namespace raft::matrix { | ||
|
||
/** | ||
* @brief col_shift: in-place shifts all columns by k columns to the right and fills the first k | ||
* columns in with "val" | ||
* @param[in] handle: raft handle | ||
* @param[in out] in_out: input matrix of size (n_rows, n_cols) | ||
* @param[in] val: value to fill in the first k columns (same for all rows) | ||
* @param[in] k: shift size | ||
*/ | ||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
math_t val, | ||
size_t k) | ||
{ | ||
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k, | ||
"Shift size k should be smaller than the number of columns in matrix."); | ||
detail::col_right_shift(handle, in_out, val, k); | ||
} | ||
|
||
/** | ||
* @brief col_shift: in-place shifts all columns by k columns to the right and replaces the first | ||
* n_rows x k part of the in_out matrix with the "values" matrix | ||
* @param[in] handle: raft handle | ||
* @param[in out] in_out: input matrix of size (n_rows, n_cols) | ||
* @param[in] values: value matrix of size (n_rows x k) to fill in the first k columns | ||
*/ | ||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
raft::device_matrix_view<const math_t, matrix_idx_t> values) | ||
{ | ||
RAFT_EXPECTS(in_out.extent(0) == values.extent(0), | ||
"in_out matrix and the values matrix should haver the same number of rows"); | ||
RAFT_EXPECTS(in_out.extent(1) > values.extent(1), | ||
"number of columns in in_out should be > number of columns in values"); | ||
detail::col_right_shift(handle, in_out, values); | ||
} | ||
|
||
/** | ||
* @brief col_shift: in-place shifts all columns by k columns to the right and fills the first k | ||
* columns with its row id | ||
* @param[in] handle: raft handle | ||
* @param[in out] in_out: input matrix of size (n_rows, n_cols) | ||
* @param[in] k: shift size | ||
*/ | ||
template <typename math_t, typename matrix_idx_t> | ||
void col_right_shift_self(raft::resources const& handle, | ||
raft::device_matrix_view<math_t, matrix_idx_t, row_major> in_out, | ||
size_t k) | ||
{ | ||
RAFT_EXPECTS(static_cast<size_t>(in_out.extent(1)) > k, | ||
"Shift size k should be smaller than the number of columns in matrix."); | ||
detail::col_right_shift_self(handle, in_out, k); | ||
} | ||
|
||
} // namespace raft::matrix |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.