Skip to content

Commit 4e99fc0

Browse files
committed
Autovectorization optimizations
1 parent 9968f6d commit 4e99fc0

File tree

1 file changed

+82
-63
lines changed

1 file changed

+82
-63
lines changed

src/denoise/mod.rs

Lines changed: 82 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -312,89 +312,105 @@ where
312312
.map(|f| f[p].data_origin())
313313
.collect::<ArrayVec<_, TB_SIZE>>();
314314

315-
for y in (0..effective_height).step_by(INC) {
316-
for x in (0..=(pad_width - SB_SIZE)).step_by(INC) {
317-
for z in 0..TB_SIZE {
318-
self.proc0(
319-
&src_planes[z][x..],
320-
&self.hw()[(BLOCK_AREA * z)..],
321-
&mut dftr[(BLOCK_AREA * z)..],
322-
src_stride,
315+
// SAFETY: We know the size of the planes we're working on,
316+
// so we can safely ensure we are not out of bounds.
317+
// There are a fair number of unsafe function calls here
318+
// which are unsafe for optimization purposes.
319+
// All are safe as long as we do not pass out-of-bounds parameters.
320+
unsafe {
321+
for y in (0..effective_height).step_by(INC) {
322+
for x in (0..=(pad_width - SB_SIZE)).step_by(INC) {
323+
for z in 0..TB_SIZE {
324+
self.proc0(
325+
src_planes[z].get_unchecked(x..),
326+
self.hw().get_unchecked((BLOCK_AREA * z)..),
327+
dftr.get_unchecked_mut((BLOCK_AREA * z)..),
328+
src_stride,
329+
SB_SIZE,
330+
self.src_scale(),
331+
);
332+
}
333+
334+
self.real_to_complex_3d(&dftr, &mut dftc);
335+
self.remove_mean(&mut dftc, self.dftgc(), &mut means);
336+
337+
self.filter_coeffs(&mut dftc);
338+
339+
self.add_mean(&mut dftc, &means);
340+
self.complex_to_real_3d(&dftc, &mut dftr);
341+
342+
self.proc1(
343+
dftr.get_unchecked((TB_MIDPOINT * BLOCK_AREA)..),
344+
self.hw().get_unchecked((TB_MIDPOINT * BLOCK_AREA)..),
345+
ebuff.get_unchecked_mut((y * ebuff_stride + x)..),
323346
SB_SIZE,
324-
self.src_scale(),
347+
ebuff_stride,
325348
);
326349
}
327350

328-
self.real_to_complex_3d(&dftr, &mut dftc);
329-
self.remove_mean(&mut dftc, self.dftgc(), &mut means);
330-
331-
self.filter_coeffs(&mut dftc);
332-
333-
self.add_mean(&mut dftc, &means);
334-
self.complex_to_real_3d(&dftc, &mut dftr);
335-
336-
self.proc1(
337-
&dftr[(TB_MIDPOINT * BLOCK_AREA)..],
338-
&self.hw()[(TB_MIDPOINT * BLOCK_AREA)..],
339-
&mut ebuff[(y * ebuff_stride + x)..],
340-
SB_SIZE,
341-
ebuff_stride,
342-
);
351+
for q in 0..TB_SIZE {
352+
src_planes[q] = &src_planes[q][(INC * src_stride)..];
353+
}
343354
}
344355

345-
for q in 0..TB_SIZE {
346-
src_planes[q] = &src_planes[q][(INC * src_stride)..];
347-
}
356+
let dest_width = dest.planes[p].cfg.width;
357+
let dest_height = dest.planes[p].cfg.height;
358+
let dest_stride = dest.planes[p].cfg.stride;
359+
let dest_plane = dest.planes[p].data_origin_mut();
360+
let ebp_offset = ebuff_stride * ((pad_height - dest_height) / 2)
361+
+ (pad_width - dest_width) / 2;
362+
let ebp = &ebuff[ebp_offset..];
363+
364+
self.cast(
365+
ebp,
366+
dest_plane,
367+
dest_width,
368+
dest_height,
369+
dest_stride,
370+
ebuff_stride,
371+
);
348372
}
349-
350-
let dest_width = dest.planes[p].cfg.width;
351-
let dest_height = dest.planes[p].cfg.height;
352-
let dest_stride = dest.planes[p].cfg.stride;
353-
let dest_plane = dest.planes[p].data_origin_mut();
354-
let ebp_offset = ebuff_stride * ((pad_height - dest_height) / 2)
355-
+ (pad_width - dest_width) / 2;
356-
let ebp = &ebuff[ebp_offset..];
357-
358-
self.cast(
359-
ebp,
360-
dest_plane,
361-
dest_width,
362-
dest_height,
363-
dest_stride,
364-
ebuff_stride,
365-
);
366373
}
367374
}
368375

369-
fn proc0(
376+
#[inline]
377+
unsafe fn proc0(
370378
&self, s0: &[T], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize,
371379
src_scale: f32,
372380
) {
373-
let s0 = s0.chunks(p0);
374-
let s1 = s1.chunks(p1);
375-
let dest = dest.chunks_mut(p1);
381+
let s0 = s0.as_ptr();
382+
let s1 = s1.as_ptr();
383+
let dest = dest.as_mut_ptr();
376384

377-
for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p1) {
385+
for u in 0..p1 {
378386
for v in 0..p1 {
379-
dest[v] = u16::cast_from(s0[v]) as f32 * src_scale * s1[v];
387+
let s0 = s0.add(u * p0 + v);
388+
let s1 = s1.add(u * p1 + v);
389+
let dest = dest.add(u * p1 + v);
390+
dest.write(u16::cast_from(s0.read()) as f32 * src_scale * s1.read())
380391
}
381392
}
382393
}
383394

384-
fn proc1(
395+
#[inline]
396+
unsafe fn proc1(
385397
&self, s0: &[f32], s1: &[f32], dest: &mut [f32], p0: usize, p1: usize,
386398
) {
387-
let s0 = s0.chunks(p0);
388-
let s1 = s1.chunks(p0);
389-
let dest = dest.chunks_mut(p1);
399+
let s0 = s0.as_ptr();
400+
let s1 = s1.as_ptr();
401+
let dest = dest.as_mut_ptr();
390402

391-
for (s0, (s1, dest)) in s0.zip(s1.zip(dest)).take(p0) {
403+
for u in 0..p0 {
392404
for v in 0..p0 {
393-
dest[v] += s0[v] * s1[v];
405+
let s0 = s0.add(u * p0 + v);
406+
let s1 = s1.add(u * p0 + v);
407+
let dest = dest.add(u * p1 + v);
408+
dest.write(dest.read() + s0.read() * s1.read());
394409
}
395410
}
396411
}
397412

413+
#[inline]
398414
fn remove_mean(
399415
&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT],
400416
dftgc: &[Complex<f32>; COMPLEX_COUNT],
@@ -410,6 +426,7 @@ where
410426
}
411427
}
412428

429+
#[inline]
413430
fn add_mean(
414431
&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT],
415432
means: &[Complex<f32>; COMPLEX_COUNT],
@@ -420,6 +437,7 @@ where
420437
}
421438
}
422439

440+
#[inline]
423441
// Applies a generalized wiener filter
424442
fn filter_coeffs(&self, dftc: &mut [Complex<f32>; COMPLEX_COUNT]) {
425443
let sigmas = self.sigmas();
@@ -508,20 +526,21 @@ where
508526
}
509527
}
510528

511-
fn cast(
529+
unsafe fn cast(
512530
&self, ebuff: &[f32], dest: &mut [T], dest_width: usize,
513531
dest_height: usize, dest_stride: usize, ebp_stride: usize,
514532
) {
515-
let ebuff = ebuff.chunks(ebp_stride);
516-
let dest = dest.chunks_mut(dest_stride);
517533
let dest_scale = self.dest_scale();
518534
let peak = self.peak();
535+
let ebuff = ebuff.as_ptr();
536+
let dest = dest.as_mut_ptr();
519537

520-
for (ebuff, dest) in ebuff.zip(dest).take(dest_height) {
538+
for y in 0..dest_height {
521539
for x in 0..dest_width {
522-
let fval = ebuff[x].mul_add(dest_scale, 0.5);
523-
dest[x] =
524-
clamp(T::cast_from(fval.round() as u16), T::cast_from(0u16), peak);
540+
let dest = dest.add(y * dest_stride + x);
541+
let ebp = ebuff.add(y * ebp_stride + x);
542+
let fval = ebp.read().mul_add(dest_scale, 0.5);
543+
dest.write(clamp(T::cast_from(fval as u16), T::cast_from(0u16), peak));
525544
}
526545
}
527546
}

0 commit comments

Comments
 (0)