Description
In my project, I am using PointCNN for a segmentation task. Recently, I did a performance testing using Nvidia Nsight System to identify potential bottlenecks. During these test, I observed that the KNN kernel consumed approximately 89% of the total inference time, which seems abnormally high.
Below, I have included several screenshots that highlight this performance issue:
- CUDA Kernel Summary. The KNN kernel took ~89% of the total inference time.

- Single batch inference analysis. The KNN operation within the
dec3
layer consumed nearly half of the inference time.

- KNN/FPS execution time and input shapes. The tests were conducted with a batch size of 24, where each item consisted of 8192 point samples.
The execution time of the KNN operation appears to increase exponentially as the value of the k
parameter grows. Below are some examples of execution times with varying k
values (same number of input points but with different numbers of neighbors):
Layer | k | Execution Time (ms) |
---|---|---|
conv1 | 8 | 13 |
dec4 | 32 | 241 |
dec3 | 48 | 681 |
I reviewed the CUDA implementation of KNN and suspect that the main reason for this slowdown is related to adjusting to best_dist
and best_idx
arrays.
// n_y is current request point,
// for which we going to calculate k nearest neighbors across n_x points
// for every input point
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
// ...
// calculate distance from n_y to n_x, save into tmp_dist and
// ...
// adjust best_dist and best_idx arrays on every step
// probably the slowest part with increased k value
for (int64_t e1 = 0; e1 < k; e1++) {
if (best_dist[e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) {
best_dist[e2] = best_dist[e2 - 1];
best_idx[e2] = best_idx[e2 - 1];
}
best_dist[e1] = tmp_dist;
best_idx[e1] = n_x;
break;
}
}
}
So, I think there are several main issues with a current code:
- Managing the
best_dist
array insideknn_kernel
appears to take significant time and is not the most efficient code to run on GPU. - Recomputing distances for each FPS/KNN call seems inefficient, probably there is sense to do it once.
Questions
- Is there a fundamental issue with my implementation or an incorrect usage of the KNN/FPS operations?
- Would pre-computing the distances between points on the CPU within the data loader be a good option to consider?
Implementation Details
Below is the PointCNN implementation used in this project (the model is run through torch.compile
, excluding the KNN and FPS operations):
class PointCnnSegm(torch.nn.Module):
def __init__(
self,
num_classes: int,
x_features: int = 3,
fps_ratio: list[float] = [0.1, 0.5, 0.334],
):
super().__init__()
self.conv1 = XConv(x_features, 256, dim=3, \
kernel_size=8, hidden_channels=128)
self.conv2 = XConv(256, 256, dim=3, kernel_size=12, dilation=2)
self.conv3 = XConv(256, 512, dim=3, kernel_size=16, dilation=2)
self.conv4 = XConv(512, 1024, dim=3, kernel_size=16, dilation=6)
self.dec1 = XConv(1024+512, 512, dim=3, kernel_size=16, dilation=6)
self.dec2 = XConv(512+256, 256, dim=3, kernel_size=12, dilation=6)
self.dec3 = XConv(256+256, 256, dim=3, kernel_size=8, dilation=6)
self.dec4 = XConv(256+256, 256, dim=3, kernel_size=8, dilation=4)
self.head = torch.nn.Sequential(
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, num_classes)
)
self.fps_ratio = fps_ratio
def forward(self, enc1_x, enc1_pos, enc1_batch):
enc1_x = relu(self.conv1(enc1_x, enc1_pos, enc1_batch))
idx = fps(enc1_pos, enc1_batch, ratio=self.fps_ratio[0])
enc2_x, enc2_pos, enc2_batch = \
enc1_x[idx], enc1_pos[idx], enc1_batch[idx]
enc2_x = relu(self.conv2(enc2_x, enc2_pos, enc2_batch))
idx = fps(enc2_pos, enc2_batch, ratio=self.fps_ratio[1])
enc3_x, enc3_pos, enc3_batch = \
enc2_x[idx], enc2_pos[idx], enc2_batch[idx]
enc3_x = relu(self.conv3(enc3_x, enc3_pos, enc3_batch))
idx = fps(enc3_pos, enc3_batch, ratio=self.fps_ratio[2])
enc4_x, enc4_pos, enc4_batch = \
enc3_x[idx], enc3_pos[idx], enc3_batch[idx]
enc4_x = relu(self.conv4(enc4_x, enc4_pos, enc4_batch))
dec1_x = knn_interpolate(enc4_x, enc4_pos, enc3_pos, \
enc4_batch, enc3_batch, k=3)
dec1_x = torch.cat([dec1_x, enc3_x], dim=1)
dec1_x = relu(self.dec1(dec1_x, enc3_pos, enc3_batch))
dec2_x = knn_interpolate(dec1_x, enc3_pos, enc2_pos, \
enc3_batch, enc2_batch, k=3)
dec2_x = torch.cat([dec2_x, enc2_x], dim=1)
dec2_x = relu(self.dec2(dec2_x, enc2_pos, enc2_batch))
dec3_x = knn_interpolate(dec2_x, enc2_pos, enc1_pos, \
enc2_batch, enc1_batch, k=3)
dec3_x = torch.cat([dec3_x, enc1_x], dim=1)
dec3_x = relu(self.dec3(dec3_x, enc1_pos, enc1_batch))
dec4_x = torch.cat([dec3_x, enc1_x], dim=1)
dec4_x = relu(self.dec4(dec4_x, enc1_pos, enc1_batch))
out = self.head(dec4_x)
return out
XConv implementation:
class XConv(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, dim: int,
kernel_size: int, hidden_channels: int | None = None,
dilation: int = 1, bias: bool = True, num_workers: int = 1):
super().__init__()
self.in_channels = in_channels
if hidden_channels is None:
hidden_channels = in_channels // 4
assert hidden_channels > 0
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.dim = dim
self.kernel_size = kernel_size
self.dilation = dilation
self.num_workers = num_workers
C_in, C_delta, C_out = in_channels, hidden_channels, out_channels
D, K = dim, kernel_size
self.mlp1 = torch.nn.Sequential(
torch.nn.Linear(dim, C_delta),
torch.nn.ELU(),
torch.nn.BatchNorm1d(C_delta),
torch.nn.Linear(C_delta, C_delta),
torch.nn.ELU(),
torch.nn.BatchNorm1d(C_delta),
Reshape(-1, K, C_delta),
)
self.mlp2 = torch.nn.Sequential(
torch.nn.Linear(D * K, K**2),
torch.nn.ELU(),
torch.nn.BatchNorm1d(K**2),
Reshape(-1, K, K),
torch.nn.Conv1d(K, K**2, K, groups=K),
torch.nn.ELU(),
torch.nn.BatchNorm1d(K**2),
Reshape(-1, K, K),
torch.nn.Conv1d(K, K**2, K, groups=K),
torch.nn.BatchNorm1d(K**2),
Reshape(-1, K, K),
)
C_in = C_in + C_delta
depth_multiplier = int(ceil(C_out / C_in))
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),
Reshape(-1, C_in * depth_multiplier),
torch.nn.Linear(C_in * depth_multiplier, C_out, bias=bias),
)
self.reset_parameters()
def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
reset(self.mlp1)
reset(self.mlp2)
reset(self.conv)
def forward(
self,
x: torch.Tensor,
pos: torch.Tensor,
batch: torch.Tensor | None = None
):
r"""Runs the forward pass of the module."""
pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos
(N, D), K = pos.size(), self.kernel_size
edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,
flow='target_to_source',
num_workers=self.num_workers)
if self.dilation > 1:
edge_index = edge_index[:, ::self.dilation]
row, col = edge_index[0], edge_index[1]
pos = pos[col] - pos[row]
x_star = self.mlp1(pos)
if x is not None:
x = x.unsqueeze(-1) if x.dim() == 1 else x
x = x[col].view(N, K, self.in_channels)
x_star = torch.cat([x_star, x], dim=-1)
x_star = x_star.transpose(1, 2).contiguous()
transform_matrix = self.mlp2(pos.view(N, K * D))
x_transformed = torch.matmul(x_star, transform_matrix)
out = self.conv(x_transformed)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')
Environment Details
- PyTorch == 2.2.1
- PyG == 2.5.0
- Torch Cluster == 1.6.3
- Python 3.10
- NVIDIA GeForce RTX 4090
- CUDA Version 12.5
Thank you for your assistance!
Activity