Skip to content

Commit 79f3970

Browse files
Refactor providers into separate libraries (#1190)
This removes most of the #if USE_CUDA and #if USE_DML blocks for the model handling code. Device memory management is also handled through the DeviceSpan structure and now all data copying is done in a device independent manner. It's a huge change, and there will be some rough edges when submitted. Goal is to unblock other people needing the changes and then to make larger improvements in future prs. --------- Co-authored-by: aciddelgado <[email protected]>
1 parent 391bce3 commit 79f3970

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1068
-1430
lines changed

cmake/global_variables.cmake

+4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ file(GLOB generator_srcs CONFIGURE_DEPENDS
6666
"${GENERATORS_ROOT}/*.cpp"
6767
"${GENERATORS_ROOT}/cpu/*.h"
6868
"${GENERATORS_ROOT}/cpu/*.cpp"
69+
"${GENERATORS_ROOT}/qnn/*.h"
70+
"${GENERATORS_ROOT}/qnn/*.cpp"
71+
"${GENERATORS_ROOT}/webgpu/*.h"
72+
"${GENERATORS_ROOT}/webgpu/*.cpp"
6973
"${MODELS_ROOT}/*.h"
7074
"${MODELS_ROOT}/*.cpp"
7175
)

src/beam_search_scorer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ BeamSearchScorer::BeamSearchScorer(const GeneratorParams& parameters)
6767

6868
// Space to store intermediate sequence
6969
size_t const per_beam = (max_length_ * (max_length_ + 1)) / 2;
70-
hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam, true);
70+
hypothesis_buffer_ = device.Allocate<int32_t>(batch_beam_size * per_beam);
7171

7272
memset(next_beam_scores_.Span().data(), 0, next_beam_scores_.Span().size_bytes());
7373

src/cpu/interface.cpp

+62-10
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33

44
#include "../generators.h"
55
#include "../search.h"
6+
#include "../models/utils.h"
67
#include "interface.h"
78

89
namespace Generators {
910

11+
static Ort::Allocator* ort_allocator_{};
1012
const char* label_cpu = "cpu";
1113

1214
struct CpuMemory final : DeviceBuffer {
1315
CpuMemory(size_t size) : owned_{true} {
1416
size_in_bytes_ = size;
15-
p_cpu_ = p_device_ = new uint8_t[size_in_bytes_];
17+
p_cpu_ = p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
1618
}
1719

1820
CpuMemory(void* p, size_t size) : owned_{false} {
@@ -22,39 +24,89 @@ struct CpuMemory final : DeviceBuffer {
2224

2325
~CpuMemory() override {
2426
if (owned_)
25-
delete[] p_device_;
27+
ort_allocator_->Free(p_device_);
2628
}
2729

2830
const char* GetType() const override { return label_cpu; }
2931
void AllocateCpu() override {} // Nothing to do, device is also CPU
3032
void CopyDeviceToCpu() override {} // Nothing to do, device is also CPU
3133
void CopyCpuToDevice() override {} // Nothing to do, device is also CPU
3234
void CopyFrom(size_t begin_dest, DeviceBuffer& source, size_t begin_source, size_t size_in_bytes) override {
33-
if (GetType() == label_cpu)
34-
memcpy(p_device_ + begin_dest, source.p_device_ + begin_source, size_in_bytes);
35-
else
36-
throw std::runtime_error("CpuMemory::CopyFromDevice not implemented for " + std::string(source.GetType()));
35+
CopyThroughCpu(*this, begin_dest, source, begin_source, size_in_bytes);
36+
}
37+
38+
void Zero() override {
39+
memset(p_device_, 0, size_in_bytes_);
3740
}
3841

3942
bool owned_;
4043
};
4144

4245
struct CpuInterface : DeviceInterface {
43-
std::shared_ptr<DeviceBuffer> AllocateBase(size_t size, bool cpu_accessible) override {
44-
// cpu_accessible is ignored, as with the cpu, the device is also the cpu
46+
CpuInterface() {
47+
}
48+
49+
DeviceType GetType() const override { return DeviceType::CPU; }
50+
51+
void InitOrt(const OrtApi& /*api*/, Ort::Allocator& allocator) override {
52+
assert(!ort_allocator_);
53+
ort_allocator_ = &allocator;
54+
}
55+
56+
Ort::Allocator& GetAllocator() override {
57+
return *ort_allocator_;
58+
}
59+
60+
std::shared_ptr<DeviceBuffer> AllocateBase(size_t size) override {
4561
return std::make_shared<CpuMemory>(size);
4662
}
4763

4864
std::shared_ptr<DeviceBuffer> WrapMemoryBase(void* p, size_t size) override {
4965
return std::make_shared<CpuMemory>(p, size);
5066
}
5167

68+
bool Cast(OrtValue& input, OrtValue& output) override {
69+
auto input_info = input.GetTensorTypeAndShapeInfo();
70+
auto output_info = output.GetTensorTypeAndShapeInfo();
71+
72+
auto input_type = input_info->GetElementType();
73+
auto output_type = output_info->GetElementType();
74+
75+
auto element_count = input_info->GetElementCount();
76+
if (element_count != output_info->GetElementCount())
77+
throw std::runtime_error("Cast - input and output element counts do not match");
78+
if (input_type == output_type)
79+
throw std::runtime_error("Cast - input and output types are the same");
80+
81+
if (input_type == Ort::TypeToTensorType<float> && output_type == Ort::TypeToTensorType<Ort::Float16_t>) {
82+
auto* fp32 = input.GetTensorData<float>();
83+
auto* fp16 = output.GetTensorMutableData<uint16_t>();
84+
for (size_t i = 0; i < element_count; i++)
85+
fp16[i] = FastFloat32ToFloat16(fp32[i]);
86+
} else if (input_type == Ort::TypeToTensorType<Ort::Float16_t> && output_type == Ort::TypeToTensorType<float>) {
87+
auto* fp16 = input.GetTensorData<uint16_t>();
88+
auto* fp32 = output.GetTensorMutableData<float>();
89+
for (size_t i = 0; i < element_count; i++)
90+
fp32[i] = FastFloat16ToFloat32(fp16[i]);
91+
} else if (input_type == Ort::TypeToTensorType<int32_t> && output_type == Ort::TypeToTensorType<int64_t>) {
92+
auto* input_data = input.GetTensorData<int32_t>();
93+
auto* output_data = output.GetTensorMutableData<int64_t>();
94+
for (size_t i = 0; i < element_count; i++)
95+
output_data[i] = input_data[i];
96+
} else
97+
throw std::runtime_error("Cast - Unimplemented cast");
98+
return true;
99+
}
100+
52101
std::unique_ptr<Search> CreateGreedy(const GeneratorParams& params) override { return std::make_unique<GreedySearch_Cpu>(params); }
53102
std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }
54103

55104
void Synchronize() override {} // Nothing to do as CPU is always in sync with itself
56-
} g_cpu;
105+
};
57106

58-
DeviceInterface* GetCpuInterface() { return &g_cpu; }
107+
DeviceInterface* GetCpuInterface() {
108+
static std::unique_ptr<CpuInterface> g_cpu = std::make_unique<CpuInterface>();
109+
return g_cpu.get();
110+
}
59111

60112
} // namespace Generators

src/cuda/beam_search_scorer_cuda.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
#include "generators.h"
25
#include "search.h"
36
#include "search_cuda.h"
@@ -8,7 +11,7 @@
811
namespace Generators {
912

1013
BeamSearchScorer_Cuda::BeamSearchScorer_Cuda(const GeneratorParams& parameters)
11-
: stream_{parameters.cuda_stream} {
14+
: stream_{GetStream()} {
1215
state_cpu_ = CudaMallocHostArray<cuda::BeamScorerState>(1);
1316
state_cpu_->batch_size_ = static_cast<size_t>(parameters.search.batch_size);
1417
state_cpu_->num_beams_ = static_cast<size_t>(parameters.search.num_beams);

src/cuda/beam_search_scorer_cuda.cu

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
#include <cuda_runtime.h>
25
#include <assert.h>
36
#include <algorithm>

src/cuda/beam_search_scorer_cuda.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "models/onnxruntime_api.h"
15
#include "smartptrs.h"
26

37
namespace Generators {

src/cuda/beam_search_scorer_cuda.h

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
namespace Generators {
25

36
struct BeamSearchScorer_Cuda {

src/cuda/beam_search_topk.cu

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
14
#include <cuda_runtime.h>
25
#include <cub/cub.cuh>
36
#include <limits>

src/cuda/cuda_sampling.cu

+7-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "span.h"
99
#include "beam_search_topk.h"
1010
#include "cuda_sampling.cuh"
11+
#include "models/onnxruntime_api.h"
1112
#include "smartptrs.h"
1213
#include <cuda_runtime.h>
1314
#include <cub/cub.cuh>
@@ -297,22 +298,22 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl
297298
}
298299

299300
template <bool is_log_softmax>
300-
void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements,
301+
void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements,
301302
int input_stride, int output_stride, int batch_count, float temperature) {
302303
dim3 grid(batch_count);
303304
constexpr int ILP = sizeof(float4) / sizeof(float);
304305
dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements);
305306
if (is_log_softmax) {
306307
SoftmaxBlockForward<ILP, float, float, float, LogSoftmaxForwardEpilogue>
307-
<<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
308+
<<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
308309
softmax_elements, input_stride, output_stride, temperature);
309310
} else {
310311
SoftmaxBlockForward<ILP, float, float, float, SoftmaxForwardEpilogue>
311-
<<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input),
312+
<<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input),
312313
softmax_elements, input_stride, output_stride, temperature);
313314
}
314315
}
315-
template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t*, float*, const float*, int, int, int, int, float);
316+
template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t, float*, const float*, int, int, int, int, float);
316317

317318
// Populate Kernels and Launchers
318319

@@ -521,7 +522,7 @@ void LaunchSampleKernel(SamplingData* data, cudaStream_t stream, float* scores,
521522
void SoftmaxAndSort(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, float temperature) {
522523
// Softmax scores
523524
std::span<float> scores{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
524-
DispatchBlockwiseSoftmaxForward<false>(&stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
525+
DispatchBlockwiseSoftmaxForward<false>(stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
525526
// Sort indices by scores
526527
std::span<int> offsets_gpu{data->offsets.get(), static_cast<size_t>(batch_size + 1)};
527528
LaunchPopulateOffsets(offsets_gpu.data(), vocab_size, batch_size, stream);
@@ -550,7 +551,7 @@ void LaunchGetTopKSubsetFullSort(SamplingData* data, cudaStream_t stream, float*
550551
void GetTopKSubset(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, int k, float temperature) {
551552
// Softmax scores
552553
std::span<float> scores_softmaxed{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
553-
DispatchBlockwiseSoftmaxForward<false>(&stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
554+
DispatchBlockwiseSoftmaxForward<false>(stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature);
554555
// Get top k subset
555556
#define GetTopK(max_k) \
556557
LaunchGetTopKSubset<max_k>(stream, \

src/cuda/cuda_sampling.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3+
34
#include <assert.h>
45
#include "cuda_common.h"
56
#include <curand_kernel.h>
@@ -25,7 +26,7 @@ void LaunchPopulateIndices(int* indices, int size, int batch_size, cudaStream_t
2526
void GetSample(SamplingData* data, cudaStream_t stream, int32_t* d_next_token, float* d_scores, int vocab_size, int batch_size, int k, float p, float temperature);
2627

2728
template <bool is_log_softmax>
28-
void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);
29+
void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, int input_stride, int output_stride, int batch_count, float temperature = 1.0);
2930

3031
} // namespace cuda
3132
} // namespace Generators

0 commit comments

Comments
 (0)