Skip to content

Commit a52b76a

Browse files
Expose the cudnn algo in the conv ops. (huggingface#2892)
* Set the algo. * Expose the cudnn preferred algo for conv ops.
1 parent fb660b8 commit a52b76a

12 files changed

Lines changed: 63 additions & 26 deletions

File tree

candle-core/examples/cuda_basics.rs

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,18 @@ extern crate intel_mkl_src;
66

77
use anyhow::Result;
88
use candle_core::{Device, Tensor};
9-
9+
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
1010
fn main() -> Result<()> {
1111
let device = Device::new_cuda(0)?;
12-
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
13-
.to_dtype(candle_core::DType::BF16)?;
14-
candle_core::cuda::set_gemm_reduced_precision_f32(false);
15-
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
16-
let _x1 = x.matmul(&x)?;
17-
drop(_x1);
18-
let start_time = std::time::Instant::now();
19-
let _x1 = x.matmul(&x)?;
20-
device.synchronize()?;
21-
println!("fp32: {:?}", start_time.elapsed());
22-
drop(_x1);
23-
candle_core::cuda::set_gemm_reduced_precision_f32(true);
24-
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
25-
let _x1 = x.matmul(&x)?;
26-
drop(_x1);
27-
let start_time = std::time::Instant::now();
28-
let _x1 = x.matmul(&x)?;
29-
device.synchronize()?;
30-
println!("tf32: {:?}", start_time.elapsed());
12+
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
13+
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
14+
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
3115
drop(_x1);
16+
for _ in 0..20 {
17+
let start_time = std::time::Instant::now();
18+
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
19+
device.synchronize()?;
20+
println!("conv1d: {:?}", start_time.elapsed());
21+
}
3222
Ok(())
3323
}

candle-core/src/conv.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl ParamsConvTranspose1D {
5555
}
5656
}
5757

58-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
58+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5959
pub enum CudnnFwdAlgo {
6060
ImplicitGemm,
6161
ImplicitPrecompGemm,
@@ -152,6 +152,19 @@ impl Tensor {
152152
stride: usize,
153153
dilation: usize,
154154
groups: usize,
155+
) -> Result<Self> {
156+
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
157+
}
158+
159+
/// Applies a 1D convolution over the input tensor.
160+
pub fn conv1d_with_algo(
161+
&self,
162+
kernel: &Self,
163+
padding: usize,
164+
stride: usize,
165+
dilation: usize,
166+
groups: usize,
167+
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
155168
) -> Result<Self> {
156169
let (c_out, c_in_k, k_size) = kernel.dims3()?;
157170
let (b_size, c_in, l_in) = self.dims3()?;
@@ -175,7 +188,7 @@ impl Tensor {
175188
padding,
176189
stride,
177190
dilation,
178-
cudnn_fwd_algo: None,
191+
cudnn_fwd_algo,
179192
};
180193
if groups == 1 {
181194
self.conv1d_single_group(kernel, &params)
@@ -280,6 +293,18 @@ impl Tensor {
280293
stride: usize,
281294
dilation: usize,
282295
groups: usize,
296+
) -> Result<Self> {
297+
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
298+
}
299+
300+
pub fn conv2d_with_algo(
301+
&self,
302+
kernel: &Self,
303+
padding: usize,
304+
stride: usize,
305+
dilation: usize,
306+
groups: usize,
307+
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
283308
) -> Result<Self> {
284309
let (b_size, c_in, i_h, i_w) = self.dims4()?;
285310
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@@ -299,7 +324,7 @@ impl Tensor {
299324
padding,
300325
stride,
301326
dilation,
302-
cudnn_fwd_algo: None,
327+
cudnn_fwd_algo,
303328
};
304329
if groups == 1 {
305330
self.conv2d_single_group(kernel, &params)

candle-examples/examples/yolo-v3/darknet.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
133133
padding,
134134
groups: 1,
135135
dilation: 1,
136+
cudnn_fwd_algo: None,
136137
};
137138
let conv = if bias {
138139
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?

candle-examples/examples/yolo-v8/model.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ impl ConvBlock {
9292
stride,
9393
groups: 1,
9494
dilation: 1,
95+
cudnn_fwd_algo: None,
9596
};
9697
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
9798
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;

candle-nn/src/conv.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
//! Convolution Layers.
22
use crate::BatchNorm;
3-
use candle::{Result, Tensor};
3+
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
44

55
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66
pub struct Conv1dConfig {
77
pub padding: usize,
88
pub stride: usize,
99
pub dilation: usize,
1010
pub groups: usize,
11+
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
1112
}
1213

1314
impl Default for Conv1dConfig {
@@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
1718
stride: 1,
1819
dilation: 1,
1920
groups: 1,
21+
cudnn_fwd_algo: None,
2022
}
2123
}
2224
}
@@ -52,12 +54,13 @@ impl Conv1d {
5254

5355
impl crate::Module for Conv1d {
5456
fn forward(&self, x: &Tensor) -> Result<Tensor> {
55-
let x = x.conv1d(
57+
let x = x.conv1d_with_algo(
5658
&self.weight,
5759
self.config.padding,
5860
self.config.stride,
5961
self.config.dilation,
6062
self.config.groups,
63+
self.config.cudnn_fwd_algo,
6164
)?;
6265
match &self.bias {
6366
None => Ok(x),
@@ -147,6 +150,7 @@ pub struct Conv2dConfig {
147150
pub stride: usize,
148151
pub dilation: usize,
149152
pub groups: usize,
153+
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
150154
}
151155

152156
impl Default for Conv2dConfig {
@@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
156160
stride: 1,
157161
dilation: 1,
158162
groups: 1,
163+
cudnn_fwd_algo: None,
159164
}
160165
}
161166
}
@@ -211,12 +216,13 @@ impl Conv2d {
211216

212217
impl crate::Module for Conv2d {
213218
fn forward(&self, x: &Tensor) -> Result<Tensor> {
214-
let x = x.conv2d(
219+
let x = x.conv2d_with_algo(
215220
&self.weight,
216221
self.config.padding,
217222
self.config.stride,
218223
self.config.dilation,
219224
self.config.groups,
225+
self.config.cudnn_fwd_algo,
220226
)?;
221227
match &self.bias {
222228
None => Ok(x),

candle-transformers/src/models/depth_anything_v2.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ impl ResidualConvUnit {
124124
stride: 1,
125125
dilation: 1,
126126
groups: 1,
127+
cudnn_fwd_algo: None,
127128
};
128129
let conv1 = conv2d(
129130
conf.num_features,
@@ -208,6 +209,7 @@ impl FeatureFusionBlock {
208209
stride: 1,
209210
dilation: 1,
210211
groups: 1,
212+
cudnn_fwd_algo: None,
211213
};
212214
let output_conv = conv2d(
213215
conf.num_features,
@@ -258,6 +260,7 @@ impl Scratch {
258260
stride: 1,
259261
dilation: 1,
260262
groups: 1,
263+
cudnn_fwd_algo: None,
261264
};
262265

263266
let layer1_rn = conv2d_no_bias(
@@ -319,6 +322,7 @@ impl Scratch {
319322
stride: 1,
320323
dilation: 1,
321324
groups: 1,
325+
cudnn_fwd_algo: None,
322326
};
323327
let output_conv1 = conv2d(
324328
conf.num_features,
@@ -425,6 +429,7 @@ impl DPTHead {
425429
stride: 2,
426430
dilation: 1,
427431
groups: 1,
432+
cudnn_fwd_algo: None,
428433
},
429434
vb.pp("resize_layers").pp("3"),
430435
)?),

candle-transformers/src/models/encodec.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ impl EncodecConv1d {
468468
stride,
469469
groups: 1,
470470
dilation: 1,
471+
cudnn_fwd_algo: None,
471472
},
472473
vb.pp("conv"),
473474
)?,

candle-transformers/src/models/mimi/conv.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ impl StreamableConv1d {
267267
stride,
268268
dilation,
269269
groups,
270+
cudnn_fwd_algo: None,
270271
};
271272
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
272273
if k_size < stride {

candle-transformers/src/models/stable_diffusion/resnet.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ impl ResnetBlock2D {
6868
padding: 1,
6969
groups: 1,
7070
dilation: 1,
71+
cudnn_fwd_algo: None,
7172
};
7273
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
7374
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@@ -83,6 +84,7 @@ impl ResnetBlock2D {
8384
padding: 0,
8485
groups: 1,
8586
dilation: 1,
87+
cudnn_fwd_algo: None,
8688
};
8789
Some(conv2d(
8890
in_channels,

candle-transformers/src/models/whisper/model.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,14 @@ impl AudioEncoder {
248248
stride: 1,
249249
groups: 1,
250250
dilation: 1,
251+
cudnn_fwd_algo: None,
251252
};
252253
let cfg2 = Conv1dConfig {
253254
padding: 1,
254255
stride: 2,
255256
groups: 1,
256257
dilation: 1,
258+
cudnn_fwd_algo: None,
257259
};
258260
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
259261
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

0 commit comments

Comments
 (0)