Skip to content

Commit bed8991

Browse files
author
eliotwang
committed
amd
1 parent c9f4cd7 commit bed8991

20 files changed

+4502
-191
lines changed

csrc/fused/rocm/dispatch_utils.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) 2024 by SageAttention team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
#include <torch/extension.h>
19+
#include <cstdint>
20+
#include <sstream>
21+
#include <stdexcept>
22+
23+
#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
24+
if (head_dim == 64) { \
25+
constexpr int HEAD_DIM = 64; \
26+
__VA_ARGS__ \
27+
} else if (head_dim == 128) { \
28+
constexpr int HEAD_DIM = 128; \
29+
__VA_ARGS__ \
30+
} else { \
31+
std::ostringstream err_msg; \
32+
err_msg << "Unsupported head dim: " << int(head_dim); \
33+
throw std::invalid_argument(err_msg.str()); \
34+
}
35+
36+
#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \
37+
if (is_causal == 1) { \
38+
constexpr bool IS_CAUSAL = true; \
39+
__VA_ARGS__ \
40+
} else if (is_causal == 0) { \
41+
constexpr bool IS_CAUSAL = false; \
42+
__VA_ARGS__ \
43+
} else { \
44+
std::ostringstream err_msg; \
45+
err_msg << "Unsupported causal mode: " << int(is_causal); \
46+
throw std::invalid_argument(err_msg.str()); \
47+
}
48+
49+
#define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \
50+
if (qk_quant_gran == 2) { \
51+
constexpr int QK_QUANT_GRAN = 2; \
52+
__VA_ARGS__ \
53+
} else if (qk_quant_gran == 3) { \
54+
constexpr int QK_QUANT_GRAN = 3; \
55+
__VA_ARGS__ \
56+
} else { \
57+
std::ostringstream err_msg; \
58+
err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \
59+
throw std::invalid_argument(err_msg.str()); \
60+
}
61+
62+
#define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \
63+
if (return_lse == 1) { \
64+
constexpr bool RETURN_LSE = true; \
65+
__VA_ARGS__ \
66+
} else if (return_lse == 0) { \
67+
constexpr bool RETURN_LSE = false; \
68+
__VA_ARGS__ \
69+
} else { \
70+
std::ostringstream err_msg; \
71+
err_msg << "Unsupported causal mode: " << int(return_lse); \
72+
throw std::invalid_argument(err_msg.str()); \
73+
}
74+
75+
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
76+
if (pytorch_dtype == at::ScalarType::Half) { \
77+
using c_type = half; \
78+
__VA_ARGS__ \
79+
} else if (pytorch_dtype == at::ScalarType::BFloat16) { \
80+
using c_type = hip_bfloat16; \
81+
__VA_ARGS__ \
82+
} else { \
83+
std::ostringstream oss; \
84+
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
85+
TORCH_CHECK(false, oss.str()); \
86+
}
87+
88+
#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \
89+
if (block_size == 64) { \
90+
constexpr int BLOCK_SIZE = 64; \
91+
__VA_ARGS__ \
92+
} else if (block_size == 128) { \
93+
constexpr int BLOCK_SIZE = 128; \
94+
__VA_ARGS__ \
95+
} else { \
96+
std::ostringstream err_msg; \
97+
err_msg << "Unsupported block_size " << int(block_size); \
98+
throw std::invalid_argument(err_msg.str()); \
99+
}
100+
101+
#define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \
102+
if (warp_block_size == 16) { \
103+
constexpr int WARP_BLOCK_SIZE = 16; \
104+
__VA_ARGS__ \
105+
} else if (warp_block_size == 32) { \
106+
constexpr int WARP_BLOCK_SIZE = 32; \
107+
__VA_ARGS__ \
108+
} else { \
109+
std::ostringstream err_msg; \
110+
err_msg << "Unsupported warp_block_size " << int(warp_block_size); \
111+
throw std::invalid_argument(err_msg.str()); \
112+
}

0 commit comments

Comments
 (0)