-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathcudaUtils.h
More file actions
1439 lines (1272 loc) · 46.3 KB
/
cudaUtils.h
File metadata and controls
1439 lines (1272 loc) · 46.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#if ENABLE_FP4
#include <cuda_fp4.h>
#endif
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
#include <algorithm>
#include <cassert>
#include <cinttypes>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <driver_types.h>
#include <fstream>
#include <iomanip>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#ifndef _WIN32 // Linux
#include <sys/sysinfo.h>
#endif // not WIN32
#include <vector>
#ifdef _WIN32 // Windows
#include <windows.h>
#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without
// this undef.
#endif // WIN32
TRTLLM_NAMESPACE_BEGIN
namespace common
{
// workspace for cublas gemm : 32MB
#define CUBLAS_WORKSPACE_SIZE 33554432
typedef struct __align__(4)
{
half x, y, z, w;
}
half4;
/* **************************** type definition ***************************** */
enum CublasDataType
{
FLOAT_DATATYPE = 0,
HALF_DATATYPE = 1,
BFLOAT16_DATATYPE = 2,
INT8_DATATYPE = 3,
FP8_DATATYPE = 4
};
enum TRTLLMCudaDataType
{
FP32 = 0,
FP16 = 1,
BF16 = 2,
INT8 = 3,
FP8 = 4
};
enum class OperationType
{
FP32,
FP16,
BF16,
INT8,
FP8
};
/* **************************** debug tools ********************************* */
static char const* _cudaGetErrorEnum(cudaError_t error)
{
return cudaGetErrorString(error);
}
static char const* _cudaGetErrorEnum(cublasStatus_t error)
{
switch (error)
{
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}
template <typename T>
void check(T ptr, char const* const func, char const* const file, int const line)
{
if (ptr)
{
throw TllmException(file, line,
fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(ptr)).c_str());
}
}
template <typename T>
void checkEx(
T ptr, std::initializer_list<T> const& validReturns, char const* const func, char const* const file, int const line)
{
if (std::all_of(std::begin(validReturns), std::end(validReturns), [&ptr](T const& t) { return t != ptr; }))
{
throw TllmException(file, line,
fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(ptr)).c_str());
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
inline bool isCapturing(cudaStream_t stream)
{
cudaStreamCaptureStatus status;
check_cuda_error(cudaStreamIsCapturing(stream, &status));
return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive;
}
inline bool doCheckError(cudaStream_t stream)
{
// If we're capturing a CUDA graph we don't check. Otherwise, we
// default to only checking in debug builds. But we always listen to
// the env variable.
static bool const doCheckIfNotCapturing = []()
{
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
if (env != nullptr)
{
return std::string(env) == "1";
}
#ifndef NDEBUG
return true;
#else
return false;
#endif
}();
return doCheckIfNotCapturing && !isCapturing(stream);
}
inline void syncAndCheck(cudaStream_t stream, char const* const file, int const line)
{
if (doCheckError(stream))
{
cudaStreamSynchronize(stream);
check(cudaGetLastError(), "cudaGetLastError", file, line);
}
}
#define sync_check_cuda_error(stream) tensorrt_llm::common::syncAndCheck(stream, __FILE__, __LINE__)
#define PRINT_FUNC_NAME_() \
do \
{ \
std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \
} while (0)
// clang-format off
template<typename T> struct packed_type;
template <> struct packed_type<float> { using type = float; }; // we don't need to pack float by default
template <> struct packed_type<half> { using type = half2; };
#ifdef ENABLE_BF16
template<>
struct packed_type<__nv_bfloat16> {
using type = __nv_bfloat162;
};
#endif
#ifdef ENABLE_FP8
template<>
struct packed_type<__nv_fp8_e4m3> {
using type = __nv_fp8x2_e4m3;
};
#endif
template<typename T> struct num_elems;
template <> struct num_elems<float> { static constexpr int value = 1; };
template <> struct num_elems<float2> { static constexpr int value = 2; };
template <> struct num_elems<float4> { static constexpr int value = 4; };
template <> struct num_elems<half> { static constexpr int value = 1; };
template <> struct num_elems<half2> { static constexpr int value = 2; };
#ifdef ENABLE_BF16
template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; };
template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; };
#endif
#ifdef ENABLE_FP8
template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; };
template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; };
#endif
template<typename T, int num> struct packed_as;
template<typename T> struct packed_as<T, 1> { using type = T; };
template<> struct packed_as<half, 2> { using type = half2; };
template<> struct packed_as<float, 2> { using type = float2; };
template<> struct packed_as<int8_t, 2> { using type = int16_t; };
template<> struct packed_as<int32_t, 2> { using type = int2; };
template<> struct packed_as<half2, 1> { using type = half; };
template<> struct packed_as<float2, 1> { using type = float; };
#ifdef ENABLE_BF16
template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; };
template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; };
#endif
#ifdef ENABLE_FP8
template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; };
template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; };
template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; };
template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; };
#endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
// clang-format on
template <typename T>
struct CudaDataType
{
};
template <>
struct CudaDataType<float>
{
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F;
};
template <>
struct CudaDataType<half>
{
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F;
};
#ifdef ENABLE_BF16
template <>
struct CudaDataType<__nv_bfloat16>
{
static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF;
};
#endif
/// @brief Get the SM version of the current device.
/// @param queryRealSmArch Whether to query the real SM architecture. example usage: use real sm arch when do LUT tuning
/// and use fake sm arch when reuse sm120 code on sm121 devices.
/// @return The SM version of the current device.
inline int getSMVersion(bool queryRealSmArch = false)
{
int device{-1};
check_cuda_error(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121 && !queryRealSmArch)
{
return 120;
}
return sm;
}
inline bool isSM100Family()
{
int const sm = getSMVersion();
return sm == 100 || sm == 103; // To be continued...
}
inline int getDevice()
{
int deviceID{0};
check_cuda_error(cudaGetDevice(&deviceID));
return deviceID;
}
inline int getDeviceCount()
{
int count{0};
check_cuda_error(cudaGetDeviceCount(&count));
return count;
}
/// @brief Identifies the memory type of the given pointer.
template <typename T>
cudaMemoryType getPtrCudaMemoryType(T* ptr)
{
cudaPointerAttributes attributes{};
check_cuda_error(cudaPointerGetAttributes(&attributes, ptr));
return attributes.type;
}
/// Get the memory info
/// \return The free and total amount of memory in bytes
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
{
if (useUvm)
{
size_t freeSysMem = 0;
size_t totalSysMem = 0;
#ifndef _WIN32 // Linux
struct sysinfo info
{
};
sysinfo(&info);
totalSysMem = info.totalram * info.mem_unit;
freeSysMem = info.freeram * info.mem_unit;
#else // Windows
MEMORYSTATUSEX memInfo;
memInfo.dwLength = sizeof(memInfo);
GlobalMemoryStatusEx(&memInfo);
totalSysMem = memInfo.ullTotalPhys;
freeSysMem = memInfo.ullAvailPhys;
#endif // WIN32
TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9));
return {freeSysMem, totalSysMem};
}
size_t free = 0;
size_t total = 0;
check_cuda_error(cudaMemGetInfo(&free, &total));
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
((double) total / 1e9), ((double) free / 1e9));
return {free, total};
}
/// @brief Gets the memory allocation granularity for the current device.
///
/// @return size_t The size of the smallest difference in memory size supported by the current device.
inline size_t getAllocationGranularity()
{
auto const currentDevice = getDevice();
::CUmemAllocationProp prop = {};
prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = currentDevice;
prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE;
// Get the minimum granularity supported for allocation with cuMemCreate()
size_t granularity = 0;
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
return granularity;
}
inline int getMultiProcessorCount()
{
int nSM{0};
int deviceID{0};
check_cuda_error(cudaGetDevice(&deviceID));
check_cuda_error(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID));
return nSM;
}
inline int getMaxSharedMemoryPerSM()
{
int nByteMaxSharedMemoryPerSM{0};
int deviceID{0};
check_cuda_error(cudaGetDevice(&deviceID));
check_cuda_error(
cudaDeviceGetAttribute(&nByteMaxSharedMemoryPerSM, cudaDevAttrMaxSharedMemoryPerMultiprocessor, deviceID));
return nByteMaxSharedMemoryPerSM;
}
inline int getMaxSharedMemoryPerBlockOptin()
{
int nByteMaxSharedMemoryPerBlockOptin{0};
int deviceID{0};
check_cuda_error(cudaGetDevice(&deviceID));
check_cuda_error(
cudaDeviceGetAttribute(&nByteMaxSharedMemoryPerBlockOptin, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceID));
return nByteMaxSharedMemoryPerBlockOptin;
}
template <typename T>
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
{
static std::unordered_map<T, int> cache;
auto it = cache.find(kernel);
if (it != cache.end())
{
return it->second;
}
int numBlocks;
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
cache[kernel] = numBlocks;
return numBlocks;
}
template <typename T1, typename T2>
inline size_t divUp(T1 const& a, T2 const& b)
{
auto const tmp_a = static_cast<size_t>(a);
auto const tmp_b = static_cast<size_t>(b);
return (tmp_a + tmp_b - 1) / tmp_b;
}
inline int roundUp(int a, int b)
{
return divUp(a, b) * b;
}
template <typename T, typename U, typename = std::enable_if_t<std::is_integral<T>::value>,
typename = std::enable_if_t<std::is_integral<U>::value>>
auto constexpr ceilDiv(T numerator, U denominator)
{
return (numerator + denominator - 1) / denominator;
}
template <typename T>
void printArrayInfo(T const* ptr, uint64_t nElement = 1, std::string name = "", bool const bPrintElement = false)
{
if (ptr == nullptr)
{
TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str());
return;
}
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
bool const isDevicePtr = (getPtrCudaMemoryType(ptr) == cudaMemoryTypeDevice);
size_t sizeInByte = sizeof(T) * nElement;
TLLM_LOG_TRACE("addr=%p, location=%s, sizeof(T)=%lu, nElement=%d, sizeInByte=%lu\n", ptr,
(isDevicePtr ? "Device" : "Host"), sizeof(T), nElement, sizeInByte);
T* tmp = const_cast<T*>(ptr);
std::vector<T> tmpVec; // For device pointer
if (isDevicePtr)
{
tmpVec.resize(nElement);
tmp = tmpVec.data(); // Note `data()` is not supported for vector<bool>
check_cuda_error(cudaMemcpy(tmp, ptr, sizeInByte, cudaMemcpyDeviceToHost));
cudaDeviceSynchronize();
}
size_t nInf = 0;
size_t nNaN = 0;
size_t nZero = 0;
double sum = 0.0;
double sqrSum = 0.0;
double absSum = 0.0;
float allMax = -1.0e6f;
float allMin = 1.0e6f;
float allSad = 0.0f; // Sum Abs of Difference, to distinguish A and its transpose
float old = 0.0f;
for (uint64_t i = 0; i < nElement; i++)
{
float val = (float) tmp[i];
if (std::isinf(val))
{
nInf++;
continue;
}
if (std::isnan(val))
{
nNaN++;
continue;
}
nZero += (val == 0.0f);
sum += val;
sqrSum += val * val;
absSum += expf(val);
allMax = std::max(allMax, val);
allMin = std::min(allMin, val);
allSad += abs(val - old);
old = val;
}
float avg = sum / nElement;
float std = sqrtf(sqrSum / nElement - avg * avg);
TLLM_LOG_INFO("%s", name.c_str());
TLLM_LOG_INFO("size=%u, nInf=%zu, nNaN=%zu, nZero=%zu", nElement, nInf, nNaN, nZero);
TLLM_LOG_INFO("avg=%f, absSum: %f, std=%f, max=%f, min=%f, sad=%f", avg, absSum, std, allMax, allMin, allSad);
if (bPrintElement)
{
uint64_t constexpr nHead = 5;
std::stringstream ss;
ss << std::setw(10) << std::fixed << std::setprecision(3);
for (uint64_t i = 0; i < std::min(nElement, nHead); ++i)
{
ss << (float) tmp[i] << ", ";
}
if (nElement > nHead)
{
ss << " ... ";
for (uint64_t i = nElement - nHead; i < nElement; ++i)
{
ss << (float) tmp[i] << ", ";
}
}
TLLM_LOG_INFO("%s", ss.str().c_str());
}
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
}
template void printArrayInfo(float const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
template void printArrayInfo(half const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
#ifdef ENABLE_BF16
template void printArrayInfo(__nv_bfloat16 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
#endif
#ifdef ENABLE_FP8
template void printArrayInfo(__nv_fp8_e4m3 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
#endif
#ifdef ENABLE_FP4
template void printArrayInfo(__nv_fp4_e2m1 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
#endif
template void printArrayInfo(uint32_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
template void printArrayInfo(uint64_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
template void printArrayInfo(int const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
template void printArrayInfo(uint8_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
template <typename T>
void printToStream(T const* ptr, int const nElement, FILE* strm)
{
bool const split_rows = (strm == stdout);
if (ptr == nullptr)
{
TLLM_LOG_WARNING("Nullptr, skip!\n");
return;
}
std::vector<T> tmp(nElement, 0);
check_cuda_error(cudaMemcpy(tmp.data(), ptr, sizeof(T) * nElement, cudaMemcpyDeviceToHost));
for (int i = 0; i < nElement; ++i)
{
fprintf(strm, "%f, ", static_cast<float>(tmp[i]));
if (split_rows && ((i + 1) % 10) == 0)
fprintf(strm, "\n");
}
if (!split_rows || (nElement % 10) != 0)
{
fprintf(strm, "\n");
}
}
template <typename T>
void printToScreen(T const* ptr, int const nElement)
{
printToStream(ptr, nElement, stdout);
}
template <typename T>
void print2dToStream(T const* ptr, int const nRow, int const nCol, int const nStride, FILE* strm)
{
if (ptr == nullptr)
{
TLLM_LOG_WARNING("Nullptr, skip!\n");
return;
}
for (int ri = 0; ri < nRow; ++ri)
{
T const* tmp = ptr + ri * nStride;
printToStream(tmp, nCol, strm);
}
fprintf(strm, "\n");
}
template <typename T>
void print2dToScreen(T const* ptr, int const nRow, int const nCol, int const nStride)
{
print2dToStream(ptr, nRow, nCol, nStride, stdout);
}
template <typename T>
void print2dToFile(std::string fname, T const* ptr, int const nRow, int const nCol, int const nStride)
{
FILE* fp = fopen(fname.c_str(), "wt");
if (fp != nullptr)
{
print2dToStream(ptr, nRow, nCol, nStride, fp);
fclose(fp);
}
}
__host__ __device__ inline void print_float_(float x)
{
printf("%7.3f ", x);
}
__host__ __device__ inline void print_element_(float x)
{
print_float_(x);
}
__host__ __device__ inline void print_element_(half x)
{
print_float_((float) x);
}
#ifdef ENABLE_BF16
__host__ __device__ inline void print_element_(__nv_bfloat16 x)
{
print_float_((float) x);
}
#endif
#ifdef ENABLE_FP8
__host__ __device__ inline void print_element_(__nv_fp8_e4m3 x)
{
print_float_((float) x);
}
#endif
__host__ __device__ inline void print_element_(bool ui)
{
printf("%7" PRIu32 " ", (unsigned int) ui);
}
__host__ __device__ inline void print_element_(uint8_t ui)
{
printf("%7" PRIu32 " ", (unsigned int) ui);
}
__host__ __device__ inline void print_element_(uint32_t ul)
{
printf("%7" PRIu32 " ", ul);
}
__host__ __device__ inline void print_element_(uint64_t ull)
{
printf("%7" PRIu64 " ", ull);
}
__host__ __device__ inline void print_element_(int32_t il)
{
printf("%7" PRId32 " ", il);
}
__host__ __device__ inline void print_element_(int64_t ill)
{
printf("%7" PRId64 " ", ill);
}
template <typename T>
__host__ __device__ inline void print_elements(T const* ptr, int nRow, int nCol, int nStride)
{
for (int iRow = -1; iRow < nRow; ++iRow)
{
if (iRow >= 0)
{
printf("%07d|", iRow);
}
else
{
printf(" |"); // heading row
}
for (int iCol = 0; iCol < nCol; iCol += 1)
{
if (iRow >= 0)
{
print_element_(ptr[iRow * nStride + iCol]);
}
else
{
printf("%7d|", iCol); // heading colume
}
}
printf("\n");
}
printf("\n");
}
template <typename T>
inline void printMatrix(T const* ptr, int nRow, int nCol, int nStride)
{
// `nRow` is length of row dimension
// `nStride` is length of column dimension
// `nCol` (<= nStride) is length for print per row
if (ptr == nullptr)
{
TLLM_LOG_WARNING("Nullptr, skip!\n");
return;
}
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
bool const isDevicePtr = (getPtrCudaMemoryType(ptr) == cudaMemoryTypeDevice);
size_t sizeInByte = sizeof(T) * nRow * nStride;
TLLM_LOG_TRACE("addr=%p, location=%s, sizeof(T)=%lu, nRow=%d, nStride=%d, sizeInByte=%lu\n", ptr,
(isDevicePtr ? "Device" : "Host"), sizeof(T), nRow, nStride, sizeInByte);
if (isDevicePtr)
{
std::vector<T> tmpVec;
tmpVec.resize(nRow * nStride);
T* tmp = tmpVec.data(); // Note `data()` is not supported for vector<bool>
check_cuda_error(cudaMemcpy(tmp, ptr, sizeInByte, cudaMemcpyDeviceToHost));
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());
print_elements(tmp, nRow, nCol, nStride);
}
else
{
print_elements(ptr, nRow, nCol, nStride);
}
}
template void printMatrix(float const* ptr, int nRow, int nCol, int nStride);
template void printMatrix(half const* ptr, int nRow, int nCol, int nStride);
#ifdef ENABLE_BF16
template void printMatrix(__nv_bfloat16 const* ptr, int nRow, int nCol, int nStride);
#endif
#ifdef ENABLE_FP8
template void printMatrix(__nv_fp8_e4m3 const* ptr, int nRow, int nCol, int nStride);
#endif
template void printMatrix(uint32_t const* ptr, int nRow, int nCol, int nStride);
template void printMatrix(uint64_t const* ptr, int nRow, int nCol, int nStride);
template void printMatrix(int const* ptr, int nRow, int nCol, int nStride);
template void printMatrix(uint8_t const* ptr, int nRow, int nCol, int nStride);
template <typename T>
__device__ inline void printMatrixDevice(T const* ptr, int nRow, int nCol, int nStride)
{
// `nRow` is length of row dimension
// `nStride` is length of column dimension
// `nCol` (<= nStride) is length for print per row
// Can be called inside kernels by one single thread
if (ptr == nullptr)
{
printf("Nullptr, skip!\n");
return;
}
size_t sizeInByte = sizeof(T) * nRow * nStride;
printf("addr=%p, sizeof(T)=%lu, nRow=%d, nStride=%d, sizeInByte=%lu\n", ptr, sizeof(T), nRow, nStride, sizeInByte);
print_elements(ptr, nRow, nCol, nStride);
}
template __device__ void printMatrixDevice(float const* ptr, int nRow, int nCol, int nStride);
template __device__ void printMatrixDevice(half const* ptr, int nRow, int nCol, int nStride);
#ifdef ENABLE_BF16
template __device__ void printMatrixDevice(__nv_bfloat16 const* ptr, int nRow, int nCol, int nStride);
#endif
#ifdef ENABLE_FP8
template __device__ void printMatrixDevice(__nv_fp8_e4m3 const* ptr, int nRow, int nCol, int nStride);
#endif
template __device__ void printMatrixDevice(uint32_t const* ptr, int nRow, int nCol, int nStride);
template __device__ void printMatrixDevice(uint64_t const* ptr, int nRow, int nCol, int nStride);
template __device__ void printMatrixDevice(int const* ptr, int nRow, int nCol, int nStride);
template __device__ void printMatrixDevice(uint8_t const* ptr, int nRow, int nCol, int nStride);
#ifndef CUDA_CALL
#define CUDA_CALL(answer) \
{ \
gpuAssert((answer), __FILE__, __LINE__); \
}
inline void gpuAssert(cudaError_t code, char const* file, int line, bool abort = true)
{
if (code != cudaSuccess)
{
fprintf(stderr, "CUDA error: %s @ %s:%d\n", cudaGetErrorString(code), file, line);
if (abort)
exit(code);
}
}
inline void gpuAssert(CUresult code, char const* file, int line, bool abort = true)
{
if (code != CUresult::CUDA_SUCCESS)
{
char const* buf = "Unknown error";
assert(cuGetErrorString(code, &buf) == CUresult::CUDA_SUCCESS);
fprintf(stderr, "Driver API error: %s @ %s:%d\n", buf, file, line);
if (abort)
exit(code);
}
}
#endif
template <typename T>
struct UpperType;
template <>
struct UpperType<int8_t>
{
using Type = int;
};
template <>
struct UpperType<uint32_t>
{
using Type = uint32_t;
};
template <>
struct UpperType<int>
{
using Type = int;
};
template <>
struct UpperType<__nv_bfloat16>
{
using Type = double;
};
template <>
struct UpperType<half>
{
using Type = double;
};
template <>
struct UpperType<float>
{
using Type = double;
};
extern "C"
{
__device__ uint32_t __nvvm_get_smem_pointer(void* ptr);
}
__forceinline__ __device__ void issue_stas(uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 [%0], %2, [%1];\n\t"
:
: "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0));
#endif
}
__forceinline__ __device__ void issue_stas(uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint64_t d0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b64 [%0], %2, [%1];\n\t"
:
: "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "l"(d0));
#endif
}
__forceinline__ __device__ void issue_stas(
uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0, uint32_t d1)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.b32 [%0], {%2, %3}, [%1];\n\t"
:
: "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0), "r"(d1));
#endif
}
__forceinline__ __device__ void issue_stas(
uint32_t dist_barrier_ptr, uint32_t dist_buffer_ptr, uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.b32 [%0], {%2, %3, %4, %5}, [%1];\n\t"
:
: "r"(dist_buffer_ptr), "r"(dist_barrier_ptr), "r"(d0), "r"(d1), "r"(d2), "r"(d3));
#endif
}
inline __device__ uint32_t elect_one_sync()
{
uint32_t pred = 0;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL))
uint32_t laneid = 0;
asm volatile(
"\n\
{\n\
.reg .b32 %rx;\n\
.reg .pred %px;\n\
elect.sync %rx|%px, %2;\n\
@%px mov.s32 %1, 1;\n\
mov.s32 %0, %rx;\n\
}\n\
"
: "+r"(laneid), "+r"(pred)
: "r"(0xFFFFFFFF));
#endif
#endif
return pred;
}
__forceinline__ __device__ uint32_t get_smem_pointer(void const* ptr)
{
return __nvvm_get_smem_pointer(const_cast<void*>(ptr));
}
__forceinline__ __device__ void bar_create(void* bar_ptr, int init_count)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
unsigned smem_ptr = get_smem_pointer(bar_ptr);
asm volatile(
"{\n\t"
"mbarrier.init.shared.b64 [%1], %0; \n\t"
"}"
:
: "r"(init_count), "r"(smem_ptr));
#endif
}
struct Arrive_wait
{
public:
__forceinline__ __device__ Arrive_wait()
{
bar_base_ = NULL;
}
__forceinline__ __device__ Arrive_wait(uint64_t* bar_base, int id = 0)
{
bar_base_ = bar_base;
id_ = id;
}
__forceinline__ __device__ int bar_peek(int id, unsigned int bar_phase)
{
uint32_t result32{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
auto* bar_ptr = bar_base_ + id;
unsigned smem_ptr = get_smem_pointer(bar_ptr);
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(result32)
: "r"(smem_ptr), "r"(bar_phase));
#endif
return result32;
}
__forceinline__ __device__ int bar_peek(int id, unsigned int bar_phase, int pred)
{
uint32_t result32{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
auto* bar_ptr = bar_base_ + id;
unsigned smem_ptr = get_smem_pointer(bar_ptr);
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
".reg .pred P2;\n\t"
"setp.eq.u32 P2, %3, 1;\n\t"
"@P2 mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P1; \n\t"
"}"
: "=r"(result32)
: "r"(smem_ptr), "r"(bar_phase), "r"(pred));
#endif
return result32;
}
__forceinline__ __device__ void bar_wait(int id, unsigned int bar_phase)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
auto* bar_ptr = bar_base_ + id;
unsigned smem_ptr = get_smem_pointer(bar_ptr);
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"