|
8 | 8 | #include "span.h"
|
9 | 9 | #include "beam_search_topk.h"
|
10 | 10 | #include "cuda_sampling.cuh"
|
| 11 | +#include "models/onnxruntime_api.h" |
11 | 12 | #include "smartptrs.h"
|
12 | 13 | #include <cuda_runtime.h>
|
13 | 14 | #include <cub/cub.cuh>
|
@@ -297,22 +298,22 @@ __global__ void SoftmaxBlockForward(outscalar_t* output, scalar_t* input, int cl
|
297 | 298 | }
|
298 | 299 |
|
299 | 300 | template <bool is_log_softmax>
|
300 |
| -void DispatchBlockwiseSoftmaxForward(cudaStream_t* stream, float* output, const float* input, int softmax_elements, |
| 301 | +void DispatchBlockwiseSoftmaxForward(cudaStream_t stream, float* output, const float* input, int softmax_elements, |
301 | 302 | int input_stride, int output_stride, int batch_count, float temperature) {
|
302 | 303 | dim3 grid(batch_count);
|
303 | 304 | constexpr int ILP = sizeof(float4) / sizeof(float);
|
304 | 305 | dim3 block = SoftmaxGetBlockSize(ILP, softmax_elements);
|
305 | 306 | if (is_log_softmax) {
|
306 | 307 | SoftmaxBlockForward<ILP, float, float, float, LogSoftmaxForwardEpilogue>
|
307 |
| - <<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input), |
| 308 | + <<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input), |
308 | 309 | softmax_elements, input_stride, output_stride, temperature);
|
309 | 310 | } else {
|
310 | 311 | SoftmaxBlockForward<ILP, float, float, float, SoftmaxForwardEpilogue>
|
311 |
| - <<<grid, block, block.x * sizeof(float), *stream>>>(output, const_cast<float*>(input), |
| 312 | + <<<grid, block, block.x * sizeof(float), stream>>>(output, const_cast<float*>(input), |
312 | 313 | softmax_elements, input_stride, output_stride, temperature);
|
313 | 314 | }
|
314 | 315 | }
|
315 |
| -template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t*, float*, const float*, int, int, int, int, float); |
| 316 | +template void DispatchBlockwiseSoftmaxForward<true>(cudaStream_t, float*, const float*, int, int, int, int, float); |
316 | 317 |
|
317 | 318 | // Populate Kernels and Launchers
|
318 | 319 |
|
@@ -521,7 +522,7 @@ void LaunchSampleKernel(SamplingData* data, cudaStream_t stream, float* scores,
|
521 | 522 | void SoftmaxAndSort(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, float temperature) {
|
522 | 523 | // Softmax scores
|
523 | 524 | std::span<float> scores{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
|
524 |
| - DispatchBlockwiseSoftmaxForward<false>(&stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); |
| 525 | + DispatchBlockwiseSoftmaxForward<false>(stream, scores.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); |
525 | 526 | // Sort indices by scores
|
526 | 527 | std::span<int> offsets_gpu{data->offsets.get(), static_cast<size_t>(batch_size + 1)};
|
527 | 528 | LaunchPopulateOffsets(offsets_gpu.data(), vocab_size, batch_size, stream);
|
@@ -550,7 +551,7 @@ void LaunchGetTopKSubsetFullSort(SamplingData* data, cudaStream_t stream, float*
|
550 | 551 | void GetTopKSubset(SamplingData* data, cudaStream_t stream, float* scores_in, float* scores_out, int* indices_out, int vocab_size, int batch_size, int k, float temperature) {
|
551 | 552 | // Softmax scores
|
552 | 553 | std::span<float> scores_softmaxed{data->scores_softmaxed.get(), static_cast<size_t>(vocab_size * batch_size)};
|
553 |
| - DispatchBlockwiseSoftmaxForward<false>(&stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); |
| 554 | + DispatchBlockwiseSoftmaxForward<false>(stream, scores_softmaxed.data(), const_cast<const float*>(scores_in), vocab_size, vocab_size, vocab_size, batch_size, temperature); |
554 | 555 | // Get top k subset
|
555 | 556 | #define GetTopK(max_k) \
|
556 | 557 | LaunchGetTopKSubset<max_k>(stream, \
|
|
0 commit comments