Skip to content

llm(engine): add SSE dequant path for int4/int8 disk embedding#4257

Closed
EricMoin wants to merge 1 commit intoalibaba:masterfrom
EricMoin:master
Closed

llm(engine): add SSE dequant path for int4/int8 disk embedding#4257
EricMoin wants to merge 1 commit intoalibaba:masterfrom
EricMoin:master

Conversation

@EricMoin
Copy link
Contributor

补了个有SSE的版本,在q41_dequant_ref里加速比较明显。下面是测试的代码和结果

#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <iomanip>
#include <iostream>
#include <limits>
#include <random>
#include <string>
#include <vector>

#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(_M_X64))
#include <smmintrin.h>
#define MNN_BENCH_HAS_SSE 1
#else
#define MNN_BENCH_HAS_SSE 0
#endif

namespace {

using Clock = std::chrono::high_resolution_clock;

struct BenchConfig {
    int warmup = 8;
    int repeat = 2000;
    int hidden = 4096;
    int tokens = 512;
};

void q41_dequant_ref(const uint8_t* src, float* dst, float scale, float zero, int size) {
    for (int i = 0; i < size / 2; i++) {
        const int x = src[i];
        const int hi = x / 16;
        const int lo = x % 16;
        dst[2 * i] = hi * scale + zero;
        dst[2 * i + 1] = lo * scale + zero;
    }
}

#if MNN_BENCH_HAS_SSE
void q41_dequant_sse(const uint8_t* src, float* dst, float scale, float zero, int size) {
    const __m128 scale4 = _mm_set1_ps(scale);
    const __m128 zero4 = _mm_set1_ps(zero);
    const __m128i nibble = _mm_set1_epi8(0x0f);

    const int bytes = size / 2;
    int i = 0;
    alignas(16) float high_buffer[16];
    alignas(16) float low_buffer[16];

    for (; i + 16 <= bytes; i += 16) {
        const __m128i x = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
        const __m128i lo8 = _mm_and_si128(x, nibble);
        const __m128i hi8 = _mm_and_si128(_mm_srli_epi16(x, 4), nibble);

        const __m128i lo_i0 = _mm_cvtepu8_epi32(lo8);
        const __m128i lo_i1 = _mm_cvtepu8_epi32(_mm_srli_si128(lo8, 4));
        const __m128i lo_i2 = _mm_cvtepu8_epi32(_mm_srli_si128(lo8, 8));
        const __m128i lo_i3 = _mm_cvtepu8_epi32(_mm_srli_si128(lo8, 12));
        const __m128i hi_i0 = _mm_cvtepu8_epi32(hi8);
        const __m128i hi_i1 = _mm_cvtepu8_epi32(_mm_srli_si128(hi8, 4));
        const __m128i hi_i2 = _mm_cvtepu8_epi32(_mm_srli_si128(hi8, 8));
        const __m128i hi_i3 = _mm_cvtepu8_epi32(_mm_srli_si128(hi8, 12));

        const __m128 lo_f0 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(lo_i0), scale4), zero4);
        const __m128 lo_f1 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(lo_i1), scale4), zero4);
        const __m128 lo_f2 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(lo_i2), scale4), zero4);
        const __m128 lo_f3 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(lo_i3), scale4), zero4);
        const __m128 hi_f0 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(hi_i0), scale4), zero4);
        const __m128 hi_f1 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(hi_i1), scale4), zero4);
        const __m128 hi_f2 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(hi_i2), scale4), zero4);
        const __m128 hi_f3 = _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(hi_i3), scale4), zero4);

        _mm_storeu_ps(low_buffer, lo_f0);
        _mm_storeu_ps(low_buffer + 4, lo_f1);
        _mm_storeu_ps(low_buffer + 8, lo_f2);
        _mm_storeu_ps(low_buffer + 12, lo_f3);
        _mm_storeu_ps(high_buffer, hi_f0);
        _mm_storeu_ps(high_buffer + 4, hi_f1);
        _mm_storeu_ps(high_buffer + 8, hi_f2);
        _mm_storeu_ps(high_buffer + 12, hi_f3);

        const int out = 2 * i;
        for (int k = 0; k < 16; ++k) {
            dst[out + 2 * k] = high_buffer[k];
            dst[out + 2 * k + 1] = low_buffer[k];
        }
    }

    for (; i < bytes; ++i) {
        const int x = src[i];
        const int hi = x / 16;
        const int lo = x % 16;
        dst[2 * i] = hi * scale + zero;
        dst[2 * i + 1] = lo * scale + zero;
    }
}
#endif

double max_abs_diff(const std::vector<float>& a, const std::vector<float>& b) {
    double m = 0.0;
    const size_t n = std::min(a.size(), b.size());
    for (size_t i = 0; i < n; ++i) {
        m = std::max(m, static_cast<double>(std::fabs(a[i] - b[i])));
    }
    return m;
}

double benchmark_ms(const std::function<void()>& fn, int warmup, int repeat) {
    for (int i = 0; i < warmup; ++i) {
        fn();
    }
    const auto t0 = Clock::now();
    for (int i = 0; i < repeat; ++i) {
        fn();
    }
    const auto t1 = Clock::now();
    const auto us = std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count();
    return static_cast<double>(us) / 1000.0 / repeat;
}

void print_compare(double before_ms, double after_ms, double diff, double items) {
    const double speedup = before_ms / std::max(after_ms, std::numeric_limits<double>::min());
    const double ns_per_item_before = before_ms * 1e6 / items;
    const double ns_per_item_after = after_ms * 1e6 / items;
    std::cout << "[q41_dequant]\n"
              << "  baseline: " << std::fixed << std::setprecision(4) << before_ms << " ms"
              << "  (" << std::setprecision(2) << ns_per_item_before << " ns/item)\n"
              << "  optimized: " << std::fixed << std::setprecision(4) << after_ms << " ms"
              << "  (" << std::setprecision(2) << ns_per_item_after << " ns/item)\n"
              << "  speedup: x" << std::setprecision(3) << speedup << "\n"
              << "  max_abs_diff: " << std::setprecision(8) << diff << "\n";
}

BenchConfig parse_args(int argc, char** argv) {
    BenchConfig cfg;
    for (int i = 1; i + 1 < argc; i += 2) {
        std::string k = argv[i];
        const int v = std::atoi(argv[i + 1]);
        if (k == "--warmup") cfg.warmup = v;
        if (k == "--repeat") cfg.repeat = v;
        if (k == "--hidden") cfg.hidden = v;
        if (k == "--tokens") cfg.tokens = v;
    }
    cfg.hidden = std::max(4, cfg.hidden);
    cfg.tokens = std::max(1, cfg.tokens);
    return cfg;
}

}

int main(int argc, char** argv) {
    const BenchConfig cfg = parse_args(argc, argv);
    std::cout << "sse_operator_bench\n"
              << "  warmup=" << cfg.warmup
              << " repeat=" << cfg.repeat
              << " hidden=" << cfg.hidden
              << " tokens=" << cfg.tokens << "\n"
              << "  SSE=" << (MNN_BENCH_HAS_SSE ? "ON" : "OFF") << "\n\n";

    const int dequant_size = cfg.hidden * cfg.tokens;
    std::vector<uint8_t> src(dequant_size / 2);
    std::vector<float> ref(dequant_size), opt(dequant_size);

    std::mt19937 rng(42);
    std::uniform_int_distribution<int> u8(0, 255);
    for (int i = 0; i < static_cast<int>(src.size()); ++i) {
        src[i] = static_cast<uint8_t>(u8(rng));
    }

    const float scale = 0.03125f;
    const float zero = -2.0f;

    const double base = benchmark_ms([&] { q41_dequant_ref(src.data(), ref.data(), scale, zero, dequant_size); },
                                     cfg.warmup, cfg.repeat);

#if MNN_BENCH_HAS_SSE
    const double simd = benchmark_ms([&] { q41_dequant_sse(src.data(), opt.data(), scale, zero, dequant_size); },
                                     cfg.warmup, cfg.repeat);
    q41_dequant_ref(src.data(), ref.data(), scale, zero, dequant_size);
    q41_dequant_sse(src.data(), opt.data(), scale, zero, dequant_size);
    print_compare(base, simd, max_abs_diff(ref, opt), dequant_size);
#else
    print_compare(base, base, 0.0, dequant_size);
#endif

    std::cout << "\nDone.\n";
    return 0;
}
屏幕截图 2026-03-13 184540

@wangzhaode wangzhaode self-assigned this Mar 13, 2026
@wangzhaode
Copy link
Collaborator

感谢你的优化和代码!但是直接在DiskEmbedding里通过宏来判断是否使用SIMD指令不太符合MNN的代码编程规范。反量化的技术方案会在最近几个版本考虑进行优化。

@EricMoin EricMoin closed this Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants