Skip to content

Commit 40f321a

Browse files
banasrafklecki
authored andcommitted
GPU MFCC operator. (#2423)
Extend GPU DCT kernel to support liftering and add MFCC operator for GPU. Signed-off-by: Rafal <[email protected]>
1 parent d67c0a3 commit 40f321a

File tree

9 files changed

+293
-99
lines changed

9 files changed

+293
-99
lines changed

dali/kernels/signal/dct/dct_gpu.cu

+19-8
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ namespace dct {
3131

3232
// The kernel processes data with the shape reduced to 3D.
3333
// Transform is applied over the middle axis.
34-
template <typename OutputType, typename InputType>
34+
template <typename OutputType, typename InputType, bool HasLifter>
3535
__global__ void ApplyDct(const typename Dct1DGpu<OutputType, InputType>::SampleDesc *samples,
36-
const BlockDesc<3> *blocks) {
36+
const BlockDesc<3> *blocks, const float *lifter_coeffs) {
3737
int bid = blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z);
3838
auto block = blocks[bid];
3939
const auto &sample = samples[block.sample_idx];
@@ -43,6 +43,7 @@ __global__ void ApplyDct(const typename Dct1DGpu<OutputType, InputType>::SampleD
4343
for (int z = block.start.z + threadIdx.z; z < block.end.z; z += blockDim.z) {
4444
for (int y = block.start.y + threadIdx.y; y < block.end.y; y += blockDim.y) {
4545
const OutputType *cos_row = sample.cos_table + sample.input_length * y;
46+
float coeff = HasLifter ? lifter_coeffs[y] : 1.f;
4647
for (int x = block.start.x + threadIdx.x; x < block.end.x; x += blockDim.x) {
4748
int output_idx = dot(out_stride, ivec3{z, y, x});
4849
const InputType *input = sample.input + dot(in_stride, ivec3{z, 0, x});
@@ -51,7 +52,7 @@ __global__ void ApplyDct(const typename Dct1DGpu<OutputType, InputType>::SampleD
5152
out_val += *input * cos_row[i];
5253
input += in_stride[1];
5354
}
54-
sample.output[output_idx] = out_val;
55+
sample.output[output_idx] = HasLifter ? out_val * coeff : out_val;
5556
}
5657
}
5758
}
@@ -60,8 +61,7 @@ __global__ void ApplyDct(const typename Dct1DGpu<OutputType, InputType>::SampleD
6061
template <typename OutputType, typename InputType>
6162
KernelRequirements Dct1DGpu<OutputType, InputType>::Setup(KernelContext &ctx,
6263
const InListGPU<InputType> &in,
63-
span<const DctArgs> args,
64-
int axis) {
64+
span<const DctArgs> args, int axis) {
6565
DALI_ENFORCE(args.size() == in.num_samples());
6666
KernelRequirements req{};
6767
ScratchpadEstimator se{};
@@ -120,7 +120,7 @@ template <typename OutputType, typename InputType>
120120
DLL_PUBLIC void Dct1DGpu<OutputType, InputType>::Run(KernelContext &ctx,
121121
const OutListGPU<OutputType> &out,
122122
const InListGPU<InputType> &in,
123-
span<const DctArgs>, int) {
123+
InTensorGPU<float, 1> lifter_coeffs) {
124124
OutputType *cpu_cos_table[2];
125125
cpu_cos_table[0] =
126126
ctx.scratchpad->Allocate<OutputType>(AllocType::Pinned, max_cos_table_size_);
@@ -148,6 +148,10 @@ DLL_PUBLIC void Dct1DGpu<OutputType, InputType>::Run(KernelContext &ctx,
148148
for (auto arg : args_) {
149149
auto in_shape = reduce_shape(in.tensor_shape_span(s), axis_);
150150
auto out_shape = reduce_shape(out.tensor_shape_span(s), axis_);
151+
DALI_ENFORCE(lifter_coeffs.num_elements() == 0 || out_shape[1] <= lifter_coeffs.num_elements(),
152+
make_string("Not enough lifter coefficients. NDCT for sample ", s, " is ",
153+
out_shape[1], " and only ", lifter_coeffs.num_elements(),
154+
" coefficients were passed."));
151155
ivec3 out_stride = GetStrides(ivec3{out_shape[0], out_shape[1], out_shape[2]});
152156
ivec3 in_stride = GetStrides(ivec3{in_shape[0], in_shape[1], in_shape[2]});;
153157
int n = in_shape[1];
@@ -162,8 +166,15 @@ DLL_PUBLIC void Dct1DGpu<OutputType, InputType>::Run(KernelContext &ctx,
162166
ctx.scratchpad->ToContiguousGPU(ctx.gpu.stream, sample_descs_, block_setup_.Blocks());
163167
dim3 grid_dim = block_setup_.GridDim();
164168
dim3 block_dim = block_setup_.BlockDim();
165-
ApplyDct<OutputType, InputType>
166-
<<<grid_dim, block_dim, 0, ctx.gpu.stream>>>(sample_descs_gpu, block_descs_gpu);
169+
if (lifter_coeffs.num_elements() > 0) {
170+
ApplyDct<OutputType, InputType, true>
171+
<<<grid_dim, block_dim, 0, ctx.gpu.stream>>>(sample_descs_gpu, block_descs_gpu,
172+
lifter_coeffs.data);
173+
} else {
174+
ApplyDct<OutputType, InputType, false>
175+
<<<grid_dim, block_dim, 0, ctx.gpu.stream>>>(sample_descs_gpu, block_descs_gpu,
176+
nullptr);
177+
}
167178
}
168179

169180
template class Dct1DGpu<float, float>;

dali/kernels/signal/dct/dct_gpu.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class DLL_PUBLIC Dct1DGpu {
8282
DLL_PUBLIC void Run(KernelContext &context,
8383
const OutListGPU<OutputType> &out,
8484
const InListGPU<InputType> &in,
85-
span<const DctArgs> args, int axis);
85+
InTensorGPU<float, 1> lifter_coeffs);
8686

8787
private:
8888
std::map<std::pair<int, DctArgs>, OutputType*> cos_tables_{};

dali/kernels/signal/dct/dct_gpu_test.cc

+28-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "dali/test/tensor_test_utils.h"
2020
#include "dali/test/test_tensors.h"
2121
#include "dali/kernels/signal/dct/dct_test.h"
22+
#include "dali/core/dev_buffer.h"
2223

2324
namespace dali {
2425
namespace kernels {
@@ -27,13 +28,21 @@ namespace dct {
2728
namespace test {
2829

2930
class Dct1DGpuTest : public ::testing::TestWithParam<
30-
std::tuple<int, std::pair<int, std::vector<int>>>> {
31+
std::tuple<int, float, std::pair<int, std::vector<int>>>> {
3132
public:
3233
Dct1DGpuTest()
3334
: batch_size_(std::get<0>(GetParam()))
34-
, dims_(std::get<1>(GetParam()).first)
35-
, axes_(std::get<1>(GetParam()).second)
35+
, lifter_(std::get<1>(GetParam()))
36+
, dims_(std::get<2>(GetParam()).first)
37+
, axes_(std::get<2>(GetParam()).second)
3638
, in_shape_(batch_size_, dims_) {
39+
if (lifter_) {
40+
FillLifter();
41+
lifter_coeffs_gpu_buffer.resize(max_ndct);
42+
lifter_coeffs_gpu_ = make_tensor_gpu<1>(lifter_coeffs_gpu_buffer.data(), {max_ndct});
43+
cudaMemcpy(lifter_coeffs_gpu_.data, lifter_coeffs_.data(),
44+
lifter_coeffs_.size() * sizeof(float), cudaMemcpyHostToDevice);
45+
}
3746
while (args_.size() < static_cast<size_t>(batch_size_) * axes_.size()) {
3847
for (auto dct : dct_type) {
3948
for (auto norm : normalize) {
@@ -49,6 +58,13 @@ class Dct1DGpuTest : public ::testing::TestWithParam<
4958
~Dct1DGpuTest() override = default;
5059

5160
protected:
61+
void FillLifter() {
62+
lifter_coeffs_.resize(max_ndct);
63+
for (int i = 0; i < max_ndct; ++i) {
64+
lifter_coeffs_[i] = 1.0 + lifter_ / 2 * std::sin(M_PI / lifter_ * (i + 1));
65+
}
66+
}
67+
5268
void PrepareInput() {
5369
std::mt19937_64 rng{12345};
5470
std::uniform_int_distribution<> dim_dist(1, 3);
@@ -82,17 +98,22 @@ class Dct1DGpuTest : public ::testing::TestWithParam<
8298
}
8399

84100
int batch_size_;
101+
float lifter_;
85102
int dims_;
86103
std::vector<int> axes_;
87104
TensorListShape<> in_shape_;
88105
TestTensorList<float> ttl_in_;
89106
TestTensorList<float> ttl_out_;
90107
std::vector<DctArgs> args_;
108+
std::vector<float> lifter_coeffs_;
109+
DeviceBuffer<float> lifter_coeffs_gpu_buffer;
110+
OutTensorGPU<float, 1> lifter_coeffs_gpu_{};
91111
int args_idx_ = 0;
92112
span<const DctArgs> args_span_;
93113
const std::array<int, 4> dct_type = {{1, 2, 3, 4}};
94114
const std::array<bool, 2> normalize = {{false, true}};
95115
const std::array<int, 3> ndct = {{-1, 10, 20}};
116+
const int max_ndct = 40;
96117
};
97118

98119

@@ -112,7 +133,7 @@ TEST_P(Dct1DGpuTest, DctTest) {
112133
ASSERT_EQ(out_shape, req.output_shapes[0]);
113134
ttl_out_.reshape(out_shape);
114135
auto out_view = ttl_out_.gpu();
115-
kmgr.Run<Kernel>(0, 0, ctx, out_view, in_view, args_span_, axis);
136+
kmgr.Run<Kernel>(0, 0, ctx, out_view, in_view, lifter_coeffs_gpu_);
116137
cudaStreamSynchronize(ctx.gpu.stream);
117138
auto cpu_in_view = ttl_in_.cpu();
118139
auto cpu_out_view = ttl_out_.cpu();
@@ -148,7 +169,7 @@ TEST_P(Dct1DGpuTest, DctTest) {
148169
LOG_LINE << "\n";
149170
int ndct = args.ndct > 0 ? args.ndct : in_shape_[s][axis];
150171
std::vector<float> ref(ndct, 0);
151-
ReferenceDct(args.dct_type, make_span(ref), make_cspan(in_buf), args.normalize);
172+
ReferenceDct(args.dct_type, make_span(ref), make_cspan(in_buf), args.normalize, lifter_);
152173
LOG_LINE << "DCT (type " << args.dct_type << "):";
153174
for (int k = 0; k < ndct; k++) {
154175
EXPECT_NEAR(ref[k], out[out_idx], 1e-5);
@@ -163,7 +184,8 @@ TEST_P(Dct1DGpuTest, DctTest) {
163184
}
164185

165186
INSTANTIATE_TEST_SUITE_P(Dct1DGpuTest, Dct1DGpuTest, testing::Combine(
166-
testing::Values(1, 6, 12), // batch_size
187+
testing::Values(1, 12), // batch_size
188+
testing::Values(0.f, 0.5f), // lifter
167189
testing::Values(std::make_pair(2, std::vector<int>{1}),
168190
std::make_pair(4, std::vector<int>{0, 3, 1}),
169191
std::make_pair(1, std::vector<int>{0, 0})) // dims, axes

dali/kernels/signal/dct/dct_test.h

+17-13
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace dct {
2424
namespace test {
2525

2626
template <typename T>
27-
void ReferenceDctTypeI(span<T> out, span<const T> in, bool normalize) {
27+
void ReferenceDctTypeI(span<T> out, span<const T> in, bool normalize, float lifter) {
2828
int64_t in_length = in.size();
2929
int64_t out_length = out.size();
3030
double phase_mul = M_PI / (in_length - 1);
@@ -34,12 +34,13 @@ void ReferenceDctTypeI(span<T> out, span<const T> in, bool normalize) {
3434
for (int64_t n = 1; n < in_length - 1; n++) {
3535
out_val += in[n] * std::cos(phase_mul * n * k);
3636
}
37-
out[k] = out_val;
37+
float coeff = lifter ? (1.0 + lifter / 2 * std::sin(M_PI / lifter * (k + 1))) : 1.f;
38+
out[k] = out_val * coeff;
3839
}
3940
}
4041

4142
template <typename T>
42-
void ReferenceDctTypeII(span<T> out, span<const T> in, bool normalize) {
43+
void ReferenceDctTypeII(span<T> out, span<const T> in, bool normalize, float lifter) {
4344
int64_t in_length = in.size();
4445
int64_t out_length = out.size();
4546
double phase_mul = M_PI / in_length;
@@ -54,12 +55,13 @@ void ReferenceDctTypeII(span<T> out, span<const T> in, bool normalize) {
5455
out_val += in[n] * std::cos(phase_mul * (n + 0.5) * k);
5556
}
5657
double factor = (k == 0) ? factor_k_0 : factor_k_i;
57-
out[k] = factor * out_val;
58+
float coeff = lifter ? (1.0 + lifter / 2 * std::sin(M_PI / lifter * (k + 1))) : 1.f;
59+
out[k] = factor * out_val * coeff;
5860
}
5961
}
6062

6163
template <typename T>
62-
void ReferenceDctTypeIII(span<T> out, span<const T> in, bool normalize) {
64+
void ReferenceDctTypeIII(span<T> out, span<const T> in, bool normalize, float lifter) {
6365
int64_t in_length = in.size();
6466
int64_t out_length = out.size();
6567
double phase_mul = M_PI / in_length;
@@ -74,12 +76,13 @@ void ReferenceDctTypeIII(span<T> out, span<const T> in, bool normalize) {
7476
for (int64_t n = 1; n < in_length; n++) {
7577
out_val += factor_n_i * in[n] * std::cos(phase_mul * n * (k + 0.5));
7678
}
77-
out[k] = out_val;
79+
float coeff = lifter ? (1.0 + lifter / 2 * std::sin(M_PI / lifter * (k + 1))) : 1.f;
80+
out[k] = out_val * coeff;
7881
}
7982
}
8083

8184
template <typename T>
82-
void ReferenceDctTypeIV(span<T> out, span<const T> in, bool normalize) {
85+
void ReferenceDctTypeIV(span<T> out, span<const T> in, bool normalize, float lifter) {
8386
int64_t in_length = in.size();
8487
int64_t out_length = out.size();
8588
double phase_mul = M_PI / in_length;
@@ -89,28 +92,29 @@ void ReferenceDctTypeIV(span<T> out, span<const T> in, bool normalize) {
8992
for (int64_t n = 0; n < in_length; n++) {
9093
out_val += factor * in[n] * std::cos(phase_mul * (n + 0.5) * (k + 0.5));
9194
}
92-
out[k] = out_val;
95+
float coeff = lifter ? (1.0 + lifter / 2 * std::sin(M_PI / lifter * (k + 1))) : 1.f;
96+
out[k] = out_val * coeff;
9397
}
9498
}
9599

96100

97101
template <typename T>
98-
void ReferenceDct(int dct_type, span<T> out, span<const T> in, bool normalize) {
102+
void ReferenceDct(int dct_type, span<T> out, span<const T> in, bool normalize, float lifter = 0) {
99103
switch (dct_type) {
100104
case 1:
101-
ReferenceDctTypeI(out, in, normalize);
105+
ReferenceDctTypeI(out, in, normalize, lifter);
102106
break;
103107

104108
case 2:
105-
ReferenceDctTypeII(out, in, normalize);
109+
ReferenceDctTypeII(out, in, normalize, lifter);
106110
break;
107111

108112
case 3:
109-
ReferenceDctTypeIII(out, in, normalize);
113+
ReferenceDctTypeIII(out, in, normalize, lifter);
110114
break;
111115

112116
case 4:
113-
ReferenceDctTypeIV(out, in, normalize);
117+
ReferenceDctTypeIV(out, in, normalize, lifter);
114118
break;
115119

116120
default:

dali/operators/audio/mfcc/mfcc.cc

+30-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
1+
// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -19,8 +19,6 @@
1919
#include "dali/kernels/common/for_axis.h"
2020
#include "dali/pipeline/data/views.h"
2121

22-
23-
#define MFCC_SUPPORTED_TYPES (float)
2422
#define MFCC_SUPPORTED_NDIMS (2, 3, 4)
2523

2624
static constexpr int kNumInputs = 1;
@@ -30,6 +28,28 @@ namespace dali {
3028

3129
namespace detail {
3230

31+
template <>
32+
DLL_PUBLIC void LifterCoeffs<CPUBackend>::Calculate(int64_t target_length, float lifter,
33+
cudaStream_t) {
34+
// If different lifter argument, clear previous coefficients
35+
if (lifter_ != lifter) {
36+
coeffs_.clear();
37+
lifter_ = lifter;
38+
}
39+
40+
// 0 means no liftering
41+
if (lifter_ == 0.0f)
42+
return;
43+
44+
// Calculate remaining coefficients (if necessary)
45+
if (static_cast<int64_t>(coeffs_.size()) < target_length) {
46+
int64_t start = coeffs_.size(), end = target_length;
47+
coeffs_.resize(target_length);
48+
CalculateCoeffs(coeffs_.data() + start, start, target_length - start);
49+
}
50+
}
51+
52+
3353
template <typename T, int Dims>
3454
void ApplyLifter(const kernels::OutTensorCPU<T, Dims> &inout, int axis, const T* lifter_coeffs) {
3555
assert(axis >= 0 && axis < Dims);
@@ -93,6 +113,7 @@ the following formula::
93113
template <>
94114
bool MFCC<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
95115
const workspace_t<CPUBackend> &ws) {
116+
GetArguments(ws);
96117
output_desc.resize(kNumOutputs);
97118
const auto &input = ws.InputRef<CPUBackend>(0);
98119
auto &output = ws.OutputRef<CPUBackend>(0);
@@ -116,11 +137,11 @@ bool MFCC<CPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
116137
output_desc[0].shape.resize(nsamples, Dims);
117138
for (int i = 0; i < nsamples; i++) {
118139
const auto in_view = view<const T, Dims>(input[i]);
119-
auto &req = kmgr_.Setup<DctKernel>(i, ctx, in_view, args_, axis_);
120-
output_desc[0].shape.set_tensor_shape(i, req.output_shapes[0][0].shape);
121-
122-
if (in_view.shape[axis_] > max_length) {
123-
max_length = in_view.shape[axis_];
140+
auto &req = kmgr_.Setup<DctKernel>(i, ctx, in_view, args_[i], axis_);
141+
auto out_shape = req.output_shapes[0][0];
142+
output_desc[0].shape.set_tensor_shape(i, out_shape);
143+
if (out_shape[axis_] > max_length) {
144+
max_length = out_shape[axis_];
124145
}
125146
}
126147
), DALI_FAIL(make_string("Unsupported number of dimensions ", in_shape.size()))); // NOLINT
@@ -147,7 +168,7 @@ void MFCC<CPUBackend>::RunImpl(workspace_t<CPUBackend> &ws) {
147168
kernels::KernelContext ctx;
148169
auto in_view = view<const T, Dims>(input[i]);
149170
auto out_view = view<T, Dims>(output[i]);
150-
kmgr_.Run<DctKernel>(thread_id, i, ctx, out_view, in_view, args_, axis_);
171+
kmgr_.Run<DctKernel>(thread_id, i, ctx, out_view, in_view, args_[i], axis_);
151172
if (lifter_ != 0.0f) {
152173
assert(static_cast<int64_t>(lifter_coeffs_.size()) >= out_view.shape[axis_]);
153174
detail::ApplyLifter(out_view, axis_, lifter_coeffs_.data());

0 commit comments

Comments
 (0)