Skip to content

Commit ee517cd

Browse files
authored
Merge pull request #3907 from naxingyu/sync-pybind11-with-master
Sync pybind11 with master
2 parents a1091f4 + d0f7b5e commit ee517cd

10 files changed

+167
-79
lines changed

src/cudafeat/feature-online-batched-cmvn-cuda-kernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
#include <cub/cub.cuh>
1919
#include "cudafeat/feature-online-batched-cmvn-cuda-kernels.h"
2020

21-
__device__ inline float2 operator-(const float2 &a, const float2 &b) {
21+
__host__ __device__ inline float2 operator-(const float2 &a, const float2 &b) {
2222
float2 retval;
2323
retval.x = a.x - b.x;
2424
retval.y = a.y - b.y;
2525
return retval;
2626
}
27-
__device__ inline float2 operator+(const float2 &a, const float2 &b) {
27+
__host__ __device__ inline float2 operator+(const float2 &a, const float2 &b) {
2828
float2 retval;
2929
retval.x = a.x + b.x;
3030
retval.y = a.y + b.y;

src/cudafeat/feature-online-batched-ivector-cuda-kernels.cu

Lines changed: 109 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121
#include "cudamatrix/cu-common.h"
2222
namespace kaldi {
2323

24-
// computes feats^2. This works in place and out of place.
24+
// computes pointwise square of each matrix
2525
__global__ void square_batched_matrix_kernel(
2626
int32_t chunk_frames, int32_t num_cols, const float *feats, int32_t ldf,
2727
int32_t stridef, float *feats_sq, int32_t lds, int32_t strides,
2828
const LaneDesc *lanes, int32_t num_lanes) {
2929
int32_t lane = blockIdx.z;
30-
int32_t num_chunk_frames = lanes[lane].num_chunk_frames;
3130

3231
feats = feats + lane * stridef;
3332
feats_sq = feats_sq + lane * strides;
3433

35-
for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_chunk_frames;
34+
for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < chunk_frames;
3635
i += blockDim.y * gridDim.y) {
3736
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_cols;
3837
j += blockDim.x * gridDim.x) {
@@ -56,6 +55,55 @@ void square_batched_matrix(int32_t chunk_frames, int32_t num_cols,
5655
CU_SAFE_CALL(cudaGetLastError());
5756
}
5857

58+
// after computing posteriors some rows are invalid because they were created
59+
// with rows with undefined data. This kernel zeros those rows out so that
60+
// they will not contribue to stats.
61+
__global__ void zero_invalid_posteriors_kernel(
62+
int32_t chunk_size, int32_t num_gauss, float *posteriors, int32_t ldp,
63+
int32_t stridep, int32_t right, const LaneDesc *lanes, int32_t num_lanes) {
64+
int32_t lane = blockIdx.z;
65+
66+
LaneDesc desc = lanes[lane];
67+
int32_t num_chunk_frames = desc.num_chunk_frames;
68+
int32_t current_frame = desc.current_frame;
69+
bool last = desc.last;
70+
71+
// last valid frame for reading
72+
int32_t num_computed_rows = current_frame + num_chunk_frames;
73+
74+
// if not the last frame remove right context
75+
if (!last) {
76+
num_computed_rows -= right;
77+
}
78+
79+
// offset by lane
80+
posteriors = posteriors + lane * stridep;
81+
82+
for (int r = blockIdx.y * blockDim.y + threadIdx.y; r < chunk_size;
83+
r += blockDim.y * gridDim.y) {
84+
int global_row = current_frame + r - right;
85+
if (global_row < 0 || global_row >= num_computed_rows) {
86+
// zero this row out
87+
for (int c = blockIdx.x * blockDim.x + threadIdx.x; c < num_gauss;
88+
c += blockDim.x * gridDim.x) {
89+
posteriors[r * ldp + c] = 0.0f;
90+
}
91+
}
92+
}
93+
}
94+
95+
void zero_invalid_posteriors(int32_t num_chunk_frames, int32_t num_gauss,
96+
float *posteriors, int32_t ldp, int32_t stridep,
97+
int32_t right, const LaneDesc *lanes,
98+
int32_t num_lanes) {
99+
dim3 threads(32, 32);
100+
dim3 blocks((num_gauss + 31) / 32, (num_chunk_frames + 31) / 32, num_lanes);
101+
102+
zero_invalid_posteriors_kernel<<<blocks, threads>>>(
103+
num_chunk_frames, num_gauss, posteriors, ldp, stridep, right, lanes,
104+
num_lanes);
105+
}
106+
59107
// Meant to be called with blockDim= 32x32
60108
// takes features in feat and writes them into sfeats while applying
61109
// the splicing algorithm for the left and right context.
@@ -67,39 +115,48 @@ __global__ void splice_features_batched_kernel(
67115
float *__restrict__ feats_out, int32_t ldo, int32_t strideo,
68116
const LaneDesc *lanes, int32_t num_lanes) {
69117
int32_t lane = blockIdx.y;
70-
int32_t frame = blockIdx.x;
118+
// output frame index
119+
int32_t oframe = blockIdx.x;
71120
int32_t tid = threadIdx.x;
72121

73122
LaneDesc desc = lanes[lane];
74123
int32_t num_chunk_frames = desc.num_chunk_frames;
75124
int32_t channel = desc.channel;
76-
int32_t start_frame = desc.current_frame;
125+
int32_t current_frame = desc.current_frame;
126+
bool last = desc.last;
77127

78-
bool valid_frame = true;
79-
// check that we have valid input
80-
if (frame >= num_chunk_frames) {
81-
valid_frame = false;
82-
}
128+
// offset by lane
129+
feats_in = feats_in + lane * stridei;
130+
feats_out = feats_out + lane * strideo;
83131

84-
// for first chunk we process less frames
85-
if (start_frame == 0 && frame >= num_chunk_frames - right) {
86-
valid_frame = false;
87-
}
132+
// offset by channel
133+
feats_stash = feats_stash + channel * stridest;
88134

89-
// the stash size
135+
// offset feature output to process oframe
136+
feats_out = feats_out + ldo * oframe;
137+
138+
// the size of the stash
90139
int32_t ssize = left + right;
140+
// the size of the window
91141
int32_t size = ssize + 1;
92142

93-
// offset by lane
94-
feats_in = feats_in + lane * stridei;
95-
feats_out = feats_out + lane * strideo;
96-
feats_stash = feats_stash + channel * stridest;
143+
// number of valid frame for reading
144+
int32_t num_valid_frames = current_frame + num_chunk_frames;
97145

98-
// offset feature output to process frame
99-
feats_out = feats_out + ldo * frame;
146+
// number of valid frames for writing
147+
int32_t num_computed_frames = num_valid_frames;
100148

101-
if (!valid_frame) {
102-
// this frames output is not valid, zero it here
149+
// if not the last frame remove right context
150+
if (!last) {
151+
num_computed_frames -= right;
152+
}
153+
154+
// subtract right context from logical frame to delay output
155+
int32_t local_frame = oframe - right;
156+
int32_t global_frame = current_frame + local_frame;
157+
158+
// these frames are set to zeros
159+
if (global_frame < 0 || global_frame >= num_computed_frames) {
103160
for (int i = 0; i < size; i++) {
104161
for (int c = tid; c < feat_dim; c += blockDim.x) {
105162
feats_out[i * feat_dim + c] = 0.0f;
@@ -108,44 +165,40 @@ __global__ void splice_features_batched_kernel(
108165
return;
109166
}
110167

111-
// for each splice of input
112-
for (int i = 0; i < size; i++) {
113-
const float *feats_src = feats_in;
114-
int32_t ld = ldi;
115-
116-
// shift input row by left context
117-
int r = frame + i - left;
168+
for (int i = -left; i <= right; i++) {
169+
int32_t g_in = global_frame + i; // global frame index
170+
int32_t l_in = local_frame + i; // local frame index
118171

119-
// clamp input row if necessary
120-
if (start_frame + r < 0) {
121-
r = 0;
122-
}
172+
// if global row is below zero clamp local to zero
173+
if (g_in < 0) l_in = 0;
123174

124-
// if we have a right context shift input row by that too
125-
if (start_frame > 0) {
126-
r = r - right;
175+
// if global row is larger than the number of valid frames
176+
if (g_in >= num_valid_frames) {
177+
// should only happen on last chunk
178+
assert(last);
179+
// clamp input
180+
l_in = num_chunk_frames - 1;
127181
}
128182

129-
if (r > num_chunk_frames - 1) {
130-
// This should only happen on the last chunk
131-
assert(desc.last == true);
132-
r = num_chunk_frames - 1;
133-
}
183+
// set default input location
184+
const float *feats = feats_in;
185+
int32_t ld = ldi;
134186

135-
if (r < 0) {
136-
// feats are located in stash from previous chunk
137-
feats_src = feats_stash;
187+
// if l < 0 then feats come from the stash
188+
if (l_in < 0) {
189+
// input is from stash
190+
feats = feats_stash;
138191
ld = ldst;
139-
r = r + ssize;
192+
l_in += ssize; // offset by stash size
140193
}
141194

142195
// for each column of input in parallel
143196
for (int c = tid; c < feat_dim; c += blockDim.x) {
144197
// read feature from input row offset by column
145-
float val = feats_src[r * ld + c];
198+
float val = feats[l_in * ld + c];
146199

147200
// write feature to output offset by splice index and column
148-
feats_out[i * feat_dim + c] = val;
201+
feats_out[(i + left) * feat_dim + c] = val;
149202
}
150203
}
151204
}
@@ -159,6 +212,7 @@ void splice_features_batched(int32_t num_chunk_frames, int32_t feat_dim,
159212
const LaneDesc *lanes, int32_t num_lanes) {
160213
int threads = (feat_dim + 31) / 32 * 32; // round up to the nearest warp size
161214
if (threads > 1024) threads = 1024; // Max block size is 1024 threads
215+
162216
dim3 blocks(num_chunk_frames, num_lanes);
163217

164218
splice_features_batched_kernel<<<blocks, threads>>>(
@@ -302,12 +356,15 @@ __global__ void batched_update_linear_and_quadratic_terms_kernel(
302356
linear = linear + lane * stridel;
303357
quadratic = quadratic + lane * strideq;
304358

305-
// This is always zero. not 100% certain as why we don't need
306-
// to account for earlier chunk. maybe Dan knows.
359+
// This is always zero because linear and quadratic terms are not
360+
// being carried forward. Thus we don't need to remove old prior
361+
// scale. Keeping the code below so that it logically matches
362+
// the CPU code in case someone is looking at this in the future.
307363
float old_num_frames = 0;
308364
// float old_num_frames = desc.current_frame;
309365
float new_num_frames = desc.current_frame + desc.num_chunk_frames;
310366

367+
// in CPU code the frame counts are scaled by posterior scale
311368
new_num_frames *= posterior_scale;
312369
old_num_frames *= posterior_scale;
313370

@@ -458,9 +515,6 @@ __global__ void batched_sum_posteriors_kernel(
458515
int32_t stridep, float *gamma, int32_t strideg, float post_scale,
459516
const LaneDesc *lanes, int32_t num_lanes) {
460517
int32_t lane = blockIdx.y;
461-
LaneDesc desc = lanes[lane];
462-
463-
int32_t num_rows = desc.num_chunk_frames;
464518

465519
// offset input and output by lane
466520
posteriors = posteriors + lane * stridep;
@@ -471,7 +525,7 @@ __global__ void batched_sum_posteriors_kernel(
471525
col += blockDim.x * gridDim.x) {
472526
// compute sum across rows for this column
473527
float sum = 0.0f;
474-
for (int row = 0; row < num_rows; row++) {
528+
for (int row = 0; row < chunk_size; row++) {
475529
sum += posteriors[row * ldp + col];
476530
}
477531

@@ -509,7 +563,7 @@ __global__ void initialize_channels_kernel(int32_t num_gauss, int32_t feat_dim,
509563

510564
// initialize stashes to zero
511565
for (int i = threadIdx.y * blockDim.x + threadIdx.x; i < num_gauss;
512-
i += blockDim.x * gridDim.x) {
566+
i += blockDim.y * blockDim.x) {
513567
gamma[i] = 0.0f;
514568
}
515569

src/cudafeat/feature-online-batched-ivector-cuda-kernels.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ namespace kaldi {
2929
// Thus to compute the matrix pointer of a matrix you use this formula:
3030
// matrix_pointer = base_pointer + batch_number * stride
3131

32+
void zero_invalid_posteriors(int32_t num_chunk_frames, int32_t num_gauss,
33+
float *posteriors, int32_t ldp, int32_t stridep,
34+
int32_t right, const LaneDesc *lanes,
35+
int32_t num_lanes);
36+
3237
void splice_features_batched(int32_t num_chunk_frames, int32_t feat_dim,
3338
int32_t left, int32_t right, const float *feats,
3439
int32_t ldf, int32_t stridef,

src/cudafeat/feature-online-batched-ivector-cuda.cc

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ BatchedIvectorExtractorCuda::BatchedIvectorExtractorCuda(
2626
chunk_size_(chunk_size),
2727
max_lanes_(num_lanes),
2828
num_channels_(num_channels) {
29+
#if CUDA_VERSION < 9010
30+
// some components require newer cuda versions. If you see this error
31+
// upgrade to a more recent CUDA version.
32+
KALDI_ERR << "BatchedIvectorExtractorCuda requires CUDA 9.1 or newer.";
33+
#endif
2934
info_.Init(config);
3035
Read(config);
3136

@@ -255,6 +260,8 @@ void BatchedIvectorExtractorCuda::LDATransform(const CuMatrix<BaseFloat> &feats,
255260
void BatchedIvectorExtractorCuda::ComputePosteriors(CuMatrix<BaseFloat> &feats,
256261
const LaneDesc *lanes,
257262
int32_t num_lanes) {
263+
int right = info_.splice_opts.right_context;
264+
258265
// inititalize posteriors
259266
posteriors_.CopyRowsFromVec(ubm_gconsts_);
260267

@@ -271,6 +278,13 @@ void BatchedIvectorExtractorCuda::ComputePosteriors(CuMatrix<BaseFloat> &feats,
271278
posteriors_.AddMatMat(-0.5, feats, kNoTrans, ubm_inv_vars_, kTrans, 1.0);
272279

273280
posteriors_.ApplySoftMaxPerRow();
281+
282+
// At this point some rows of posteriors are invalid because they
283+
// didn't have valid input rows. Zero those out now so that
284+
// they don't impact stats
285+
zero_invalid_posteriors(
286+
chunk_size_, num_gauss_, posteriors_.Data(), posteriors_.Stride(),
287+
posteriors_.Stride() * chunk_size_, right, lanes, num_lanes);
274288
}
275289

276290
void BatchedIvectorExtractorCuda::ComputeIvectorStats(
@@ -281,6 +295,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
281295
posteriors_.Stride() * chunk_size_, gamma_.Data(),
282296
num_gauss_, info_.posterior_scale, lanes, num_lanes);
283297

298+
#if CUDA_VERSION >= 9010
284299
int32_t m = feat_dim_;
285300
int32_t n = num_gauss_;
286301
int32_t k = chunk_size_;
@@ -296,11 +311,12 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
296311
int32_t ldc = X_.Stride();
297312
int32_t strideC = ldc * num_gauss_;
298313

299-
// multiplying X = stash * feats
314+
// multiplying X = post * feats
300315
CUBLAS_SAFE_CALL(cublasGemmStridedBatchedEx(
301316
GetCublasHandle(), CUBLAS_OP_N, CUBLAS_OP_T, m, n, k, &alpha, A,
302317
CUDA_R_32F, lda, strideA, B, CUDA_R_32F, ldb, strideB, &beta, C,
303318
CUDA_R_32F, ldc, strideC, num_lanes, CUDA_R_32F, CUBLAS_GEMM_DEFAULT))
319+
#endif
304320

305321
apply_and_update_stash(
306322
num_gauss_, feat_dim_, gamma_.Data(), gamma_stash_.Data(), num_gauss_,
@@ -311,9 +327,6 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
311327
void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
312328
CuVectorBase<BaseFloat> *ivectors, const LaneDesc *lanes,
313329
int32_t num_lanes) {
314-
static int current_frame = 0;
315-
current_frame++;
316-
317330
// Computing Linear Term
318331
{
319332
// need to set this term to zero because batched_compute_linear_term
@@ -348,9 +361,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
348361
ivector_dim_, lanes, num_lanes);
349362

350363
#if CUDA_VERSION >= 9010
351-
352364
int nrhs = 1;
353-
354365
// perform factorization in batched
355366
CUSOLVER_SAFE_CALL(cusolverDnSpotrfBatched(
356367
GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, quad_array_,
@@ -361,18 +372,10 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
361372
GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, nrhs,
362373
quad_array_, ivector_dim_, ivec_array_, ivector_dim_, d_infoArray_,
363374
num_lanes));
375+
#endif
364376

365377
// cusolver solves in place. Ivectors are now in linear_
366378

367-
#else
368-
// We could make a fallback if necessary. This would likely just loop
369-
// over each matrix and call Invert not batched. This would be very slow and
370-
// throwing an error is probably better force people to use a more recent
371-
// version of CUDA.
372-
KALDI_ERR << "Online Ivectors in CUDA is not supported by your CUDA version. "
373-
<< "Upgrade to CUDA 9.1 or later";
374-
#endif
375-
376379
// Create a submatrix which points to the first element of each ivector
377380
CuSubMatrix<BaseFloat> ivector0(linear_.Data(), num_lanes, 1, ivector_dim_);
378381
// remove prior

0 commit comments

Comments
 (0)