Skip to content

Commit 9972000

Browse files
kvishnivetskyKonstantin S. Vishnivetsky
and
Konstantin S. Vishnivetsky
authored
[src] FIX: Error control of CUDA allocations in Reallocate method. (#4305)
FIX: Pointers initialization and deinitialization. FIX: Minor logical errors with variables. Summary: Prevents Segmentation Fault uncontrolled termination on NVIDIA Tela T4 multi-cards hardware configuration. Co-authored-by: Konstantin S. Vishnivetsky <[email protected]>
1 parent 811bd21 commit 9972000

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.cc

+13-10
Original file line numberDiff line numberDiff line change
@@ -205,23 +205,26 @@ void BatchedThreadedNnet3CudaPipeline2::ComputeOfflineFeatures() {
205205

206206
cudaEventSynchronize(wave_buffer_->evt);
207207
if (nsamp > wave_buffer_->size) {
208-
wave_buffer_->Reallocate(nsamp);
208+
wave_buffer_->Reallocate(nsamp);
209209
}
210-
std::memcpy(wave_buffer_->h_data, h_wave.Data(),
211-
h_wave.Dim() * sizeof(BaseFloat));
212-
cudaMemcpyAsync(wave_buffer_->d_data, wave_buffer_->h_data,
213-
sizeof(BaseFloat) * nsamp, cudaMemcpyHostToDevice,
214-
cudaStreamPerThread);
210+
211+
std::memcpy(wave_buffer_->h_data, h_wave.Data(), nsamp * sizeof(BaseFloat));
212+
cudaMemcpyAsync(wave_buffer_->d_data, wave_buffer_->h_data, nsamp * sizeof(BaseFloat), cudaMemcpyHostToDevice, cudaStreamPerThread);
215213

216214
task.d_features.reset(new CuMatrix<BaseFloat>());
217215
task.d_ivectors.reset(new CuVector<BaseFloat>());
216+
218217
CuSubVector<BaseFloat> wrapper(wave_buffer_->d_data, nsamp);
219-
cuda_features_->ComputeFeatures(
220-
wrapper, cuda_online_pipeline_.GetModelFrequency(),
221-
task.d_features.get(), task.d_ivectors.get());
218+
219+
cuda_features_->ComputeFeatures(wrapper, cuda_online_pipeline_.GetModelFrequency(), task.d_features.get(), task.d_ivectors.get());
220+
222221
cudaEventRecord(wave_buffer_->evt, cudaStreamPerThread);
222+
223223
std::swap(wave_buffer_, next_wave_buffer_);
224-
if (task.wave_data) task.wave_data.reset(); // delete wave samples on host
224+
225+
if (task.wave_data)
226+
task.wave_data.reset(); // delete wave samples on host
227+
225228
{
226229
std::lock_guard<std::mutex> lk(outstanding_utt_m_);
227230
outstanding_utt_.push(std::move(task));

src/cudadecoder/batched-threaded-nnet3-cuda-pipeline2.h

+24-12
Original file line numberDiff line numberDiff line change
@@ -138,29 +138,41 @@ class BatchedThreadedNnet3CudaPipeline2 {
138138
BaseFloat *d_data;
139139
size_t size;
140140

141-
HostDeviceVector()
141+
HostDeviceVector(const size_t new_size = KALDI_CUDA_DECODER_AUDIO_HOST_DEVICE_BUFFER_SIZE)
142142
: h_data(NULL),
143143
d_data(NULL),
144-
size(KALDI_CUDA_DECODER_AUDIO_HOST_DEVICE_BUFFER_SIZE) {
144+
size(new_size) {
145145
cudaEventCreate(&evt);
146-
Reallocate(size);
146+
Reallocate(new_size);
147147
}
148148

149149
virtual ~HostDeviceVector() {
150150
Deallocate();
151151
cudaEventDestroy(evt);
152152
}
153153

154-
void Reallocate(size_t new_size) {
155-
KALDI_ASSERT(new_size > 0);
156-
Deallocate();
157-
cudaMalloc(&d_data, new_size * sizeof(*d_data));
158-
cudaMallocHost(&h_data, new_size * sizeof(*d_data));
159-
new_size = size;
154+
void Reallocate(const size_t new_size) {
155+
KALDI_ASSERT(new_size > 0);
156+
Deallocate();
157+
158+
cudaError_t cuResult = cudaSuccess;
159+
cuResult = cudaMalloc(&d_data, new_size * sizeof(BaseFloat));
160+
if (cuResult != cudaSuccess) {
161+
KALDI_ERR << "cudaMalloc() failed with error: " << cudaGetErrorString(cuResult);
162+
}
163+
KALDI_ASSERT(d_data != NULL);
164+
165+
cuResult = cudaMallocHost(&h_data, new_size * sizeof(BaseFloat));
166+
if (cuResult != cudaSuccess) {
167+
KALDI_ERR << "cudaMallocHost() failed with error: " << cudaGetErrorString(cuResult);
168+
}
169+
KALDI_ASSERT(h_data != NULL);
170+
171+
size = new_size;
160172
}
161173
void Deallocate() {
162-
if (d_data) cudaFree(d_data);
163-
if (h_data) cudaFreeHost(h_data);
174+
if (d_data) {cudaFree(d_data); d_data = NULL; }
175+
if (h_data) {cudaFreeHost(h_data); h_data = NULL; }
164176
}
165177
};
166178

@@ -245,7 +257,7 @@ class BatchedThreadedNnet3CudaPipeline2 {
245257
void WaitForAllTasks();
246258

247259
// Used for debug
248-
void SetSymbolTable(const fst::SymbolTable &word_syms) {
260+
void SetSymbolTable(fst::SymbolTable *word_syms) {
249261
cuda_online_pipeline_.SetSymbolTable(word_syms);
250262
}
251263

0 commit comments

Comments
 (0)