Skip to content

Commit e7d98ca

Browse files
sryapfacebook-github-bot
authored andcommitted
Refactor bounds_check_indices (pytorch#4049)
Summary: X-link: facebookresearch/FBGEMM#1140 Pull Request resolved: pytorch#4049 This diff refactors the `bounds_check_indices` operator: - Moved the common host code into the common host function - Moved the common device code into `embedding_bounds_check_common.cuh` - Removed `const` from scalar args Reviewed By: q10 Differential Revision: D73905710 fbshipit-source-id: 9c0d7c94b617b1a453904e0170cc3da5d9ab9bf0
1 parent 0c177e9 commit e7d98ca

File tree

4 files changed

+117
-150
lines changed

4 files changed

+117
-150
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp

+65-16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <torch/library.h>
1414

1515
#include "fbgemm_gpu/utils/ops_utils.h"
16+
#include "fbgemm_gpu/utils/tensor_utils.h"
1617

1718
#include "fbgemm_gpu/config/feature_gates.h"
1819
#include "fbgemm_gpu/embedding_common.h"
@@ -25,29 +26,37 @@ void _bounds_check_indices_cuda_v1(
2526
Tensor& rows_per_table,
2627
Tensor& indices,
2728
Tensor& offsets,
28-
int64_t bounds_check_mode,
29+
fbgemm_gpu::BoundsCheckMode bounds_check_mode,
2930
Tensor& warning,
3031
const std::optional<Tensor>& weights,
3132
const std::optional<Tensor>& B_offsets,
32-
const int64_t max_B,
33+
int64_t max_B,
3334
const std::optional<Tensor>& b_t_map,
34-
const int32_t info_B_num_bits,
35-
const uint32_t info_B_mask,
36-
const bool prefetch_pipeline);
35+
int32_t info_B_num_bits,
36+
uint32_t info_B_mask,
37+
int64_t T,
38+
int64_t B,
39+
int64_t total_B,
40+
bool vbe,
41+
bool prefetch_pipeline);
3742

3843
void _bounds_check_indices_cuda_v2(
3944
Tensor& rows_per_table,
4045
Tensor& indices,
4146
Tensor& offsets,
42-
int64_t bounds_check_mode,
47+
fbgemm_gpu::BoundsCheckMode bounds_check_mode,
4348
Tensor& warning,
4449
const std::optional<Tensor>& weights,
4550
const std::optional<Tensor>& B_offsets,
46-
const int64_t max_B,
51+
int64_t max_B,
4752
const std::optional<Tensor>& b_t_map,
48-
const int32_t info_B_num_bits,
49-
const uint32_t info_B_mask,
50-
const bool prefetch_pipeline);
53+
int32_t info_B_num_bits,
54+
uint32_t info_B_mask,
55+
int64_t T,
56+
int64_t B,
57+
int64_t total_B,
58+
bool vbe,
59+
bool prefetch_pipeline);
5160

5261
///@ingroup embedding-cuda
5362
void bounds_check_indices_cuda(
@@ -58,12 +67,12 @@ void bounds_check_indices_cuda(
5867
Tensor& warning,
5968
const std::optional<Tensor>& weights,
6069
const std::optional<Tensor>& B_offsets,
61-
const int64_t max_B,
70+
int64_t max_B,
6271
const std::optional<Tensor>& b_t_map,
63-
const int64_t info_B_num_bits,
64-
const int64_t info_B_mask,
65-
const int8_t bounds_check_version,
66-
const bool prefetch_pipeline) {
72+
int64_t info_B_num_bits,
73+
int64_t info_B_mask,
74+
int8_t bounds_check_version,
75+
bool prefetch_pipeline) {
6776
TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2);
6877
const static bool use_v2 =
6978
fbgemm_gpu::config::is_feature_enabled(
@@ -73,25 +82,65 @@ void bounds_check_indices_cuda(
7382
use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1;
7483
const auto bounds_check_mode_ =
7584
static_cast<fbgemm_gpu::BoundsCheckMode>(bounds_check_mode);
85+
7686
TORCH_CHECK(
7787
bounds_check_mode_ == fbgemm_gpu::BoundsCheckMode::WARNING ||
7888
bounds_check_mode_ == fbgemm_gpu::BoundsCheckMode::FATAL ||
7989
bounds_check_mode_ == fbgemm_gpu::BoundsCheckMode::IGNORE,
8090
"bounds_check_indices: bounds_check_mode=",
8191
bounds_check_mode,
8292
" is not supported");
93+
94+
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
95+
rows_per_table, indices, offsets, warning, weights, B_offsets, b_t_map);
96+
97+
TENSOR_NDIM_EQUALS(rows_per_table, 1);
98+
TENSOR_NDIM_EQUALS(indices, 1);
99+
TENSOR_NDIM_EQUALS(offsets, 1);
100+
TENSOR_NDIM_EQUALS(warning, 1);
101+
102+
const auto T = rows_per_table.size(0);
103+
const auto total_B = offsets.size(0) - 1;
104+
const auto B = total_B / T;
105+
if (total_B == 0 || T == 0) {
106+
return;
107+
}
108+
109+
const auto vbe = B_offsets.has_value();
110+
if (vbe) {
111+
TENSOR_NDIM_EQUALS(B_offsets.value(), 1);
112+
TORCH_CHECK(max_B >= 0);
113+
} else if (!vbe) {
114+
TORCH_CHECK(
115+
offsets.size(0) == B * T + 1,
116+
"offsets size " + std::to_string(offsets.size(0)) +
117+
" is not equal to B (" + std::to_string(B) + ") * T (" +
118+
std::to_string(T) + ") + 1");
119+
}
120+
if (weights.has_value() && weights->numel() != 0) {
121+
const auto num_indices = indices.size(0);
122+
TORCH_CHECK(
123+
weights->size(0) == num_indices,
124+
"weights size " + std::to_string(weights->size(0)) +
125+
" is not equal to indices size " + std::to_string(num_indices));
126+
}
127+
83128
bounds_check_indices_fn(
84129
rows_per_table,
85130
indices,
86131
offsets,
87-
bounds_check_mode,
132+
bounds_check_mode_,
88133
warning,
89134
weights,
90135
B_offsets,
91136
max_B,
92137
b_t_map,
93138
static_cast<int32_t>(info_B_num_bits),
94139
static_cast<uint32_t>(info_B_mask),
140+
T,
141+
B,
142+
total_B,
143+
vbe,
95144
prefetch_pipeline);
96145
}
97146
// Deprecated for fb namespace! Please use fbgemm namespace instead!

fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu

+12-73
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
10-
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
11-
12-
#include <c10/cuda/CUDADeviceAssertion.h>
13-
#include <c10/cuda/CUDAException.h>
14-
15-
using Tensor = at::Tensor;
16-
using namespace fbgemm_gpu;
17-
18-
template <typename index_t>
19-
__device__ void adjust_offset_kernel(
20-
index_t& indices_start,
21-
index_t& indices_end,
22-
const index_t num_indices,
23-
index_t* const offset_acc_start,
24-
index_t* const offset_acc_end) {
25-
indices_start =
26-
std::max(static_cast<index_t>(0), std::min(indices_start, num_indices));
27-
indices_end = std::max(indices_start, std::min(indices_end, num_indices));
28-
*offset_acc_start = indices_start;
29-
*offset_acc_end = indices_end;
30-
}
9+
#include "fbgemm_gpu/utils/embedding_bounds_check_common.cuh"
3110

3211
template <typename index_t, bool vbe>
3312
__global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
@@ -37,7 +16,7 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
3716
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
3817
const int32_t* const B_offsets, // Use a raw pointer to avoid creating a
3918
// dummy PackedTensorAccessor
40-
const int64_t bounds_check_mode_,
19+
BoundsCheckMode bounds_check_mode,
4120
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> warning,
4221
FixedDivisor fd,
4322
TORCH_DSA_KERNEL_ARGS) {
@@ -71,8 +50,6 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
7150
B = total_B / T;
7251
}
7352

74-
const auto bounds_check_mode =
75-
static_cast<BoundsCheckMode>(bounds_check_mode_);
7653
const auto num_rows = rows_per_table[t];
7754
auto indices_start = offsets[b_t];
7855
auto indices_end = offsets[b_t + 1];
@@ -179,70 +156,32 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v1(
179156
}
180157
}
181158

182-
void check_weights_dim_matches_indices(
183-
const std::optional<Tensor>& weights,
184-
int64_t num_indices) {
185-
if (weights.has_value() && weights->numel() != 0) {
186-
TORCH_CHECK(
187-
weights.value().size(0) == num_indices,
188-
"weights size " + std::to_string(weights.value().size(0)) +
189-
" is not equal to indices size " + std::to_string(num_indices));
190-
}
191-
}
192-
193159
void _bounds_check_indices_cuda_v1(
194160
Tensor& rows_per_table,
195161
Tensor& indices,
196162
Tensor& offsets,
197-
int64_t bounds_check_mode_,
163+
BoundsCheckMode bounds_check_mode,
198164
Tensor& warning,
199165
const std::optional<Tensor>& weights,
200166
const std::optional<Tensor>& B_offsets,
201-
const int64_t max_B,
167+
int64_t max_B,
202168
const std::optional<Tensor>& /*b_t_map*/,
203-
const int32_t /*info_b_num_bits*/,
204-
const uint32_t /*info_B_mask*/,
205-
const bool prefetch_pipeline) {
206-
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
207-
rows_per_table, indices, offsets, warning, weights, B_offsets);
208-
TENSOR_NDIM_EQUALS(rows_per_table, 1);
209-
TENSOR_NDIM_EQUALS(indices, 1);
210-
TENSOR_NDIM_EQUALS(offsets, 1);
211-
TENSOR_NDIM_EQUALS(warning, 1);
169+
int32_t /*info_b_num_bits*/,
170+
uint32_t /*info_B_mask*/,
171+
int64_t T,
172+
int64_t B,
173+
int64_t /*total_B*/,
174+
bool vbe,
175+
bool prefetch_pipeline) {
212176
TORCH_CHECK(
213177
!prefetch_pipeline,
214178
"bounds_check_indices_v1 does not support prefetch_pipeline=true")
215179

216-
const auto vbe = B_offsets.has_value();
217-
if (vbe) {
218-
TENSOR_NDIM_EQUALS(B_offsets.value(), 1);
219-
}
220-
221180
CUDA_DEVICE_GUARD(rows_per_table);
222181

223-
const int32_t T = rows_per_table.size(0);
224-
const int32_t total_B = offsets.size(0) - 1;
225-
const int32_t B = (total_B) / T;
226-
if (total_B == 0 || T == 0) {
227-
return;
228-
}
229-
const auto bounds_check_mode =
230-
static_cast<BoundsCheckMode>(bounds_check_mode_);
231182
if (bounds_check_mode == BoundsCheckMode::WARNING) {
232183
warning.zero_();
233184
}
234-
const int64_t num_indices = indices.size(0);
235-
236-
if (vbe) {
237-
TORCH_CHECK(max_B >= 0);
238-
} else {
239-
TORCH_CHECK(
240-
offsets.size(0) == B * T + 1,
241-
"offsets size " + std::to_string(offsets.size(0)) +
242-
" is not equal to B (" + std::to_string(B) + ") * T (" +
243-
std::to_string(T) + ") + 1");
244-
}
245-
check_weights_dim_matches_indices(weights, num_indices);
246185

247186
constexpr size_t kNumThreads = 256;
248187
const auto max_B_ = vbe ? max_B : B;
@@ -265,7 +204,7 @@ void _bounds_check_indices_cuda_v1(
265204
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
266205
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
267206
vbe ? B_offsets.value().data_ptr<int32_t>() : nullptr,
268-
bounds_check_mode_,
207+
bounds_check_mode,
269208
MAKE_PTA_WITH_NAME(func_name, warning, int64_t, 1, 32),
270209
FixedDivisor(max_B_));
271210
});

fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu

+10-61
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
10-
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
11-
12-
#include <c10/cuda/CUDADeviceAssertion.h>
13-
#include <c10/cuda/CUDAException.h>
14-
15-
using Tensor = at::Tensor;
16-
using namespace fbgemm_gpu;
17-
18-
template <typename index_t>
19-
__device__ void adjust_offset_kernel(
20-
index_t& indices_start,
21-
index_t& indices_end,
22-
const index_t num_indices,
23-
index_t* const offset_acc_start,
24-
index_t* const offset_acc_end) {
25-
indices_start =
26-
std::max(static_cast<index_t>(0), std::min(indices_start, num_indices));
27-
indices_end = std::max(indices_start, std::min(indices_end, num_indices));
28-
if (threadIdx.x == 0) {
29-
*offset_acc_start = indices_start;
30-
*offset_acc_end = indices_end;
31-
}
32-
}
9+
#include "fbgemm_gpu/utils/embedding_bounds_check_common.cuh"
3310

3411
template <typename index_t, bool vbe, BoundsCheckMode bounds_check_mode>
3512
__global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
@@ -195,57 +172,29 @@ void _bounds_check_indices_cuda_v2(
195172
Tensor& rows_per_table,
196173
Tensor& indices,
197174
Tensor& offsets,
198-
int64_t bounds_check_mode_,
175+
BoundsCheckMode bounds_check_mode,
199176
Tensor& warning,
200177
const std::optional<Tensor>& weights,
201178
const std::optional<Tensor>& B_offsets,
202-
const int64_t /*max_B*/,
179+
int64_t /*max_B*/,
203180
const std::optional<Tensor>& b_t_map,
204-
const int32_t info_B_num_bits,
205-
const uint32_t info_B_mask,
206-
const bool prefetch_pipeline) {
207-
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
208-
rows_per_table, indices, offsets, warning, weights, B_offsets, b_t_map);
209-
TENSOR_NDIM_EQUALS(rows_per_table, 1);
210-
TENSOR_NDIM_EQUALS(indices, 1);
211-
TENSOR_NDIM_EQUALS(offsets, 1);
212-
TENSOR_NDIM_EQUALS(warning, 1);
213-
214-
const auto vbe = B_offsets.has_value();
181+
int32_t info_B_num_bits,
182+
uint32_t info_B_mask,
183+
int64_t /*T*/,
184+
int64_t B,
185+
int64_t total_B,
186+
bool vbe,
187+
bool prefetch_pipeline) {
215188
if (vbe) {
216-
TENSOR_NDIM_EQUALS(B_offsets.value(), 1);
217189
TORCH_CHECK(b_t_map.has_value());
218190
TENSOR_NDIM_EQUALS(b_t_map.value(), 1);
219191
}
220192

221193
CUDA_DEVICE_GUARD(rows_per_table);
222194

223-
const int32_t T = rows_per_table.size(0);
224-
const int32_t total_B = offsets.size(0) - 1;
225-
const int32_t B = (total_B) / T;
226-
if (total_B == 0 || T == 0) {
227-
return;
228-
}
229-
const auto bounds_check_mode =
230-
static_cast<BoundsCheckMode>(bounds_check_mode_);
231195
if (bounds_check_mode == BoundsCheckMode::WARNING) {
232196
warning.zero_();
233197
}
234-
const int64_t num_indices = indices.size(0);
235-
236-
if (!vbe) {
237-
TORCH_CHECK(
238-
offsets.size(0) == B * T + 1,
239-
"offsets size " + std::to_string(offsets.size(0)) +
240-
" is not equal to B (" + std::to_string(B) + ") * T (" +
241-
std::to_string(T) + ") + 1");
242-
}
243-
if (weights.has_value()) {
244-
TORCH_CHECK(
245-
weights.value().size(0) == num_indices,
246-
"weights size " + std::to_string(weights.value().size(0)) +
247-
" is not equal to indices size " + std::to_string(num_indices));
248-
}
249198

250199
constexpr size_t kNumThreads = 1024;
251200
auto grid_dim =

0 commit comments

Comments
 (0)