Skip to content

Commit 2363bdd

Browse files
committed
make sog export less vram heavy
1 parent 92a8635 commit 2363bdd

4 files changed

Lines changed: 363 additions & 157 deletions

File tree

src/io/cuda/kmeans.cu

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
#include <random>
1212
#include <thrust/device_ptr.h>
1313
#include <thrust/sequence.h>
14-
#include <thrust/shuffle.h>
1514
#include <thrust/sort.h>
15+
#include <unordered_map>
1616
#include <vector>
1717

1818
namespace lfs::io {
@@ -25,6 +25,14 @@ namespace lfs::io {
2525
constexpr int CHUNK_SIZE = 128;
2626
constexpr int BLOCK_SIZE = 256;
2727

28+
// Tensor::cuda() deep-copies CUDA tensors in this codebase; borrow when possible.
29+
Tensor as_cuda_contiguous(const Tensor& data) {
30+
if (data.device() == Device::CUDA) {
31+
return data.is_contiguous() ? data : data.contiguous();
32+
}
33+
return data.cuda().contiguous();
34+
}
35+
2836
template <int N_DIMS>
2937
__global__ void gather_centroids_kernel(
3038
const float* __restrict__ data,
@@ -42,11 +50,30 @@ namespace lfs::io {
4250
}
4351
}
4452

45-
void init_random_indices_gpu(int* d_indices, const int n, const unsigned int seed) {
46-
thrust::device_ptr<int> indices_ptr(d_indices);
47-
thrust::sequence(indices_ptr, indices_ptr + n);
48-
thrust::default_random_engine rng(seed);
49-
thrust::shuffle(indices_ptr, indices_ptr + n, rng);
53+
std::vector<int> sample_unique_indices(const int n, const int count, const unsigned int seed) {
54+
std::vector<int> indices(static_cast<size_t>(count));
55+
std::unordered_map<int, int> swaps;
56+
swaps.reserve(static_cast<size_t>(count) * 2);
57+
58+
auto value_for = [&swaps](const int index) {
59+
auto it = swaps.find(index);
60+
return it == swaps.end() ? index : it->second;
61+
};
62+
63+
std::mt19937 rng(seed);
64+
for (int i = 0; i < count; ++i) {
65+
std::uniform_int_distribution<int> dist(i, n - 1);
66+
const int pick = dist(rng);
67+
indices[static_cast<size_t>(i)] = value_for(pick);
68+
swaps[pick] = value_for(i);
69+
}
70+
71+
return indices;
72+
}
73+
74+
Tensor sample_unique_indices_gpu(const int n, const int count, const unsigned int seed) {
75+
auto indices = sample_unique_indices(n, count, seed);
76+
return Tensor::from_vector(indices, {static_cast<size_t>(count)}, Device::CUDA);
5077
}
5178

5279
void build_csr_offsets_gpu(
@@ -263,17 +290,16 @@ namespace lfs::io {
263290
return {centroids, labels};
264291
}
265292

266-
auto data_gpu = data.cuda().contiguous();
293+
auto data_gpu = as_cuda_contiguous(data);
267294
const float* d_data = data_gpu.ptr<float>();
268295

269296
auto centroids = Tensor::zeros({static_cast<size_t>(k), static_cast<size_t>(N_DIMS)},
270297
Device::CUDA, DataType::Float32);
271298
float* d_centroids = centroids.ptr<float>();
272299

273300
{
274-
auto perm = Tensor::zeros({static_cast<size_t>(n)}, Device::CUDA, DataType::Int32);
275301
std::random_device rd;
276-
init_random_indices_gpu(perm.ptr<int>(), n, rd());
302+
auto perm = sample_unique_indices_gpu(n, k, rd());
277303

278304
const int grid_k = (k + BLOCK_SIZE - 1) / BLOCK_SIZE;
279305
gather_centroids_kernel<N_DIMS><<<grid_k, BLOCK_SIZE>>>(
@@ -379,10 +405,13 @@ namespace lfs::io {
379405
constexpr int SUPER_CHUNK_SIZE = 64;
380406

381407
template <int N_DIMS>
382-
__global__ void find_nearest_super_clusters_kernel(
408+
__global__ void hierarchical_search_fused_kernel(
383409
const float* __restrict__ data,
410+
const float* __restrict__ centroids,
384411
const float* __restrict__ super_centroids,
385-
int* __restrict__ nearest_supers,
412+
const int* __restrict__ super_offsets,
413+
const int* __restrict__ super_indices,
414+
int* __restrict__ labels,
386415
const int n_points) {
387416
__shared__ float shared_supers[SUPER_CHUNK_SIZE * N_DIMS];
388417

@@ -449,41 +478,11 @@ namespace lfs::io {
449478
}
450479

451480
if (tid < n_points) {
452-
for (int i = 0; i < NUM_NEAREST_SUPERS; ++i) {
453-
nearest_supers[tid * NUM_NEAREST_SUPERS + i] = best_idxs[i];
454-
}
455-
}
456-
}
481+
float min_dist = 1e30f;
482+
int min_idx = 0;
457483

458-
template <int N_DIMS>
459-
__global__ void hierarchical_search_kernel(
460-
const float* __restrict__ data,
461-
const float* __restrict__ centroids,
462-
const int* __restrict__ nearest_supers,
463-
const int* __restrict__ super_offsets,
464-
const int* __restrict__ super_indices,
465-
int* __restrict__ labels,
466-
const int n_points) {
467-
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
468-
469-
float point[N_DIMS];
470-
if (tid < n_points) {
471-
#pragma unroll
472-
for (int d = 0; d < N_DIMS; ++d) {
473-
point[d] = data[tid * N_DIMS + d];
474-
}
475-
}
476-
477-
float min_dist = 1e30f;
478-
int min_idx = 0;
479-
480-
for (int si = 0; si < NUM_NEAREST_SUPERS; ++si) {
481-
int super_idx = 0;
482-
if (tid < n_points) {
483-
super_idx = nearest_supers[tid * NUM_NEAREST_SUPERS + si];
484-
}
485-
486-
if (tid < n_points) {
484+
for (int si = 0; si < NUM_NEAREST_SUPERS; ++si) {
485+
const int super_idx = best_idxs[si];
487486
const int start = super_offsets[super_idx];
488487
const int end = super_offsets[super_idx + 1];
489488

@@ -502,9 +501,7 @@ namespace lfs::io {
502501
}
503502
}
504503
}
505-
}
506504

507-
if (tid < n_points) {
508505
labels[tid] = min_idx;
509506
}
510507
}
@@ -522,17 +519,16 @@ namespace lfs::io {
522519
return {centroids, labels};
523520
}
524521

525-
auto data_gpu = data.cuda().contiguous();
522+
auto data_gpu = as_cuda_contiguous(data);
526523
const float* d_data = data_gpu.ptr<float>();
527524

528525
auto centroids = Tensor::zeros({static_cast<size_t>(k), static_cast<size_t>(N_DIMS)},
529526
Device::CUDA, DataType::Float32);
530527
float* d_centroids = centroids.ptr<float>();
531528

532529
{
533-
auto perm = Tensor::zeros({static_cast<size_t>(n)}, Device::CUDA, DataType::Int32);
534530
std::random_device rd;
535-
init_random_indices_gpu(perm.ptr<int>(), n, rd());
531+
auto perm = sample_unique_indices_gpu(n, k, rd());
536532

537533
const int grid_k = (k + BLOCK_SIZE - 1) / BLOCK_SIZE;
538534
gather_centroids_kernel<N_DIMS><<<grid_k, BLOCK_SIZE>>>(
@@ -544,9 +540,8 @@ namespace lfs::io {
544540
auto super_membership = Tensor::zeros({static_cast<size_t>(k)}, Device::CUDA, DataType::Int32);
545541

546542
{
547-
auto perm = Tensor::zeros({static_cast<size_t>(k)}, Device::CUDA, DataType::Int32);
548543
std::random_device rd;
549-
init_random_indices_gpu(perm.ptr<int>(), k, rd());
544+
auto perm = sample_unique_indices_gpu(k, NUM_SUPER_CLUSTERS, rd());
550545

551546
const int grid_super = (NUM_SUPER_CLUSTERS + BLOCK_SIZE - 1) / BLOCK_SIZE;
552547
gather_centroids_kernel<N_DIMS><<<grid_super, BLOCK_SIZE>>>(
@@ -585,8 +580,6 @@ namespace lfs::io {
585580
super_indices.ptr<int>(), k, NUM_SUPER_CLUSTERS);
586581

587582
auto labels = Tensor::zeros({static_cast<size_t>(n)}, Device::CUDA, DataType::Int32);
588-
auto nearest_supers = Tensor::zeros({static_cast<size_t>(n), static_cast<size_t>(NUM_NEAREST_SUPERS)},
589-
Device::CUDA, DataType::Int32);
590583
auto centroid_sums = Tensor::zeros({static_cast<size_t>(k), static_cast<size_t>(N_DIMS)},
591584
Device::CUDA, DataType::Float32);
592585
auto centroid_counts = Tensor::zeros({static_cast<size_t>(k)}, Device::CUDA, DataType::Int32);
@@ -616,11 +609,8 @@ namespace lfs::io {
616609
assign_nearest_bruteforce_kernel<N_DIMS><<<grid_n, BLOCK_SIZE>>>(
617610
d_data, d_centroids, labels.ptr<int>(), n, k);
618611
} else {
619-
find_nearest_super_clusters_kernel<N_DIMS><<<grid_n, BLOCK_SIZE>>>(
620-
d_data, super_centroids.ptr<float>(), nearest_supers.ptr<int>(), n);
621-
622-
hierarchical_search_kernel<N_DIMS><<<grid_n, BLOCK_SIZE>>>(
623-
d_data, d_centroids, nearest_supers.ptr<int>(),
612+
hierarchical_search_fused_kernel<N_DIMS><<<grid_n, BLOCK_SIZE>>>(
613+
d_data, d_centroids, super_centroids.ptr<float>(),
624614
super_offsets.ptr<int>(), super_indices.ptr<int>(),
625615
labels.ptr<int>(), n);
626616
}
@@ -710,7 +700,7 @@ namespace lfs::io {
710700
return {Tensor(), Tensor()};
711701
}
712702

713-
auto data_gpu = data_2d.cuda().contiguous();
703+
auto data_gpu = as_cuda_contiguous(data_2d);
714704
const int n = data_gpu.shape()[0];
715705

716706
if (n <= k) {

0 commit comments

Comments
 (0)