21
21
#include " cudamatrix/cu-common.h"
22
22
namespace kaldi {
23
23
24
- // computes feats^2. This works in place and out of place.
24
+ // computes pointwise square of each matrix
25
25
__global__ void square_batched_matrix_kernel (
26
26
int32_t chunk_frames, int32_t num_cols, const float *feats, int32_t ldf,
27
27
int32_t stridef, float *feats_sq, int32_t lds, int32_t strides,
28
28
const LaneDesc *lanes, int32_t num_lanes) {
29
29
int32_t lane = blockIdx .z ;
30
- int32_t num_chunk_frames = lanes[lane].num_chunk_frames ;
31
30
32
31
feats = feats + lane * stridef;
33
32
feats_sq = feats_sq + lane * strides;
34
33
35
- for (int i = blockIdx .y * blockDim .y + threadIdx .y ; i < num_chunk_frames ;
34
+ for (int i = blockIdx .y * blockDim .y + threadIdx .y ; i < chunk_frames ;
36
35
i += blockDim .y * gridDim .y ) {
37
36
for (int j = blockIdx .x * blockDim .x + threadIdx .x ; j < num_cols;
38
37
j += blockDim .x * gridDim .x ) {
@@ -56,6 +55,55 @@ void square_batched_matrix(int32_t chunk_frames, int32_t num_cols,
56
55
CU_SAFE_CALL (cudaGetLastError ());
57
56
}
58
57
58
+ // after computing posteriors some rows are invalid because they were created
59
+ // with rows with undefined data. This kernel zeros those rows out so that
60
+ // they will not contribue to stats.
61
+ __global__ void zero_invalid_posteriors_kernel (
62
+ int32_t chunk_size, int32_t num_gauss, float *posteriors, int32_t ldp,
63
+ int32_t stridep, int32_t right, const LaneDesc *lanes, int32_t num_lanes) {
64
+ int32_t lane = blockIdx .z ;
65
+
66
+ LaneDesc desc = lanes[lane];
67
+ int32_t num_chunk_frames = desc.num_chunk_frames ;
68
+ int32_t current_frame = desc.current_frame ;
69
+ bool last = desc.last ;
70
+
71
+ // last valid frame for reading
72
+ int32_t num_computed_rows = current_frame + num_chunk_frames;
73
+
74
+ // if not the last frame remove right context
75
+ if (!last) {
76
+ num_computed_rows -= right;
77
+ }
78
+
79
+ // offset by lane
80
+ posteriors = posteriors + lane * stridep;
81
+
82
+ for (int r = blockIdx .y * blockDim .y + threadIdx .y ; r < chunk_size;
83
+ r += blockDim .y * gridDim .y ) {
84
+ int global_row = current_frame + r - right;
85
+ if (global_row < 0 || global_row >= num_computed_rows) {
86
+ // zero this row out
87
+ for (int c = blockIdx .x * blockDim .x + threadIdx .x ; c < num_gauss;
88
+ c += blockDim .x * gridDim .x ) {
89
+ posteriors[r * ldp + c] = 0 .0f ;
90
+ }
91
+ }
92
+ }
93
+ }
94
+
95
+ void zero_invalid_posteriors (int32_t num_chunk_frames, int32_t num_gauss,
96
+ float *posteriors, int32_t ldp, int32_t stridep,
97
+ int32_t right, const LaneDesc *lanes,
98
+ int32_t num_lanes) {
99
+ dim3 threads (32 , 32 );
100
+ dim3 blocks ((num_gauss + 31 ) / 32 , (num_chunk_frames + 31 ) / 32 , num_lanes);
101
+
102
+ zero_invalid_posteriors_kernel<<<blocks, threads>>> (
103
+ num_chunk_frames, num_gauss, posteriors, ldp, stridep, right, lanes,
104
+ num_lanes);
105
+ }
106
+
59
107
// Meant to be called with blockDim= 32x32
60
108
// takes features in feat and writes them into sfeats while applying
61
109
// the splicing algorithm for the left and right context.
@@ -67,39 +115,48 @@ __global__ void splice_features_batched_kernel(
67
115
float *__restrict__ feats_out, int32_t ldo, int32_t strideo,
68
116
const LaneDesc *lanes, int32_t num_lanes) {
69
117
int32_t lane = blockIdx .y ;
70
- int32_t frame = blockIdx .x ;
118
+ // output frame index
119
+ int32_t oframe = blockIdx .x ;
71
120
int32_t tid = threadIdx .x ;
72
121
73
122
LaneDesc desc = lanes[lane];
74
123
int32_t num_chunk_frames = desc.num_chunk_frames ;
75
124
int32_t channel = desc.channel ;
76
- int32_t start_frame = desc.current_frame ;
125
+ int32_t current_frame = desc.current_frame ;
126
+ bool last = desc.last ;
77
127
78
- bool valid_frame = true ;
79
- // check that we have valid input
80
- if (frame >= num_chunk_frames) {
81
- valid_frame = false ;
82
- }
128
+ // offset by lane
129
+ feats_in = feats_in + lane * stridei;
130
+ feats_out = feats_out + lane * strideo;
83
131
84
- // for first chunk we process less frames
85
- if (start_frame == 0 && frame >= num_chunk_frames - right) {
86
- valid_frame = false ;
87
- }
132
+ // offset by channel
133
+ feats_stash = feats_stash + channel * stridest;
88
134
89
- // the stash size
135
+ // offset feature output to process oframe
136
+ feats_out = feats_out + ldo * oframe;
137
+
138
+ // the size of the stash
90
139
int32_t ssize = left + right;
140
+ // the size of the window
91
141
int32_t size = ssize + 1 ;
92
142
93
- // offset by lane
94
- feats_in = feats_in + lane * stridei;
95
- feats_out = feats_out + lane * strideo;
96
- feats_stash = feats_stash + channel * stridest;
143
+ // number of valid frame for reading
144
+ int32_t num_valid_frames = current_frame + num_chunk_frames;
97
145
98
- // offset feature output to process frame
99
- feats_out = feats_out + ldo * frame ;
146
+ // number of valid frames for writing
147
+ int32_t num_computed_frames = num_valid_frames ;
100
148
101
- if (!valid_frame) {
102
- // this frames output is not valid, zero it here
149
+ // if not the last frame remove right context
150
+ if (!last) {
151
+ num_computed_frames -= right;
152
+ }
153
+
154
+ // subtract right context from logical frame to delay output
155
+ int32_t local_frame = oframe - right;
156
+ int32_t global_frame = current_frame + local_frame;
157
+
158
+ // these frames are set to zeros
159
+ if (global_frame < 0 || global_frame >= num_computed_frames) {
103
160
for (int i = 0 ; i < size; i++) {
104
161
for (int c = tid; c < feat_dim; c += blockDim .x ) {
105
162
feats_out[i * feat_dim + c] = 0 .0f ;
@@ -108,44 +165,40 @@ __global__ void splice_features_batched_kernel(
108
165
return ;
109
166
}
110
167
111
- // for each splice of input
112
- for (int i = 0 ; i < size; i++) {
113
- const float *feats_src = feats_in;
114
- int32_t ld = ldi;
115
-
116
- // shift input row by left context
117
- int r = frame + i - left;
168
+ for (int i = -left; i <= right; i++) {
169
+ int32_t g_in = global_frame + i; // global frame index
170
+ int32_t l_in = local_frame + i; // local frame index
118
171
119
- // clamp input row if necessary
120
- if (start_frame + r < 0 ) {
121
- r = 0 ;
122
- }
172
+ // if global row is below zero clamp local to zero
173
+ if (g_in < 0 ) l_in = 0 ;
123
174
124
- // if we have a right context shift input row by that too
125
- if (start_frame > 0 ) {
126
- r = r - right;
175
+ // if global row is larger than the number of valid frames
176
+ if (g_in >= num_valid_frames) {
177
+ // should only happen on last chunk
178
+ assert (last);
179
+ // clamp input
180
+ l_in = num_chunk_frames - 1 ;
127
181
}
128
182
129
- if (r > num_chunk_frames - 1 ) {
130
- // This should only happen on the last chunk
131
- assert (desc.last == true );
132
- r = num_chunk_frames - 1 ;
133
- }
183
+ // set default input location
184
+ const float *feats = feats_in;
185
+ int32_t ld = ldi;
134
186
135
- if (r < 0 ) {
136
- // feats are located in stash from previous chunk
137
- feats_src = feats_stash;
187
+ // if l < 0 then feats come from the stash
188
+ if (l_in < 0 ) {
189
+ // input is from stash
190
+ feats = feats_stash;
138
191
ld = ldst;
139
- r = r + ssize;
192
+ l_in += ssize; // offset by stash size
140
193
}
141
194
142
195
// for each column of input in parallel
143
196
for (int c = tid; c < feat_dim; c += blockDim .x ) {
144
197
// read feature from input row offset by column
145
- float val = feats_src[r * ld + c];
198
+ float val = feats[l_in * ld + c];
146
199
147
200
// write feature to output offset by splice index and column
148
- feats_out[i * feat_dim + c] = val;
201
+ feats_out[(i + left) * feat_dim + c] = val;
149
202
}
150
203
}
151
204
}
@@ -159,6 +212,7 @@ void splice_features_batched(int32_t num_chunk_frames, int32_t feat_dim,
159
212
const LaneDesc *lanes, int32_t num_lanes) {
160
213
int threads = (feat_dim + 31 ) / 32 * 32 ; // round up to the nearest warp size
161
214
if (threads > 1024 ) threads = 1024 ; // Max block size is 1024 threads
215
+
162
216
dim3 blocks (num_chunk_frames, num_lanes);
163
217
164
218
splice_features_batched_kernel<<<blocks, threads>>> (
@@ -302,12 +356,15 @@ __global__ void batched_update_linear_and_quadratic_terms_kernel(
302
356
linear = linear + lane * stridel;
303
357
quadratic = quadratic + lane * strideq;
304
358
305
- // This is always zero. not 100% certain as why we don't need
306
- // to account for earlier chunk. maybe Dan knows.
359
+ // This is always zero because linear and quadratic terms are not
360
+ // being carried forward. Thus we don't need to remove old prior
361
+ // scale. Keeping the code below so that it logically matches
362
+ // the CPU code in case someone is looking at this in the future.
307
363
float old_num_frames = 0 ;
308
364
// float old_num_frames = desc.current_frame;
309
365
float new_num_frames = desc.current_frame + desc.num_chunk_frames ;
310
366
367
+ // in CPU code the frame counts are scaled by posterior scale
311
368
new_num_frames *= posterior_scale;
312
369
old_num_frames *= posterior_scale;
313
370
@@ -458,9 +515,6 @@ __global__ void batched_sum_posteriors_kernel(
458
515
int32_t stridep, float *gamma, int32_t strideg, float post_scale,
459
516
const LaneDesc *lanes, int32_t num_lanes) {
460
517
int32_t lane = blockIdx .y ;
461
- LaneDesc desc = lanes[lane];
462
-
463
- int32_t num_rows = desc.num_chunk_frames ;
464
518
465
519
// offset input and output by lane
466
520
posteriors = posteriors + lane * stridep;
@@ -471,7 +525,7 @@ __global__ void batched_sum_posteriors_kernel(
471
525
col += blockDim .x * gridDim .x ) {
472
526
// compute sum across rows for this column
473
527
float sum = 0 .0f ;
474
- for (int row = 0 ; row < num_rows ; row++) {
528
+ for (int row = 0 ; row < chunk_size ; row++) {
475
529
sum += posteriors[row * ldp + col];
476
530
}
477
531
@@ -509,7 +563,7 @@ __global__ void initialize_channels_kernel(int32_t num_gauss, int32_t feat_dim,
509
563
510
564
// initialize stashes to zero
511
565
for (int i = threadIdx .y * blockDim .x + threadIdx .x ; i < num_gauss;
512
- i += blockDim .x * gridDim .x ) {
566
+ i += blockDim .y * blockDim .x ) {
513
567
gamma [i] = 0 .0f ;
514
568
}
515
569
0 commit comments