Skip to content

Commit 5c44a14

Browse files
Merge pull request #462 from janhq/update-dev-from-master-2026-03-23-00-51
Sync master with upstream release b8475
2 parents 794d4c5 + 49bfdde commit 5c44a14

32 files changed

Lines changed: 387 additions & 127 deletions

common/jinja/parser.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ class parser {
5353
return tokens[current + offset];
5454
}
5555

56+
const token & next() {
57+
if (current >= tokens.size()) {
58+
throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos);
59+
}
60+
return tokens[current++];
61+
}
62+
5663
token expect(token::type type, const std::string& error) {
5764
const auto & t = peek();
5865
if (t.t != type) {
@@ -90,9 +97,9 @@ class parser {
9097
size_t start_pos = current;
9198
switch (peek().t) {
9299
case token::comment:
93-
return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
100+
return mk_stmt<comment_statement>(start_pos, next().value);
94101
case token::text:
95-
return mk_stmt<string_literal>(start_pos, tokens[current++].value);
102+
return mk_stmt<string_literal>(start_pos, next().value);
96103
case token::open_statement:
97104
return parse_jinja_statement();
98105
case token::open_expression:
@@ -119,8 +126,7 @@ class parser {
119126
}
120127

121128
size_t start_pos = current;
122-
std::string name = peek().value;
123-
current++; // consume identifier
129+
std::string name = next().value;
124130

125131
statement_ptr result;
126132
if (name == "set") {
@@ -202,7 +208,7 @@ class parser {
202208
// Ignore generation blocks (transformers-specific)
203209
// See https://github.com/huggingface/transformers/pull/30650 for more information.
204210
result = mk_stmt<noop_statement>(start_pos);
205-
current++;
211+
++current;
206212

207213
} else {
208214
throw std::runtime_error("Unknown statement: " + name);
@@ -217,7 +223,7 @@ class parser {
217223
statements body;
218224

219225
if (is(token::equals)) {
220-
current++;
226+
++current;
221227
value = parse_expression_sequence();
222228
} else {
223229
// parsing multiline set here
@@ -280,7 +286,7 @@ class parser {
280286
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
281287
bool is_tuple = is(token::comma);
282288
while (is(token::comma)) {
283-
current++; // consume comma
289+
++current; // consume comma
284290
exprs.push_back(primary ? parse_primary_expression() : parse_expression());
285291
}
286292
return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
@@ -290,7 +296,7 @@ class parser {
290296
// e.g., `message` in `for message in messages`
291297
auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
292298
if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
293-
current++;
299+
++current; // consume 'in'
294300

295301
// `messages` in `for message in messages`
296302
auto iterable = parse_expression();
@@ -305,7 +311,8 @@ class parser {
305311
}
306312

307313
if (is_statement({"else"})) {
308-
current += 2;
314+
++current; // consume {%
315+
++current; // consume 'else'
309316
expect(token::close_statement, "Expected %}");
310317
while (!is_statement({"endfor"})) {
311318
alternate.push_back(parse_any());
@@ -347,7 +354,7 @@ class parser {
347354
auto left = parse_logical_and_expression();
348355
while (is_identifier("or")) {
349356
size_t start_pos = current;
350-
token op = tokens[current++];
357+
token op = next();
351358
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
352359
}
353360
return left;
@@ -357,7 +364,7 @@ class parser {
357364
auto left = parse_logical_negation_expression();
358365
while (is_identifier("and")) {
359366
size_t start_pos = current;
360-
auto op = tokens[current++];
367+
auto op = next();
361368
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
362369
}
363370
return left;
@@ -367,7 +374,7 @@ class parser {
367374
// Try parse unary operators
368375
if (is_identifier("not")) {
369376
size_t start_pos = current;
370-
auto op = tokens[current++];
377+
auto op = next();
371378
return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
372379
}
373380
return parse_comparison_expression();
@@ -382,11 +389,12 @@ class parser {
382389
size_t start_pos = current;
383390
if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
384391
op = {token::identifier, "not in", tokens[current].pos};
385-
current += 2;
392+
++current; // consume 'not'
393+
++current; // consume 'in'
386394
} else if (is_identifier("in")) {
387-
op = tokens[current++];
395+
op = next();
388396
} else if (is(token::comparison_binary_operator)) {
389-
op = tokens[current++];
397+
op = next();
390398
} else break;
391399
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
392400
}
@@ -397,7 +405,7 @@ class parser {
397405
auto left = parse_multiplicative_expression();
398406
while (is(token::additive_binary_operator)) {
399407
size_t start_pos = current;
400-
auto op = tokens[current++];
408+
auto op = next();
401409
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
402410
}
403411
return left;
@@ -407,7 +415,7 @@ class parser {
407415
auto left = parse_test_expression();
408416
while (is(token::multiplicative_binary_operator)) {
409417
size_t start_pos = current;
410-
auto op = tokens[current++];
418+
auto op = next();
411419
left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
412420
}
413421
return left;
@@ -417,9 +425,9 @@ class parser {
417425
auto operand = parse_filter_expression();
418426
while (is_identifier("is")) {
419427
size_t start_pos = current;
420-
current++;
428+
++current; // consume 'is'
421429
bool negate = false;
422-
if (is_identifier("not")) { current++; negate = true; }
430+
if (is_identifier("not")) { ++current; negate = true; }
423431
auto test_id = parse_primary_expression();
424432
// FIXME: tests can also be expressed like this: if x is eq 3
425433
if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
@@ -432,7 +440,7 @@ class parser {
432440
auto operand = parse_call_member_expression();
433441
while (is(token::pipe)) {
434442
size_t start_pos = current;
435-
current++;
443+
++current; // consume pipe
436444
auto filter = parse_primary_expression();
437445
if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
438446
operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
@@ -490,7 +498,7 @@ class parser {
490498
statement_ptr parse_member_expression(statement_ptr object) {
491499
size_t start_pos = current;
492500
while (is(token::dot) || is(token::open_square_bracket)) {
493-
auto op = tokens[current++];
501+
auto op = next();
494502
bool computed = op.t == token::open_square_bracket;
495503
statement_ptr prop;
496504
if (computed) {
@@ -536,7 +544,7 @@ class parser {
536544

537545
statement_ptr parse_primary_expression() {
538546
size_t start_pos = current;
539-
auto t = tokens[current++];
547+
auto t = next();
540548
switch (t.t) {
541549
case token::numeric_literal:
542550
if (t.value.find('.') != std::string::npos) {
@@ -547,7 +555,7 @@ class parser {
547555
case token::string_literal: {
548556
std::string val = t.value;
549557
while (is(token::string_literal)) {
550-
val += tokens[current++].value;
558+
val += next().value;
551559
}
552560
return mk_stmt<string_literal>(start_pos, val);
553561
}
@@ -562,9 +570,9 @@ class parser {
562570
statements vals;
563571
while (!is(token::close_square_bracket)) {
564572
vals.push_back(parse_expression());
565-
if (is(token::comma)) current++;
573+
if (is(token::comma)) ++current;
566574
}
567-
current++;
575+
++current;
568576
return mk_stmt<array_literal>(start_pos, std::move(vals));
569577
}
570578
case token::open_curly_bracket: {
@@ -573,9 +581,9 @@ class parser {
573581
auto key = parse_expression();
574582
expect(token::colon, "Expected :");
575583
pairs.push_back({std::move(key), parse_expression()});
576-
if (is(token::comma)) current++;
584+
if (is(token::comma)) ++current;
577585
}
578-
current++;
586+
++current;
579587
return mk_stmt<object_literal>(start_pos, std::move(pairs));
580588
}
581589
default:

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND)
116116
list(APPEND GGML_SOURCES_CUDA ${SRCS})
117117
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
118118
else()
119-
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
120-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
121-
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
122-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
123-
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
124-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
119+
list(APPEND GGML_SOURCES_CUDA
120+
template-instances/fattn-vec-instance-f16-f16.cu
121+
template-instances/fattn-vec-instance-q4_0-q4_0.cu
122+
template-instances/fattn-vec-instance-q8_0-q8_0.cu
123+
template-instances/fattn-vec-instance-bf16-bf16.cu)
125124
endif()
126125

127126
ggml_add_backend_library(ggml-cuda

ggml/src/ggml-cuda/convert.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ template<typename dst_t, typename src_t>
4141
return __bfloat162float(x);
4242
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
4343
return __float22half2_rn(x);
44+
} else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
45+
#ifdef GGML_USE_HIP
46+
return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
47+
#else
48+
#if __CUDA_ARCH__ >= 800
49+
return __bfloat1622float2(x);
50+
#else
51+
return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
52+
#endif // __CUDA_ARCH__ >= 800
53+
#endif // GGML_USE_HIP
4454
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
4555
// bypass compile error on cuda 12.0.1
4656
#ifdef GGML_USE_HIP

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
7474
return sum;
7575
}
7676

77+
template <int D, int nthreads>
78+
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16(
79+
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
80+
81+
const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c;
82+
GGML_UNUSED(Q_q8);
83+
GGML_UNUSED(Q_ds_v);
84+
85+
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
86+
constexpr int cpy_ne = cpy_nb / 4;
87+
88+
float sum = 0.0f;
89+
90+
#pragma unroll
91+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
92+
__align__(16) nv_bfloat162 tmp[cpy_ne];
93+
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
94+
#pragma unroll
95+
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
96+
#ifdef V_DOT2_F32_F16_AVAILABLE
97+
// FIXME replace macros in vector FA kernel with templating and use FP32 for BF16
98+
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]));
99+
#else
100+
ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
101+
#endif // V_DOT2_F32_F16_AVAILABLE
102+
}
103+
}
104+
105+
return sum;
106+
}
107+
77108
template<int D, int nthreads>
78109
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
79110
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
@@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
321352
}
322353
}
323354

355+
template <typename T, int ne>
356+
static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
357+
static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output");
358+
static_assert(ne % 2 == 0, "bad ne");
359+
__align__(16) nv_bfloat162 tmp[ne/2];
360+
ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0);
361+
float2 * dst_f2 = (float2 *) dst;
362+
#pragma unroll
363+
for (int l = 0; l < ne/2; ++l) {
364+
dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]);
365+
}
366+
}
367+
324368
template <typename T, int ne>
325369
static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
326370
const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
547591
return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
548592
} else if constexpr (type_K == GGML_TYPE_Q8_0) {
549593
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
594+
} else if constexpr (type_K == GGML_TYPE_BF16) {
595+
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
550596
} else {
551597
static_assert(type_K == -1, "bad type");
552598
return nullptr;
@@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
567613
return dequantize_V_q5_1<T, ne>;
568614
} else if constexpr (type_V == GGML_TYPE_Q8_0) {
569615
return dequantize_V_q8_0<T, ne>;
616+
} else if constexpr (type_V == GGML_TYPE_BF16) {
617+
return dequantize_V_bf16<float, ne>;
570618
} else {
571619
static_assert(type_V == -1, "bad type");
572620
return nullptr;

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec(
7575
#endif // GGML_USE_HIP
7676

7777
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
78-
constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
79-
constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
78+
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
79+
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
8080

8181
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
8282
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
8383

84-
constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
84+
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
8585
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
8686

8787
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
88-
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
88+
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
8989
#ifdef V_DOT2_F32_F16_AVAILABLE
9090
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
9191
#else
@@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec(
323323
#pragma unroll
324324
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
325325
half2 tmp[V_rows_per_thread/2];
326-
dequantize_V(V + k*nb21, tmp,
327-
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
326+
if constexpr (type_V == GGML_TYPE_BF16) {
327+
float2 tmp_f[V_rows_per_thread/2];
328+
dequantize_V(V + k*nb21, tmp_f,
329+
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
330+
#pragma unroll
331+
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
332+
tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]);
333+
}
334+
} else {
335+
dequantize_V(V + k*nb21, tmp,
336+
2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
337+
}
328338
#pragma unroll
329339
for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
330340
#pragma unroll
@@ -563,24 +573,28 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
563573
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
564574
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
565575
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
576+
extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \
566577

567578
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
568579
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
569580
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
570581
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
571582
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
572583
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
584+
EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16)
573585

574586
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
575587
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
576588
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
577589
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
578590
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
579591
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
592+
EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16)
580593

581594
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
582595
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
583596
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
584597
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
585598
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
586599
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
600+
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)

0 commit comments

Comments
 (0)