Skip to content

Commit e8c5489

Browse files
authored
ggml-webgpu: FlashAttention refactor + standardize quantization support (#23834)
* Start work on flash_attn refactor * Refactor * Split k/v quantization * Refactor and abstract quantization logic for flash_attn and mul_mat * Add quantization support to tile path * formatting * Move to functions, add a check
1 parent 3c7450c commit e8c5489

11 files changed

Lines changed: 986 additions & 950 deletions

ggml/src/ggml-webgpu/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
1010

1111
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
1212

13-
# Find all WGSL files
14-
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
13+
# Find all WGSL sources
14+
file(GLOB WGSL_SHADER_FILES
15+
"${SHADER_DIR}/*.wgsl"
16+
"${SHADER_DIR}/*.tmpl"
17+
)
1518

1619
# Generate the header using a Python script
1720
add_custom_command(

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Lines changed: 347 additions & 312 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 224 additions & 192 deletions
Large diffs are not rendered by default.

ggml/src/ggml-webgpu/pre_wgsl.hpp

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,33 @@ static std::string trim(const std::string & s) {
3737
}
3838

3939
static std::string trim_value(std::istream & is) {
40-
std::string str;
41-
std::getline(is, str);
42-
return trim(str);
40+
std::ostringstream ss;
41+
ss << is.rdbuf();
42+
return trim(ss.str());
4343
}
4444

4545
static bool isIdentChar(char c) {
4646
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
4747
}
4848

49+
static bool endsWithContinuation(const std::string & line) {
50+
size_t i = line.size();
51+
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
52+
i--;
53+
}
54+
return i > 0 && line[i - 1] == '\\';
55+
}
56+
57+
static void stripContinuation(std::string & line) {
58+
size_t i = line.size();
59+
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
60+
i--;
61+
}
62+
if (i > 0 && line[i - 1] == '\\') {
63+
line.erase(i - 1);
64+
}
65+
}
66+
4967
static std::string expandMacrosRecursiveInternal(const std::string & line,
5068
const std::unordered_map<std::string, std::string> & macros,
5169
std::unordered_set<std::string> & visiting);
@@ -595,19 +613,31 @@ class Preprocessor {
595613
std::string line;
596614

597615
while (std::getline(in, line)) {
598-
std::string t = trim(line);
616+
std::string logical = line;
617+
std::string t = trim(logical);
618+
if (!t.empty() && t[0] == '#') {
619+
while (endsWithContinuation(logical)) {
620+
stripContinuation(logical);
621+
if (!std::getline(in, line)) {
622+
break;
623+
}
624+
logical += "\n";
625+
logical += line;
626+
}
627+
t = trim(logical);
628+
}
599629

600630
if (!t.empty() && t[0] == '#') {
601631
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
602632
if (mode == DirectiveMode::IncludesOnly && !handled) {
603-
out << line << "\n";
633+
out << logical << "\n";
604634
}
605635
} else {
606636
if (mode == DirectiveMode::IncludesOnly) {
607-
out << line << "\n";
637+
out << logical << "\n";
608638
} else if (condActive(cond)) {
609639
// Expand macros in the line before outputting
610-
std::string expanded = expandMacrosRecursive(line, macros);
640+
std::string expanded = expandMacrosRecursive(logical, macros);
611641
out << expanded << "\n";
612642
}
613643
}

0 commit comments

Comments
 (0)