Skip to content

Comments

WIP: add custom cuda kernel for GPU accumulation into output image#27

Draft
McHaillet wants to merge 1 commit intoteamtomo:mainfrom
McHaillet:custom_cuda_kernel_index_put
Draft

WIP: add custom cuda kernel for GPU accumulation into output image#27
McHaillet wants to merge 1 commit intoteamtomo:mainfrom
McHaillet:custom_cuda_kernel_index_put

Conversation

@McHaillet
Copy link
Contributor

@McHaillet McHaillet commented Jul 17, 2025

@alisterburt I couldnt help myself, this is my naive implementation of a custom CUDA kernel for value accumulation (replacement for index_put with accumulate=True).

This is a naive implementation that assumes idx_c, idx_z, etc... are all 1-dimensional tensors and not the (b, c, z, y, x). I realised later that that was the case, that will probably make the indexing in the kernels a bit more complicated as you would have to deal with broadcasting.

For now, I am not gonna continue fixating on finishing this but feel free to continue it (or anyone else).

When using it in torch-fourier-slice with this exact code:

import mrcfile
import torch
import time
from scipy.stats import special_ortho_group

from torch_fourier_slice import backproject_2d_to_3d, project_3d_to_2d

N_IMAGES = 1000
torch.manual_seed(42)

# load a volume and normalise
volume = torch.tensor(
    mrcfile.read("/home/marten/data/datasets/emdb/test/emd_48372_10A.mrc"),
)
volume -= torch.mean(volume)
volume /= torch.std(volume)

# rotation matrices for projection (operate on xyz column vectors)
rotations = torch.tensor(
    special_ortho_group.rvs(dim=3, size=N_IMAGES, random_state=42),
).float()

for device in (torch.device("cuda:0"), torch.device("cpu")):
    volume = volume.to(device)
    rotations = rotations.to(device)

    # warm-up to compile custom kernels
    projections = project_3d_to_2d(
        volume,
        rotation_matrices=rotations,
        pad_factor=1.0,
    )  # (b, h, w)
    reconstruction = backproject_2d_to_3d(
        images=projections,
        rotation_matrices=rotations,
        pad_factor=1.0,
    )

    torch.cuda.synchronize()
    print('start the timer')
    start = time.perf_counter()

    for x in [1.0, 1.5, 2.0]:
        # make projections
        projections = project_3d_to_2d(
            volume,
            rotation_matrices=rotations,
            pad_factor=x,
        )  # (b, h, w)

        # reconstruct volume from projections
        reconstruction = backproject_2d_to_3d(
            images=projections,
            rotation_matrices=rotations,
            pad_factor=x,
        )
        reconstruction -= torch.mean(reconstruction)
        reconstruction = reconstruction / torch.std(reconstruction)

    torch.cuda.synchronize()
    end = time.perf_counter()
    print(f"Projected time on {device}: {end - start}")

I now get these timings, which are a speed-up to what I reported before in teamtomo/torch-fourier-slice#27 :
Projected time on cuda:0: 0.53 sec.
Projected time on cpu: 1.93 sec. (multithreaded CPU)

@alisterburt
Copy link
Collaborator

Hahah very nice, love to see it and very cool that you got it working! I'm a bit surprised there isn't more of a speedup, I guess the real wins would come from not generating/manipulating the index arrays and doing it all in the kernel 🥲

@alisterburt
Copy link
Collaborator

Questions I have about adding compiled stuff into teamtomo packages

  • how do we deal with backward passes? (Explicit separate code paths? Write the gradient func too and register as torch.autograd.Function?)
  • how do we deal with packaging?

@McHaillet
Copy link
Contributor Author

Hahah very nice, love to see it and very cool that you got it working! I'm a bit surprised there isn't more of a speedup, I guess the real wins would come from not generating/manipulating the index arrays and doing it all in the kernel 🥲

Yea I know... I think its multiple reasons:

  • The CPU implementation is also multithreaded and uses my full chip. Although the GPU has more cores, CPU's have much higher clock speed.
  • We are flattening coordinates. Probably if we would use some sort of texture memory approach where we exploit the spatial relation of neighboring pixels, the GPU would excel further
  • Overhead of all the calls to the GPU.
  • My example does projection and backprojection, only back projection uses these kernels. The projection part is done with grid sample (right?). (This I could easily check by the way... might do tomorrow)

@McHaillet
Copy link
Contributor Author

Questions I have about adding compiled stuff into teamtomo packages

* how do we deal with backward passes? (Explicit separate code paths? Write the gradient func too and register as torch.autograd.Function?)

haha, no clue

* how do we deal with packaging?

Easiest is probably by having JIT compilation with caching on the machine (i believe this already does that). The first run of the function will require some compilation time, but subsequent stuff can just use the cache.

@alisterburt
Copy link
Collaborator

Heh, right - there's some hidden complexity with the JIT too, compilation is only done for a specific shape which leads to fighting the JIT overhead if you use it in code paths with dynamic shapes

Not discounting anything just things to keep in mind :-)

@McHaillet
Copy link
Contributor Author

McHaillet commented Jul 22, 2025

Some answers:

  • we do need to write the backward pass as well, BUT Claude can probably get us a long way

  • the JIT compilation is not shape specific, so you just need to compile once when you first run the kernel

  • when only measuring the backproject code, the speed-up is two fold for cuda, I think that could already be worth putting some time into

I will make some measurements of the pure kernel as well to get a better sense of the performance without boiler plate.

@McHaillet
Copy link
Contributor Author

Here some plots of the speed-up:

accumulate_benchmark

Torch seems to have a different implementation for complex values as the custom kernel gets more performance drop from it.

@alisterburt
Copy link
Collaborator

All useful info, thanks!

I haven't thought too carefully but my intuition is if we're going to write a kernel it might as well be for the whole backprojection rather than just this step...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants