From ca6d09bed94aebd253e2fb431b62ddb224ef014c Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 10 Jan 2025 01:11:40 +0900 Subject: [PATCH 01/53] make KnLM builder general (uint16_t & uint32_t) --- include/kiwi/Kiwi.h | 24 +++++++- include/kiwi/Knlm.h | 6 +- include/kiwi/Trie.hpp | 8 +-- src/KiwiBuilder.cpp | 114 ++++++++++++++++++++++--------------- src/Knlm.hpp | 4 +- src/SubstringExtractor.cpp | 2 +- src/count.hpp | 67 +++++++++++----------- 7 files changed, 134 insertions(+), 91 deletions(-) diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 1c2ee7d2..fd49d1ff 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -541,12 +541,34 @@ namespace kiwi BuildOption options = BuildOption::none; ArchType archType = ArchType::none; + public: + struct ModelBuildArgs + { + std::string morphemeDef; + std::vector corpora; + size_t minMorphCnt = 10; + size_t lmOrder = 4; + std::vector lmMinCnts = { 1 }; + size_t numWorkers = 1; + size_t sbgSize = 1000000; + bool useLmTagHistory = true; + bool quantizeLm = true; + bool compressLm = true; + float dropoutSampling = 0.05f; + float dropoutProb = 0.15f; + }; + + private: + using MorphemeMap = UnorderedMap, std::pair>; + + template + std::unique_ptr buildKnLM(const ModelBuildArgs& args, size_t lmVocabSize, MorphemeMap& morphMap) const; + void loadMorphBin(std::istream& is); void saveMorphBin(std::ostream& os) const; FormRaw& addForm(const KString& form); size_t addForm(Vector& newForms, UnorderedMap& newFormMap, KString form) const; - using MorphemeMap = UnorderedMap, std::pair>; void initMorphemes(); diff --git a/include/kiwi/Knlm.h b/include/kiwi/Knlm.h index ef54eb9b..7727ad17 100644 --- a/include/kiwi/Knlm.h +++ b/include/kiwi/Knlm.h @@ -15,8 +15,6 @@ namespace kiwi { namespace lm { - using Vid = uint16_t; - struct Header { uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset; @@ -61,12 +59,12 @@ namespace kiwi static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); - template> + template> static utils::MemoryOwner build(Trie&& ngram_cf, size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, - const std::vector>* bigram_list = nullptr, + const std::vector>* bigram_list = nullptr, const HistoryTx* history_transformer = nullptr, const void* extra_buf = nullptr, size_t extra_buf_size = 0 diff --git a/include/kiwi/Trie.hpp b/include/kiwi/Trie.hpp index 2bd9a42a..49ff5781 100644 --- a/include/kiwi/Trie.hpp +++ b/include/kiwi/Trie.hpp @@ -150,8 +150,8 @@ namespace kiwi } } - template - void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const + template + void traverseWithKeys(_Fn&& fn, std::vector<_CKey, _Alloc>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { fn((Node*)this, rkeys); @@ -487,8 +487,8 @@ namespace kiwi return nodes[0].traverse(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative); } - template - void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const + template + void traverseWithKeys(_Fn&& fn, std::vector<_CKey, _Alloc>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { return nodes[0].traverseWithKeys(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative); } diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 06b47b86..6190c487 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -760,7 +760,7 @@ void KiwiBuilder::loadMorphBin(std::istream& is) for (auto& form : forms) { const size_t idx = &form - &forms[0]; - if (idx < defaultFormSize + 27) continue; + if (idx < defaultFormSize) continue; formMap.emplace(form.form, idx); } } @@ -834,24 +834,10 @@ void KiwiBuilder::initMorphemes() morphemes[defaultTagSize + 28].userScore = -1.5f; } -KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +template +unique_ptr KiwiBuilder::buildKnLM(const ModelBuildArgs& args, size_t lmVocabSize, MorphemeMap& realMorph) const { - if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) - { - throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; - } - - archType = getSelectedArch(ArchType::default_); - initMorphemes(); - - ifstream ifs; - auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) - { - return cnt >= args.minMorphCnt; - }); - updateForms(); - - RaggedVector sents; + RaggedVector sents; for (auto& path : args.corpora) { ifstream ifs; @@ -881,11 +867,7 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) } } - size_t lmVocabSize = 0; - for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); - lmVocabSize += 1; - - Vector historyTx(lmVocabSize); + Vector historyTx(lmVocabSize); if (args.useLmTagHistory) { for (size_t i = 0; i < lmVocabSize; ++i) @@ -894,24 +876,13 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) } } - vector> bigramList; utils::ThreadPool pool; - if (args.numWorkers > 1) + if (args.numWorkers >= 1) { pool.~ThreadPool(); new (&pool) utils::ThreadPool{ args.numWorkers }; } - size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); - auto cntNodes = utils::count(sents.begin(), sents.end(), lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); - // discount for bos node cnt - if (args.useLmTagHistory) - { - cntNodes.root().getNext(lmVocabSize)->val /= 2; - } - else - { - cntNodes.root().getNext(0)->val /= 2; - } + const size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); std::vector minCnts; if (args.lmMinCnts.size() == 1) { @@ -922,37 +893,84 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) { minCnts = args.lmMinCnts; } - langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( + + vector> bigramList; + auto cntNodes = utils::count(sents.begin(), sents.end(), lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); + // discount for bos node cnt + if (args.useLmTagHistory) + { + cntNodes.root().getNext(lmVocabSize)->val /= 2; + } + else + { + cntNodes.root().getNext(0)->val /= 2; + } + + return lm::KnLangModelBase::create(lm::KnLangModelBase::build( cntNodes, args.lmOrder, minCnts, 2, 0, 1, 1e-5, args.quantizeLm ? 8 : 0, - args.compressLm, + sizeof(VocabTy) == 2 ? args.compressLm : false, &bigramList, args.useLmTagHistory ? &historyTx : nullptr ), archType); +} + + +KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +{ + if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) + { + throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; + } + + archType = getSelectedArch(ArchType::default_); + initMorphemes(); + + ifstream ifs; + auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) + { + return cnt >= args.minMorphCnt; + }); + updateForms(); + + size_t lmVocabSize = 0; + for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); + lmVocabSize += 1; + + if (lmVocabSize <= 0xFFFF) + { + langMdl.knlm = buildKnLM(args, lmVocabSize, realMorph); + } + else + { + langMdl.knlm = buildKnLM(args, lmVocabSize, realMorph); + } updateMorphemes(); } + namespace kiwi { + template class SBDataFeeder { - const RaggedVector& sents; + const RaggedVector& sents; const lm::KnLangModelBase* lm = nullptr; Vector> lmBuf; Vector> nodeBuf; public: - SBDataFeeder(const RaggedVector& _sents, const lm::KnLangModelBase* _lm, size_t numThreads = 1) + SBDataFeeder(const RaggedVector& _sents, const lm::KnLangModelBase* _lm, size_t numThreads = 1) : sents{ _sents }, lm{ _lm }, lmBuf(numThreads), nodeBuf(numThreads) { } - sb::FeedingData operator()(size_t i, size_t threadId = 0) + sb::FeedingData operator()(size_t i, size_t threadId = 0) { - sb::FeedingData ret; + sb::FeedingData ret; ret.len = sents[i].size(); if (lmBuf[threadId].size() < ret.len) { @@ -971,9 +989,11 @@ namespace kiwi KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) : KiwiBuilder{ modelPath } { + using Vid = uint16_t; + auto realMorph = restoreMorphemeMap(); - sb::SkipBigramTrainer sbg; - RaggedVector sents; + sb::SkipBigramTrainer sbg; + RaggedVector sents; for (auto& path : args.corpora) { ifstream ifs; @@ -1050,7 +1070,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) return true; }; - sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, langMdl.knlm->nonLeafNodeSize() }; + sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, langMdl.knlm->nonLeafNodeSize() }; Vector lmLogProbs; Vector baseNodes; auto tc = sbg.newContext(); @@ -1088,7 +1108,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) if (args.numWorkers <= 1) { - sbg.train(SBDataFeeder{ sents, langMdl.knlm.get() }, [&](const sb::ObservingData& od) + sbg.train(SBDataFeeder{ sents, langMdl.knlm.get() }, [&](const sb::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; @@ -1102,7 +1122,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) } else { - sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, langMdl.knlm.get(), 8 }, [&](const sb::ObservingData& od) + sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, langMdl.knlm.get(), 8 }, [&](const sb::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 5b1f0185..1f9f5821 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -1088,11 +1088,11 @@ namespace kiwi using type = utils::TrieNodeEx; }; - template + template utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, - const std::vector>* bigram_list, const HistoryTx* history_transformer, + const std::vector>* bigram_list, const HistoryTx* history_transformer, const void* extra_buf, size_t extra_buf_size ) { diff --git a/src/SubstringExtractor.cpp b/src/SubstringExtractor.cpp index 1ac6d5e4..452c6b24 100644 --- a/src/SubstringExtractor.cpp +++ b/src/SubstringExtractor.cpp @@ -384,7 +384,7 @@ namespace kiwi utils::MemoryOwner mem; { auto trie = count(); - mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, + mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, 1e-5f, 0, false, nullptr, (const Vector*)nullptr, extraBuf.data(), extraBuf.size()); } diff --git a/src/count.hpp b/src/count.hpp index ac6e7505..4d6deed8 100644 --- a/src/count.hpp +++ b/src/count.hpp @@ -29,10 +29,10 @@ namespace kiwi #else template using map = std::map; #endif - using Vid = uint16_t; - using CTrieNode = TrieNodeEx>>; + template + using CTrieNode = TrieNodeEx>>; - static constexpr Vid non_vocab_id = (Vid)-1; + static constexpr size_t non_vocab_id = (size_t)-1; template class StrideIter : public _Iterator @@ -77,14 +77,15 @@ namespace kiwi { struct vvhash { - size_t operator()(const std::pair& k) const + template + size_t operator()(const std::pair& k) const { - return std::hash{}(k.first) ^ std::hash{}(k.second); + return std::hash{}(k.first) ^ std::hash{}(k.second); } }; } - template + template void countUnigrams(std::vector& unigramCf, std::vector& unigramDf, _DocIter docBegin, _DocIter docEnd ) @@ -93,7 +94,7 @@ namespace kiwi { auto doc = *docIt; if (!doc.size()) continue; - std::unordered_set uniqs; + std::unordered_set uniqs; for (size_t i = 0; i < doc.size(); ++i) { if (doc[i] == non_vocab_id) continue; @@ -110,24 +111,24 @@ namespace kiwi } } - template> - void countBigrams(map, size_t>& bigramCf, - map, size_t>& bigramDf, + template> + void countBigrams(map, size_t>& bigramCf, + map, size_t>& bigramDf, _DocIter docBegin, _DocIter docEnd, _Freqs&& vocabFreqs, _Freqs&& vocabDf, size_t candMinCnt, size_t candMinDf, const _HistoryTx* historyTransformer = nullptr ) { - std::unordered_set, detail::vvhash> uniqBigram; + std::unordered_set, detail::vvhash> uniqBigram; for (auto docIt = docBegin; docIt != docEnd; ++docIt) { auto doc = *docIt; if (!doc.size()) continue; - Vid prevWord = doc[0]; + VocabTy prevWord = doc[0]; for (size_t j = 1; j < doc.size(); ++j) { - Vid curWord = doc[j]; + VocabTy curWord = doc[j]; if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf) { if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf) @@ -144,8 +145,8 @@ namespace kiwi } } - template> - void countNgrams(ContinuousTrie& dest, + template> + void countNgrams(ContinuousTrie>& dest, _DocIter docBegin, _DocIter docEnd, _Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs, size_t candMinCnt, size_t candMinDf, size_t maxNgrams, @@ -154,7 +155,7 @@ namespace kiwi { if (dest.empty()) { - dest = ContinuousTrie{ 1, 1024 }; + dest = ContinuousTrie>{ 1, 1024 }; } const auto& allocNode = [&]() { return dest.newNode(); }; const auto& historyTx = [&](size_t i) { return (*historyTransformer)[i]; }; @@ -165,7 +166,7 @@ namespace kiwi if (!doc.size()) continue; dest.reserveMore(doc.size() * maxNgrams * 2); - Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin(); + VocabTy prevWord = _reverse ? *doc.rbegin() : *doc.begin(); size_t labelLen = 0; auto node = &dest[0]; if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf) @@ -175,7 +176,7 @@ namespace kiwi labelLen = 1; } - const auto func = [&](Vid curWord) + const auto func = [&](VocabTy curWord) { if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf)) { @@ -239,19 +240,21 @@ namespace kiwi } } - inline void mergeNgramCounts(ContinuousTrie& dest, ContinuousTrie&& src) + template + inline void mergeNgramCounts(ContinuousTrie>& dest, ContinuousTrie>&& src) { if (src.empty()) return; - if (dest.empty()) dest = ContinuousTrie{ 1 }; + if (dest.empty()) dest = ContinuousTrie>{ 1 }; - std::vector rkeys; - src.traverseWithKeys([&](const CTrieNode* node, const std::vector& rkeys) + std::vector rkeys; + src.traverseWithKeys([&](const CTrieNode* node, const std::vector& rkeys) { dest.build(rkeys.begin(), rkeys.end(), 0)->val += node->val; }, rkeys); } - inline float branchingEntropy(const CTrieNode* node, size_t minCnt) + template + inline float branchingEntropy(const CTrieNode* node, size_t minCnt) { float entropy = 0; size_t rest = node->val; @@ -300,16 +303,16 @@ namespace kiwi return std::move(data[0]); } - template> - ContinuousTrie count(_DocIter docBegin, _DocIter docEnd, + template> + ContinuousTrie> count(_DocIter docBegin, _DocIter docEnd, size_t minCf, size_t minDf, size_t maxNgrams, - ThreadPool* pool = nullptr, std::vector>* bigramList = nullptr, + ThreadPool* pool = nullptr, std::vector>* bigramList = nullptr, const _HistoryTx* historyTransformer = nullptr ) { // counting unigrams & bigrams std::vector unigramCf, unigramDf; - map, size_t> bigramCf, bigramDf; + map, size_t> bigramCf, bigramDf; if (pool && pool->size() > 1) { @@ -325,7 +328,7 @@ namespace kiwi { futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid) { - countUnigrams(localdata[tid].first, localdata[tid].second, + countUnigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd) ); @@ -351,7 +354,7 @@ namespace kiwi } else { - countUnigrams(unigramCf, unigramDf, docBegin, docEnd); + countUnigrams(unigramCf, unigramDf, docBegin, docEnd); } if (pool && pool->size() > 1) @@ -402,7 +405,7 @@ namespace kiwi } } - ContinuousTrie trieNodes{ 1 }; + ContinuousTrie> trieNodes{ 1 }; if (historyTransformer) { for (size_t i = 0; i < unigramCf.size(); ++i) @@ -434,7 +437,7 @@ namespace kiwi // counting ngrams else { - std::unordered_set, detail::vvhash> validPairs; + std::unordered_set, detail::vvhash> validPairs; for (auto& p : bigramCf) { if (p.second >= minCf && bigramDf[p.first] >= minDf) validPairs.emplace(p.first); @@ -442,7 +445,7 @@ namespace kiwi if (pool && pool->size() > 1) { - using LocalFw = ContinuousTrie; + using LocalFw = ContinuousTrie>; std::vector localdata(pool->size()); std::vector> futures; const size_t stride = pool->size() * 8; From cb1ebc06d53731a23ad17eb9bf28a1d559d987f0 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 13 Jan 2025 01:06:51 +0900 Subject: [PATCH 02/53] Implement PoC of KnLM transposed & PcLM --- .gitmodules | 3 + include/kiwi/Dataset.h | 6 +- include/kiwi/Kiwi.h | 108 ++--- include/kiwi/LmState.h | 9 + include/kiwi/Mmap.h | 11 + include/kiwi/PCLanguageModel.h | 52 ++ include/kiwi/Types.h | 9 + src/Dataset.cpp | 109 ++++- src/Kiwi.cpp | 49 +- src/KiwiBuilder.cpp | 82 +++- src/LmState.hpp | 49 +- src/PCLanguageModel.cpp | 267 +++++++++++ src/PCLanguageModel.hpp | 335 +++++++++++++ src/PathEvaluator.hpp | 743 +++++++++++++++++++++++++++-- src/SkipBigramModel.cpp | 45 ++ src/SkipBigramModel.hpp | 38 -- src/capi/kiwi_c.cpp | 2 +- test/test_cpp.cpp | 21 +- third_party/streamvbyte | 1 + tools/evaluator_main.cpp | 53 +- tools/pclm_builder.cpp | 62 +++ vsproj/build_pclm.vcxproj | 225 +++++++++ vsproj/kiwi_shared_library.vcxproj | 19 +- 23 files changed, 2098 insertions(+), 200 deletions(-) create mode 100644 include/kiwi/PCLanguageModel.h create mode 100644 src/PCLanguageModel.cpp create mode 100644 src/PCLanguageModel.hpp create mode 100644 src/SkipBigramModel.cpp create mode 160000 third_party/streamvbyte create mode 100644 tools/pclm_builder.cpp create mode 100644 vsproj/build_pclm.vcxproj diff --git a/.gitmodules b/.gitmodules index a9fcfa31..e76d2d4b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -24,3 +24,6 @@ [submodule "third_party/json"] path = third_party/json url = https://github.com/nlohmann/json +[submodule "third_party/streamvbyte"] + path = third_party/streamvbyte + url = https://github.com/fast-pack/streamvbyte diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index e45738d5..0a0e9727 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -1,6 +1,7 @@ #pragma once #include #include "Kiwi.h" +#include "FrozenTrie.h" namespace kiwi { @@ -30,7 +31,7 @@ namespace kiwi struct ThreadLocal { std::mt19937_64 rng; - Vector tokenBuf; + Vector tokenBuf, contextualTokenBuf; Vector lmLProbsBuf; Vector outNgramNodeBuf; Deque historyBuf; @@ -58,6 +59,7 @@ namespace kiwi Deque> futures; const Vector* morphemes = nullptr; const Vector* forms = nullptr; + utils::FrozenTrie contextualMapper; size_t knlmVocabSize = 0; size_t batchSize = 0; size_t causalContextSize = 0; @@ -103,6 +105,6 @@ namespace kiwi Range::const_iterator> getSent(size_t idx) const; std::vector getAugmentedSent(size_t idx); - std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1) const; + std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1, bool exclusiveCnt = false) const; }; } diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index fd49d1ff..ecc874c2 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -58,7 +58,8 @@ namespace kiwi class Kiwi { friend class KiwiBuilder; - friend class PathEvaluator; + friend struct BestPathFinder; + template friend struct PathEvaluator; friend class cmb::AutoJoiner; template class LmState> friend struct NewAutoJoinerGetter; @@ -156,6 +157,8 @@ namespace kiwi ArchType archType() const { return selectedArch; } + ModelType modelType() const { return langMdl.type; } + /** * @brief 현재 Kiwi 객체가 오타 교정 기능이 켜진 상태로 생성되었는지 알려준다. * @@ -536,9 +539,10 @@ namespace kiwi LangModel langMdl; std::shared_ptr combiningRule; WordDetector detector; - + size_t numThreads = 0; BuildOption options = BuildOption::none; + ModelType modelType = ModelType::none; ArchType archType = ArchType::none; public: @@ -579,15 +583,15 @@ namespace kiwi template void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const; - + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; void updateForms(); - void updateMorphemes(); + void updateMorphemes(size_t vocabSize = 0); size_t findMorpheme(U16StringView form, POSTag tag) const; - + std::pair addWord(U16StringView newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId); std::pair addWord(const std::u16string& newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId); std::pair addWord(U16StringView form, POSTag tag = POSTag::nnp, float score = 0); @@ -611,44 +615,28 @@ namespace kiwi ) const; void addCombinedMorphemes( - Vector& newForms, - UnorderedMap& newFormMap, - Vector& newMorphemes, - UnorderedMap>& newFormCands, - size_t leftId, - size_t rightId, + Vector& newForms, + UnorderedMap& newFormMap, + Vector& newMorphemes, + UnorderedMap>& newFormCands, + size_t leftId, + size_t rightId, size_t ruleId ) const; void buildCombinedMorphemes( - Vector& newForms, + Vector& newForms, UnorderedMap& newFormMap, - Vector& newMorphemes, + Vector& newMorphemes, UnorderedMap>& newFormCands ) const; void addAllomorphsToRule(); public: - struct ModelBuildArgs - { - std::string morphemeDef; - std::vector corpora; - size_t minMorphCnt = 10; - size_t lmOrder = 4; - std::vector lmMinCnts = { 1 }; - size_t numWorkers = 1; - size_t sbgSize = 1000000; - bool useLmTagHistory = true; - bool quantizeLm = true; - bool compressLm = true; - float dropoutSampling = 0.05f; - float dropoutProb = 0.15f; - }; - /** * @brief KiwiBuilder의 기본 생성자 - * + * * @note 이 생성자로 생성된 경우 `ready() == false`인 상태이므로 유효한 Kiwi 객체를 생성할 수 없다. */ KiwiBuilder(); @@ -665,9 +653,9 @@ namespace kiwi /** * @brief KiwiBuilder를 raw 데이터로부터 생성한다. - * - * - * @note 이 함수는 현재 내부적으로 기본 모델 구축에 쓰인다. + * + * + * @note 이 함수는 현재 내부적으로 기본 모델 구축에 쓰인다. * 추후 공개 데이터로도 쉽게 직접 모델을 구축할 수 있도록 개선된 API를 제공할 예정. */ KiwiBuilder(const ModelBuildArgs& args); @@ -679,16 +667,16 @@ namespace kiwi /** * @brief KiwiBuilder를 모델 파일로부터 생성한다. - * + * * @param modelPath 모델이 위치한 경로 * @param numThreads 모델 및 형태소 분석에 사용할 스레드 개수 * @param options 생성 옵션. `kiwi::BuildOption`을 참조 */ - KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, bool useSBG = false); + KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, ModelType modelType = ModelType::knlm); /** * @brief 현재 KiwiBuilder 객체가 유효한 분석 모델을 로딩한 상태인지 알려준다. - * + * * @return 유효한 상태면 true를 반환한다. 기본 생성자로 생성한 경우 `ready() == false`이며, * 다른 생성자로 생성한 경우는 `ready() == true`이다. */ @@ -701,7 +689,7 @@ namespace kiwi /** * @brief 사전에 새로운 형태소를 추가한다. 이미 동일한 형태소가 있는 경우는 무시된다. - * + * * @param form 새로운 형태소의 형태 * @param tag 품사 태그 * @param score 페널티 점수. 이에 대한 자세한 설명은 하단의 note 참조. @@ -709,7 +697,7 @@ namespace kiwi * @note 이 방법으로 추가된 형태소는 언어모델 탐색에서 어휘 사전 외 토큰(OOV 토큰)으로 처리된다. * 이 방법으로 추가된 형태소는 항상 분석 과정에서 최우선으로 탐색되지는 않으므로 최상의 결과를 위해서는 `score` 값을 조절할 필요가 있다. * `score` 값을 높게 설정할수록 다른 후보들과의 경쟁에서 이 형태소가 더 높은 점수를 받아 최종 분석 결과에 노출될 가능성이 높아진다. - * 만약 이 방법으로 추가된 형태소가 원치 않는 상황에서 과도하게 출력되는 경우라면 `score`를 더 작은 값으로, + * 만약 이 방법으로 추가된 형태소가 원치 않는 상황에서 과도하게 출력되는 경우라면 `score`를 더 작은 값으로, * 반대로 원하는 상황에서도 출력되지 않는 경우라면 `score`를 더 큰 값으로 조절하는 게 좋다. */ std::pair addWord(const std::u16string& form, POSTag tag = POSTag::nnp, float score = 0); @@ -738,26 +726,26 @@ namespace kiwi * @param score 페널티 점수. 이에 대한 자세한 설명은 하단의 `addWord`함수의 note 참조. * @exception kiwi::UnknownMorphemeException `analyzed`로 주어진 형태소 중 하나라도 존재하지 않는게 있는 경우 예외를 발생시킨다. * @return 형태소열을 추가하는데 성공했으면 true, 동일한 형태소열이 존재하여 추가에 실패한 경우 false를 반환한다. - * @note 이 함수는 특정 문자열이 어떻게 분석되어야하는지 직접적으로 지정해줄 수 있다. - * 따라서 `addWord` 함수를 사용해도 오분석이 발생하는 경우, 이 함수를 통해 해당 사례들에 대해 정확한 분석 결과를 추가하면 원하는 분석 결과를 얻을 수 있다. + * @note 이 함수는 특정 문자열이 어떻게 분석되어야하는지 직접적으로 지정해줄 수 있다. + * 따라서 `addWord` 함수를 사용해도 오분석이 발생하는 경우, 이 함수를 통해 해당 사례들에 대해 정확한 분석 결과를 추가하면 원하는 분석 결과를 얻을 수 있다. */ - bool addPreAnalyzedWord(const std::u16string& form, - const std::vector>& analyzed, + bool addPreAnalyzedWord(const std::u16string& form, + const std::vector>& analyzed, std::vector> positions = {}, float score = 0 ); - bool addPreAnalyzedWord(const char16_t* form, - const std::vector>& analyzed, + bool addPreAnalyzedWord(const char16_t* form, + const std::vector>& analyzed, std::vector> positions = {}, float score = 0 ); /** * @brief 규칙에 의해 변형된 형태소 목록을 생성하여 자동 추가한다. - * - * @param tag - * @param repl - * @param score + * + * @param tag + * @param repl + * @param score * @return 새로 추가된 변형된 형태소의 ID와 그 형태를 pair로 묶은 목록 */ template @@ -792,24 +780,24 @@ namespace kiwi } /** - * @brief - * - * @param dictPath - * @return + * @brief + * + * @param dictPath + * @return */ size_t loadDictionary(const std::string& dictPath); - std::vector extractWords(const U16MultipleReader& reader, + std::vector extractWords(const U16MultipleReader& reader, size_t minCnt = 10, size_t maxWordLen = 10, float minScore = 0.25, float posThreshold = -3, bool lmFilter = true ) const; - std::vector extractAddWords(const U16MultipleReader& reader, + std::vector extractAddWords(const U16MultipleReader& reader, size_t minCnt = 10, size_t maxWordLen = 10, float minScore = 0.25, float posThreshold = -3, bool lmFilter = true ); /** * @brief 현재 단어 및 사전 설정을 기반으로 Kiwi 객체를 생성한다. - * + * * @param typos * @param typoCostThreshold * @return 형태소 분석 준비가 완료된 Kiwi의 객체. @@ -830,8 +818,8 @@ namespace kiwi using TokenFilter = std::function; - HSDataset makeHSDataset(const std::vector& inputPathes, - size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, + HSDataset makeHSDataset(const std::vector& inputPathes, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, double dropoutProbOnHistory = 0, const TokenFilter& tokenFilter = {}, @@ -840,7 +828,13 @@ namespace kiwi bool separateDefaultMorpheme = false, const std::string& morphemeDefPath = {}, size_t morphemeDefMinCnt = 0, + const std::vector>>& contextualMapper = {}, HSDataset* splitDataset = nullptr ) const; + + BuildOption getOptions() const { return options; } + ModelType getModelType() const { return modelType; } + + static void buildMorphData(const std::string& morphemeDefPath, const std::string& outputPath, size_t minCnt = 10); }; } diff --git a/include/kiwi/LmState.h b/include/kiwi/LmState.h index 1399343d..dddaac98 100644 --- a/include/kiwi/LmState.h +++ b/include/kiwi/LmState.h @@ -5,13 +5,22 @@ #include "Trie.hpp" #include "Knlm.h" #include "SkipBigramModel.h" +#include "PCLanguageModel.h" namespace kiwi { struct LangModel { + ModelType type; std::shared_ptr knlm; std::shared_ptr sbg; + std::shared_ptr pclm; + + size_t vocabSize() const + { + if (knlm) return knlm->getHeader().vocab_size; + else return pclm->getHeader().vocabSize; + } }; class LmObjectBase diff --git a/include/kiwi/Mmap.h b/include/kiwi/Mmap.h index a2c6ef3d..7850812d 100644 --- a/include/kiwi/Mmap.h +++ b/include/kiwi/Mmap.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #ifdef _WIN32 #define NOMINMAX @@ -260,6 +261,16 @@ namespace kiwi const void* get() const { return obj->get(); } size_t size() const { return obj->size(); } + + void writeToFile(const std::string& filepath) const + { + std::ofstream ofs; + if (!openFile(ofs, filepath, std::ios_base::binary)) + { + throw IOException{ "Cannot open file : " + filepath }; + } + ofs.write((const char*)get(), size()); + } }; template diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h new file mode 100644 index 00000000..e2aca30f --- /dev/null +++ b/include/kiwi/PCLanguageModel.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ArchUtils.h" +#include "Mmap.h" + +namespace kiwi +{ + namespace pclm + { + struct Header + { + uint64_t vocabSize, contextSize; + uint16_t dim; + uint8_t contextType, outputType; + uint8_t keySize, windowSize, quantize, _reserved; + uint64_t numNodes; + uint64_t nodeOffset, keyOffset, valueOffset, embOffset; + }; + + template + struct Node + { + KeyType numNexts = 0; + ValueType value = 0; + DiffType lower = 0; + uint32_t nextOffset = 0; + }; + + class PCLanguageModelBase + { + protected: + utils::MemoryObject base; + + PCLanguageModelBase(utils::MemoryObject&& mem) : base{ std::move(mem) } + { + } + public: + virtual ~PCLanguageModelBase() {} + const Header& getHeader() const { return *reinterpret_cast(base.get()); } + + static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + }; + } +} diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index e6258499..40bc4c62 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -303,6 +303,15 @@ namespace kiwi default_ = integrateAllomorph | loadDefaultDict | loadTypoDict | loadMultiDict, }; + enum class ModelType + { + none = 0, /**< Select default model */ + knlm = 1, /**< Kneser-Ney Language Model */ + sbg = 2, /**< Skip-Bigram Model */ + pclm = 3, /**< Pre-computed Context Language Model */ + knlmTransposed, + }; + struct Morpheme; /** diff --git a/src/Dataset.cpp b/src/Dataset.cpp index f0b258f7..95812b7a 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -89,6 +89,7 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, { auto& local = locals[localId]; auto& tokens = local.tokenBuf; + auto& contextualTokens = local.contextualTokenBuf; tokens.reserve(sents.get()[shuffledIdx[sentFirst]].size()); for (size_t s = sentFirst; s < sentLast; ++s) { @@ -133,6 +134,53 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, history.back() = tokenToVocab[tokens[0]]; } } + + if (causalContextSize && contextualMapper.size()) + { + auto* node = contextualMapper.root(); + contextualTokens.clear(); + contextualTokens.reserve(tokens.size()); + for (size_t i = 0; i < tokens.size(); ++i) + { + const int32_t v = tokenToVocab[tokens[i]]; + auto* next = node->template nextOpt(contextualMapper, v); + while (!next) + { + node = node->fail(); + if (!node) break; + next = node->template nextOpt(contextualMapper, v); + } + if (next) + { + auto val = next->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + contextualTokens.emplace_back(val - 1); + } + else if (contextualMapper.hasSubmatch(val)) + { + auto sub = next->fail(); + for (; sub; sub = sub->fail()) + { + val = sub->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + break; + } + } + if (sub) contextualTokens.emplace_back(val - 1); + else contextualTokens.emplace_back(nonVocab); + } + node = next; + } + else + { + contextualTokens.emplace_back(nonVocab); + node = contextualMapper.root(); + } + } + } + for (size_t i = 1; i < tokens.size(); ++i) { int32_t v = tokenToVocab[tokens[i]]; @@ -157,6 +205,10 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, { local.inData.emplace_back(nonVocab); } + else if (contextualMapper.size()) + { + local.inData.emplace_back(contextualTokens[i + j - causalContextSize]); + } else { auto t = tokens[i + j - causalContextSize]; @@ -363,7 +415,7 @@ std::vector HSDataset::getAugmentedSent(size_t idx) return ret; } -std::vector, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers) const +std::vector, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers, bool exclusiveCnt) const { using Pair = std::pair, size_t>; std::vector ret; @@ -373,15 +425,58 @@ std::vector, size_t>> kiwi::HSDataset::extractPr counter.addArray(&*sent.begin(), &*sent.end()); } auto trie = counter.count(); - trie.traverse([&](size_t cnt, const std::vector& prefix) + if (exclusiveCnt) { - if (cnt < minCnt) return; - if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + Vector, size_t>> cnts_by_length(maxLength); + trie.traverse([&](size_t cnt, const std::vector& prefix) { - return; + if (cnt < minCnt) return; + if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + { + return; + } + Vector p(prefix.begin(), prefix.end()); + cnts_by_length[p.size() - 1].emplace(move(p), cnt); + }); + + Vector suffix; + suffix.reserve(maxLength); + for (size_t i = 1; i < maxLength; ++i) + { + for (auto& p : cnts_by_length[i]) + { + suffix.clear(); + suffix.insert(suffix.end(), p.first.begin() + 1, p.first.end()); + auto it = cnts_by_length[i - 1].find(suffix); + if (it == cnts_by_length[i - 1].end() || it->second < p.second) + { + throw std::runtime_error("This should not happen"); + } + it->second -= p.second; + } } - ret.emplace_back(prefix, cnt); - }); + + for (auto& cnts : cnts_by_length) + { + for (auto& p : cnts) + { + if (p.second < minCnt) continue; + ret.emplace_back(std::vector{ p.first.begin(), p.first.end() }, p.second); + } + } + } + else + { + trie.traverse([&](size_t cnt, const std::vector& prefix) + { + if (cnt < minCnt) return; + if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + { + return; + } + ret.emplace_back(prefix, cnt); + }); + } std::sort(ret.begin(), ret.end(), [](const Pair& a, const Pair& b) { diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 7543038d..927a9f2c 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -59,12 +59,17 @@ namespace kiwi static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; static tp::Table lmKnLM_32{ FindBestPathGetter::type>{} }; static tp::Table lmKnLM_64{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_8{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_16{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_32{ FindBestPathGetter::type>{} }; static tp::Table lmSbg_8{ FindBestPathGetter::type>{} }; static tp::Table lmSbg_16{ FindBestPathGetter::type>{} }; static tp::Table lmSbg_32{ FindBestPathGetter::type>{} }; static tp::Table lmSbg_64{ FindBestPathGetter::type>{} }; + static tp::Table lmPcLM_16{ FindBestPathGetter::type>{} }; + static tp::Table lmPcLM_32{ FindBestPathGetter::type>{} }; - if (langMdl.sbg) + if (langMdl.type == ModelType::sbg) { switch (langMdl.sbg->getHeader().keySize) { @@ -84,7 +89,7 @@ namespace kiwi throw Exception{ "Wrong `lmKeySize`" }; } } - else if(langMdl.knlm) + else if(langMdl.type == ModelType::knlm) { switch (langMdl.knlm->getHeader().key_size) { @@ -99,11 +104,45 @@ namespace kiwi break; case 8: dfFindBestPath = (void*)lmKnLM_64[static_cast(selectedArch)]; + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else if (langMdl.type == ModelType::knlmTransposed) + { + switch (langMdl.knlm->getHeader().key_size) + { + case 1: + dfFindBestPath = (void*)lmKnLMT_8[static_cast(selectedArch)]; + break; + case 2: + dfFindBestPath = (void*)lmKnLMT_16[static_cast(selectedArch)]; + break; + case 4: + dfFindBestPath = (void*)lmKnLMT_32[static_cast(selectedArch)]; break; default: throw Exception{ "Wrong `lmKeySize`" }; } } + else if (langMdl.type == ModelType::pclm) + { + switch (langMdl.pclm->getHeader().keySize) + { + case 2: + dfFindBestPath = (void*)lmPcLM_16[static_cast(selectedArch)]; + break; + case 4: + dfFindBestPath = (void*)lmPcLM_32[static_cast(selectedArch)]; + break; + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else + { + throw Exception{ "Unsupported model type" }; + } } Kiwi::~Kiwi() = default; @@ -651,7 +690,7 @@ namespace kiwi inline void insertPathIntoResults( vector& ret, Vector& spStatesByRet, - const Vector& pathes, + const Vector& pathes, size_t topN, Match matchOptions, bool integrateAllomorph, @@ -677,7 +716,7 @@ namespace kiwi Vector selectedPathes(pathes.size()); for (size_t i = 0; i < ret.size(); ++i) { - auto pred = [&](const PathEvaluator::ChunkResult& p) + auto pred = [&](const BestPathFinder::ChunkResult& p) { return p.prevState == spStatesByRet[i]; }; @@ -1059,7 +1098,7 @@ namespace kiwi if (nodes.size() <= 2) continue; findPretokenizedGroupOfNode(nodeInWhichPretokenized, nodes, pretokenizedPrev, pretokenizedFirst); - Vector res = (*reinterpret_cast(dfFindBestPath))( + Vector res = (*reinterpret_cast(dfFindBestPath))( this, spStatesByRet, nodes.data(), diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 6190c487..f0c39f00 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -738,12 +738,13 @@ void KiwiBuilder::updateForms() } } -void KiwiBuilder::updateMorphemes() +void KiwiBuilder::updateMorphemes(size_t vocabSize) { + if (vocabSize == 0) vocabSize = langMdl.vocabSize(); for (auto& m : morphemes) { if (m.lmMorphemeId > 0) continue; - if (m.tag == POSTag::p || (&m - morphemes.data() + m.combined) < langMdl.knlm->getHeader().vocab_size) + if (m.tag == POSTag::p || (&m - morphemes.data() + m.combined) < vocabSize) { m.lmMorphemeId = &m - morphemes.data(); } @@ -770,8 +771,8 @@ void KiwiBuilder::saveMorphBin(std::ostream& os) const serializer::writeMany(os, serializer::toKey("KIWI"), forms, morphemes); } -KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOption _options, bool useSBG) - : detector{ modelPath, _numThreads }, options{ _options }, numThreads{ _numThreads ? _numThreads : thread::hardware_concurrency() } +KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOption _options, ModelType _modelType) + : detector{ modelPath, _numThreads }, options{ _options }, modelType{ _modelType }, numThreads{ _numThreads ? _numThreads : thread::hardware_concurrency() } { archType = getSelectedArch(ArchType::default_); @@ -780,12 +781,23 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio utils::imstream iss{ mm }; loadMorphBin(iss); } - langMdl.knlm = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType); - if (useSBG) + + langMdl.type = modelType; + if (modelType == ModelType::knlm || modelType == ModelType::knlmTransposed || modelType == ModelType::sbg) + { + langMdl.knlm = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType); + } + + if (modelType == ModelType::sbg) { langMdl.sbg = sb::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); } + if (modelType == ModelType::pclm) + { + langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType); + } + if (!!(options & BuildOption::loadDefaultDict)) { loadDictionary(modelPath + "/default.dict"); @@ -866,7 +878,7 @@ unique_ptr KiwiBuilder::buildKnLM(const ModelBuildArgs& arg } } } - + Vector historyTx(lmVocabSize); if (args.useLmTagHistory) { @@ -907,12 +919,12 @@ unique_ptr KiwiBuilder::buildKnLM(const ModelBuildArgs& arg } return lm::KnLangModelBase::create(lm::KnLangModelBase::build( - cntNodes, - args.lmOrder, minCnts, - 2, 0, 1, 1e-5, + cntNodes, + args.lmOrder, minCnts, + 2, 0, 1, 1e-5, args.quantizeLm ? 8 : 0, sizeof(VocabTy) == 2 ? args.compressLm : false, - &bigramList, + &bigramList, args.useLmTagHistory ? &historyTx : nullptr ), archType); } @@ -2333,6 +2345,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, bool separateDefaultMorpheme, const string& morphemeDefPath, size_t morphemeDefMinCnt, + const vector>>& contextualMapper, HSDataset* splitDataset ) const { @@ -2382,14 +2395,14 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { ifstream ifs; auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); - if (splitRatio > 0) - { - throw invalid_argument("splitDataset cannot be used with binary input"); - } + double splitCnt = 0; for (auto s : cvtSents) { - sents.emplace_back(); - sents.insert_data(s.begin(), s.end()); + splitCnt += splitRatio; + auto& o = splitDataset && splitCnt >= 1 ? splitDataset->sents.get() : sents; + o.emplace_back(); + o.insert_data(s.begin(), s.end()); + splitCnt = fmod(splitCnt, 1.); } } catch (const runtime_error&) @@ -2444,6 +2457,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, dataset.totalTokens += dataset.numValidTokensInSent(i) - 1; } + if (!contextualMapper.empty()) + { + utils::ContinuousTrie> cmTrie(1); + for (auto& p : contextualMapper) + { + cmTrie.build(p.second.begin(), p.second.end(), p.first + 1); + } + cmTrie.fillFail(); + dataset.contextualMapper = utils::FrozenTrie{ cmTrie, ArchTypeHolder{} }; + } + if (splitDataset) { splitDataset->windowTokenValidness = dataset.windowTokenValidness; @@ -2454,6 +2478,30 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { splitDataset->totalTokens += splitDataset->numValidTokensInSent(i) - 1; } + + if (!contextualMapper.empty()) + { + splitDataset->contextualMapper = dataset.contextualMapper; + } } return dataset; } + +void KiwiBuilder::buildMorphData(const string& morphemeDefPath, const string& outputPath, size_t minCnt) +{ + KiwiBuilder kb; + kb.initMorphemes(); + ifstream ifs; + auto realMorph = kb.loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt) + { + return cnt >= minCnt; + }); + + size_t lmVocabSize = 0; + for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); + lmVocabSize += 1; + kb.updateForms(); + kb.updateMorphemes(lmVocabSize); + ofstream ofs; + kb.saveMorphBin(openFile(ofs, outputPath + "/sj.morph", ios_base::binary)); +} diff --git a/src/LmState.hpp b/src/LmState.hpp index c019875d..daca463c 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -4,6 +4,7 @@ #include #include "Knlm.hpp" #include "SkipBigramModel.hpp" +#include "PCLanguageModel.hpp" namespace kiwi { @@ -27,13 +28,14 @@ namespace kiwi } }; - template + template class KnLMState { friend struct Hash>; int32_t node = 0; public: static constexpr ArchType arch = _arch; + static constexpr bool transposed = _transposed; KnLMState() = default; KnLMState(const LangModel& lm) : node{ (int32_t)static_cast&>(*lm.knlm).getBosNodeIdx() } {} @@ -62,6 +64,7 @@ namespace kiwi std::array history = { {0,} }; public: static constexpr ArchType arch = _arch; + static constexpr bool transposed = false; SbgState() = default; SbgState(const LangModel& lm) : KnLMState<_arch, VocabTy>{ lm } {} @@ -101,6 +104,32 @@ namespace kiwi } }; + template + class PcLMState + { + friend struct Hash>; + int32_t node = 0; + uint32_t contextIdx = 0; + size_t historyPos = 0; + std::array history = { {0,} }; + public: + static constexpr ArchType arch = _arch; + + PcLMState() = default; + PcLMState(const LangModel& lm) {} + + bool operator==(const PcLMState& other) const + { + return node == other.node && historyPos == other.historyPos && history == other.history; + } + + float next(const LangModel& lm, VocabTy next) + { + auto& pclm = static_cast&>(*lm.pclm); + return pclm.progress(node, contextIdx, next); + } + }; + // hash for LmState template struct Hash> @@ -111,10 +140,10 @@ namespace kiwi } }; - template - struct Hash> + template + struct Hash> { - size_t operator()(const KnLMState& state) const + size_t operator()(const KnLMState& state) const { std::hash hasher; return hasher(state.node); @@ -143,12 +172,24 @@ namespace kiwi template using type = KnLMState; }; + template + struct WrappedKnLMTransposed + { + template using type = KnLMState; + }; + template struct WrappedSbg { template using type = SbgState; }; + template + struct WrappedPcLM + { + template using type = PcLMState; + }; + template class LmObject : public LmObjectBase { diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp new file mode 100644 index 00000000..ca769c9d --- /dev/null +++ b/src/PCLanguageModel.cpp @@ -0,0 +1,267 @@ +#include +#include +#include "PCLanguageModel.hpp" +#include "StrUtils.h" +#include "FrozenTrie.hpp" + +using namespace std; + +namespace kiwi +{ + namespace pclm + { + utils::MemoryObject PCLanguageModelBase::build(const string& contextDefinition, const string& embedding) + { + ifstream contextStr, embeddingStr; + if (!openFile(contextStr, contextDefinition)) + { + throw IOException{ "Cannot open file : " + contextDefinition }; + } + + uint32_t maxClusterId = 0; + using Node = utils::TrieNodeEx>>; + utils::ContinuousTrie trie(1); + { + Vector, uint32_t>> contextMap; + UnorderedMap, uint32_t> erasedContexts; + Vector context; + string line; + while (getline(contextStr, line)) + { + auto tokens = split(line, '\t'); + if (tokens.size() <= 1) + { + throw IOException{ "Invalid format : " + contextDefinition }; + } + + auto clusterId = stol(tokens[0].begin(), tokens[0].end()); + if (clusterId < 0) throw IOException{ "Invalid format : " + contextDefinition }; + context.clear(); + for (size_t i = 1; i < tokens.size(); ++i) + { + auto id = stol(tokens[i].begin(), tokens[i].end()); + if (id < 0) throw IOException{ "Invalid format : " + contextDefinition }; + context.push_back(id); + } + if (contextMap.size() < context.size()) contextMap.resize(context.size()); + contextMap[context.size() - 1][context] = (uint32_t)clusterId; + maxClusterId = max(maxClusterId, (uint32_t)(clusterId + 1)); + } + + for (size_t i = contextMap.size(); i-- > 0;) // remove redundant context + { + auto& c = contextMap[i]; + for (auto it = c.begin(); it != c.end();) + { + bool erase = false; + for (size_t j = i; j-- > 0; ) + { + auto& c2 = contextMap[j]; + context.clear(); + context.insert(context.end(), it->first.end() - j - 1, it->first.end()); + auto found = c2.find(context); + if (found != c2.end()) + { + erase = found->second == it->second; + break; + } + } + + if (erase) + { + if (it->first.size() < contextMap.size()) + { + erasedContexts.emplace(it->first, it->second); + } + it = c.erase(it); + } + else ++it; + } + } + + for (auto& c : contextMap) + { + for (auto& p : c) + { + trie.build(p.first.begin(), p.first.end(), p.second + 1); + } + } + for (auto& p : erasedContexts) + { + if (auto* node = trie.find(p.first.begin(), p.first.end())) + { + //node->val = p.second + 1; + } + } + } + + Vector nodeSizes; + nodeSizes.reserve(trie.size()); + Vector keys; + keys.reserve(trie.size()); + Vector values; + values.reserve(trie.size()); + Vector valueNewIdx(maxClusterId + 1); + { + Vector valueCnts(valueNewIdx.size()); + Vector valueArgsorted(valueNewIdx.size()); + Vector rkeys; + trie.traverseWithKeys([&](const Node* node, const Vector& rkeys) + { + nodeSizes.emplace_back(node->next.size()); + for (auto& p : node->next) + { + keys.emplace_back(p.first); + } + values.emplace_back(node->val); + valueCnts[node->val]++; + }, rkeys); + + valueCnts[0] = -1; + + // remap value idx by frequency + iota(valueArgsorted.begin(), valueArgsorted.end(), 0); + sort(valueArgsorted.begin(), valueArgsorted.end(), [&](uint32_t a, uint32_t b) { return valueCnts[a] > valueCnts[b]; }); + for (size_t i = 0; i < valueArgsorted.size(); ++i) + { + valueNewIdx[valueArgsorted[i]] = (uint32_t)i; + } + for (auto& v : values) v = valueNewIdx[v]; + } + + assert(nodeSizes.size() - 1 == keys.size()); + + Vector compressedNodeSizes(streamvbyte_max_compressedbytes(nodeSizes.size())); + compressedNodeSizes.resize(streamvbyte_encode_0124(nodeSizes.data(), nodeSizes.size(), compressedNodeSizes.data())); + Vector compressedValues(streamvbyte_max_compressedbytes(values.size())); + compressedValues.resize(streamvbyte_encode_0124(values.data(), values.size(), compressedValues.data())); + Vector compressedKeys(streamvbyte_max_compressedbytes(keys.size())); + compressedKeys.resize(streamvbyte_encode(keys.data(), keys.size(), compressedKeys.data())); + + if (!openFile(embeddingStr, embedding, ios_base::binary)) + { + throw IOException{ "Cannot open file : " + embedding }; + } + const uint32_t dim = utils::read(embeddingStr); + const uint32_t contextSize = utils::read(embeddingStr); + const uint32_t outputSize = utils::read(embeddingStr); + + Vector contextEmb(dim * contextSize); + Vector contextEmbScale(contextSize); + Vector contextEmbBias(contextSize); + Vector contextConfidence(contextSize); + Vector distantEmb(dim * outputSize); + Vector distantEmbScale(outputSize); + Vector distantEmbBias(outputSize); + Vector distantConfidence(outputSize); + Vector outputEmb(dim * outputSize); + Vector outputEmbScale(outputSize); + + embeddingStr.read((char*)contextEmb.data(), contextEmb.size()); + embeddingStr.read((char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantEmb.data(), distantEmb.size()); + embeddingStr.read((char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)outputEmb.data(), outputEmb.size()); + embeddingStr.read((char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); + + // remap context embedding + { + Vector newContextEmb(contextEmb.size()); + Vector newContextEmbScale(contextSize); + Vector newContextEmbBias(contextSize); + for (size_t i = 0; i < contextSize; ++i) + { + auto idx = valueNewIdx[i]; + auto src = contextEmb.data() + i * dim; + auto dst = newContextEmb.data() + idx * dim; + copy(src, src + dim, dst); + newContextEmbScale[idx] = contextEmbScale[i]; + newContextEmbBias[idx] = contextEmbBias[i]; + } + contextEmb = move(newContextEmb); + contextEmbScale = move(newContextEmbScale); + contextEmbBias = move(newContextEmbBias); + } + + Header header; + memset(&header, 0, sizeof(Header)); + header.dim = dim; + header.contextSize = contextSize; + header.vocabSize = outputSize; + header.keySize = 4; + header.windowSize = 8; + header.numNodes = nodeSizes.size(); + + size_t finalSize = 0; + header.nodeOffset = alignedOffsetInc(finalSize, sizeof(Header)); + header.keyOffset = alignedOffsetInc(finalSize, compressedNodeSizes.size()); + header.valueOffset = alignedOffsetInc(finalSize, compressedKeys.size()); + header.embOffset = alignedOffsetInc(finalSize, compressedValues.size()); + finalSize += dim * (contextSize + outputSize * 2); + finalSize += contextSize * sizeof(uint16_t) * 3; + finalSize += outputSize * sizeof(uint16_t) * 4; + + utils::MemoryOwner mem{ finalSize }; + utils::omstream ostr{ (char*)mem.get(), (std::ptrdiff_t)mem.size() }; + ostr.write((const char*)&header, sizeof(Header)); + writePadding(ostr); + ostr.write((const char*)compressedNodeSizes.data(), compressedNodeSizes.size()); + writePadding(ostr); + ostr.write((const char*)compressedKeys.data(), compressedKeys.size()); + writePadding(ostr); + ostr.write((const char*)compressedValues.data(), compressedValues.size()); + writePadding(ostr); + ostr.write((const char*)contextEmb.data(), contextEmb.size()); + ostr.write((const char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); + ostr.write((const char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); + ostr.write((const char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); + ostr.write((const char*)distantEmb.data(), distantEmb.size()); + ostr.write((const char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); + ostr.write((const char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); + ostr.write((const char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + ostr.write((const char*)outputEmb.data(), outputEmb.size()); + ostr.write((const char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); + return mem; + } + + template + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + { + auto& header = *reinterpret_cast(mem.get()); + switch (header.keySize) + { + case 1: + return make_unique>(std::move(mem)); + case 2: + return make_unique>(std::move(mem)); + case 4: + return make_unique>(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; + } + } + + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + + struct CreateOptimizedModelGetter + { + template + struct Wrapper + { + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + }; + }; + + std::unique_ptr PCLanguageModelBase::create(utils::MemoryObject&& mem, ArchType archType) + { + static tp::Table table{ CreateOptimizedModelGetter{} }; + auto fn = table[static_cast(archType)]; + if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; + return (*fn)(std::move(mem)); + } + } +} diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp new file mode 100644 index 00000000..afd7f461 --- /dev/null +++ b/src/PCLanguageModel.hpp @@ -0,0 +1,335 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include "ArchAvailable.h" +#include "search.h" +#include "streamvbyte.h" + + +namespace kiwi +{ + namespace pclm + { + inline float half2float(uint16_t h) + { + union + { + uint32_t i; + float f; + } u; + u.i = (uint32_t)(h & 0x8000) << 16; + u.i |= ((uint32_t)(h & 0x7FFF) + 0x1C000) << 13; + return u.f; + } + + inline void dequantize(float* out, const int8_t* ints, size_t n, float scale) + { + for (size_t i = 0; i < n; ++i) + { + out[i] = ints[i] * scale; + } + } + + template + class PCLanguageModel : public PCLanguageModelBase + { + using MyNode = Node; + + std::unique_ptr nodeData; + std::unique_ptr keyData; + std::unique_ptr valueData; + std::unique_ptr contextEmb; + std::unique_ptr contextBias; + std::unique_ptr contextConf; + std::unique_ptr distantEmb; + std::unique_ptr distantBias; + std::unique_ptr distantConf; + std::unique_ptr outputEmb; + + MyNode* findLowerNode(MyNode* node, KeyType k) const + { + while (node->lower) + { + auto* lowerNode = node + node->lower; + auto* keys = &keyData[lowerNode->nextOffset]; + auto* values = &valueData[lowerNode->nextOffset]; + int32_t found; + if (nst::search( + keys, + values, + lowerNode->numNexts, + k, + found + ) && found >= 0) + { + return lowerNode + found; + } + node = lowerNode; + } + return node; + } + + uint32_t findLowerValue(MyNode* node, KeyType k) const + { + while (node->lower) + { + auto* lowerNode = node + node->lower; + auto* keys = &keyData[lowerNode->nextOffset]; + auto* values = &valueData[lowerNode->nextOffset]; + int32_t found; + if (nst::search( + keys, + values, + lowerNode->numNexts, + k, + found + )) + { + if (found >= 0) + { + return lowerNode[found].value; + } + else + { + return -found; + } + } + node = lowerNode; + } + return node->value; + } + + public: + PCLanguageModel(utils::MemoryObject&& mem) : PCLanguageModelBase{ std::move(mem) } + { + auto* ptr = reinterpret_cast(base.get()); + auto& header = getHeader(); + + Vector nodeSizes(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); + keyData = make_unique(header.numNodes - 1); + if (std::is_same::value) + { + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); + } + else + { + Vector tempKeyData(header.numNodes - 1); + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), tempKeyData.data(), header.numNodes - 1); + std::copy(tempKeyData.begin(), tempKeyData.end(), keyData.get()); + } + Vector values(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.valueOffset), values.data(), header.numNodes); + + size_t numNonLeafNodes = 0, numLeafNodes = 0; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) numNonLeafNodes++; + else numLeafNodes++; + } + + nodeData = make_unique(numNonLeafNodes); + valueData = make_unique(header.numNodes - 1); + + size_t nonLeafIdx = 0, leafIdx = 0, nextOffset = 0; + Vector> keyRanges; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) + { + auto& node = nodeData[nonLeafIdx]; + if (!keyRanges.empty()) + { + auto& back = keyRanges.back(); + valueData[back[1]] = nonLeafIdx - back[0]; + } + node.value = values[i]; + node.numNexts = nodeSizes[i]; + node.nextOffset = nextOffset; + nextOffset += nodeSizes[i]; + keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)node.nextOffset, (size_t)(node.nextOffset + node.numNexts) }); + nonLeafIdx++; + } + else + { + auto& back = keyRanges.back(); + valueData[back[1]] = -(int32_t)values[i]; + back[1]++; + while (keyRanges.back()[1] == keyRanges.back()[2]) + { + keyRanges.pop_back(); + if (keyRanges.empty()) break; + keyRanges.back()[1]++; + } + leafIdx++; + } + } + + Vector tempBuf; + for (size_t i = 0; i < nonLeafIdx; ++i) + { + auto& node = nodeData[i]; + nst::prepare(&keyData[node.nextOffset], &valueData[node.nextOffset], node.numNexts, tempBuf); + } + + Deque dq; + for (dq.emplace_back(&nodeData[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->numNexts; ++i) + { + auto k = keyData[p->nextOffset + i]; + auto v = valueData[p->nextOffset + i]; + if (v <= 0) continue; + auto* child = &p[v]; + child->lower = findLowerNode(p, k) - child; + if (child->value == 0) + { + child->value = findLowerValue(p, k); + } + dq.emplace_back(child); + } + } + + auto* eptr = ptr + header.embOffset; + contextEmb = make_unique(header.contextSize * header.dim); + contextBias = make_unique(header.contextSize); + contextConf = make_unique(header.contextSize); + distantEmb = make_unique(header.vocabSize * header.dim); + distantBias = make_unique(header.vocabSize); + distantConf = make_unique(header.vocabSize); + outputEmb = make_unique(header.vocabSize * header.dim); + + const uint16_t* contextEmbScale = reinterpret_cast(eptr + header.contextSize * header.dim); + for (size_t i = 0; i < header.contextSize; ++i) + { + dequantize(&contextEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(contextEmbScale[i])); + eptr += header.dim; + } + eptr += header.contextSize * sizeof(uint16_t); + for (size_t i = 0; i < header.contextSize; ++i) + { + contextBias[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.contextSize; ++i) + { + contextConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + const uint16_t* distantEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); + for (size_t i = 0; i < header.vocabSize; ++i) + { + dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); + eptr += header.dim; + } + eptr += header.vocabSize * sizeof(uint16_t); + for (size_t i = 0; i < header.vocabSize; ++i) + { + distantBias[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.vocabSize; ++i) + { + distantConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + const uint16_t* outputEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); + for (size_t i = 0; i < header.vocabSize; ++i) + { + dequantize(&outputEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(outputEmbScale[i])); + eptr += header.dim; + } + } + + uint32_t progressContextNode(int32_t& nodeIdx, KeyType next) const + { + while (1) + { + int32_t v; + auto* node = &nodeData[nodeIdx]; + auto* keys = &keyData[node->nextOffset]; + auto* values = &valueData[node->nextOffset]; + PREFETCH_T0(node + node->lower); + if (!nst::search( + keys, + values, + node->numNexts, next, v + )) + { + if (!node->lower) return 0; + nodeIdx += node->lower; + PREFETCH_T0(&keyData[nodeData[nodeIdx].nextOffset]); + continue; + } + + // non-leaf node + if (v > 0) + { + nodeIdx += v; + return nodeData[nodeIdx].value; + } + // leaf node + else + { + while (node->lower) + { + node += node->lower; + int32_t lv; + if (nst::search( + &keyData[node->nextOffset], + &valueData[node->nextOffset], + node->numNexts, next, lv + )) + { + if (lv > 0) + { + node += lv; + nodeIdx = node - &nodeData[0]; + return (uint32_t)-v; + } + } + } + nodeIdx = 0; + return (uint32_t)-v; + } + } + } + + float progress(int32_t& nodeIdx, uint32_t& contextIdx, KeyType next) const + { + const auto& header = getHeader(); + Eigen::Map contextVec{ &contextEmb[contextIdx * header.dim], header.dim }; + Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; + float ll = contextVec.dot(outputVec) - contextBias[contextIdx]; + contextIdx = progressContextNode(nodeIdx, next); + return ll; + } + }; + + static constexpr size_t serialAlignment = 16; + + inline size_t alignedOffsetInc(size_t& offset, size_t inc, size_t alignment = serialAlignment) + { + return offset = (offset + inc + alignment - 1) & ~(alignment - 1); + } + + inline std::ostream& writePadding(std::ostream& os, size_t alignment = serialAlignment) + { + const size_t pos = os.tellp(); + size_t pad = ((pos + alignment - 1) & ~(alignment - 1)) - pos; + for (size_t i = 0; i < pad; ++i) + { + os.put(0); + } + return os; + } + } +} diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index d89ffe36..08f36941 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -55,7 +55,85 @@ namespace kiwi top1Small, }; - class PathEvaluator + template + struct PathEvaluator + { + template + static void eval(const Kiwi* kw, + const KGraphNode* startNode, + const KGraphNode* node, + const size_t topN, + Vector>>& cache, + const Vector& ownFormList, + size_t i, + size_t ownFormId, + CandTy&& cands, + bool unknownForm, + const Vector& prevSpStates, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ); + + template + static void evalSingleMorpheme( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Morpheme* curMorph, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ); + }; + + // evaluator using transposed order + template + struct PathEvaluator::type> + { + template + static void eval(const Kiwi* kw, + const KGraphNode* startNode, + const KGraphNode* node, + const size_t topN, + Vector>>& cache, + const Vector& ownFormList, + size_t i, + size_t ownFormId, + CandTy&& cands, + bool unknownForm, + const Vector& prevSpStates, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ); + + template + static void evalMorphemes( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const Vector& morphScores, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ); + }; + + struct BestPathFinder { public: struct Result @@ -111,7 +189,7 @@ namespace kiwi }; template - static Vector findBestPath(const Kiwi* kw, + static Vector find(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, @@ -122,43 +200,9 @@ namespace kiwi bool mergeSaisiot = false, const std::unordered_set* blocklist = nullptr ); - - template - static void evalPath(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - - template - static void evalSingleMorpheme( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ); }; - using FnFindBestPath = decltype(&PathEvaluator::findBestPath>); + using FnFindBestPath = decltype(&BestPathFinder::find>); template class LmState> struct FindBestPathGetter @@ -166,7 +210,7 @@ namespace kiwi template struct Wrapper { - static constexpr FnFindBestPath value = &PathEvaluator::findBestPath(i)>>; + static constexpr FnFindBestPath value = &BestPathFinder::find(i)>>; }; }; @@ -633,8 +677,9 @@ namespace kiwi } }; - template - void PathEvaluator::evalSingleMorpheme( + template + template + void PathEvaluator::evalSingleMorpheme( Vector>& resultOut, const Kiwi* kw, const Vector& ownForms, @@ -657,7 +702,7 @@ namespace kiwi const auto spacePenalty = kw->spacePenalty; const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; - const size_t langVocabSize = langMdl.knlm->getHeader().vocab_size; + const size_t langVocabSize = langMdl.vocabSize(); const Morpheme* lastMorph; Wid firstWid; @@ -823,8 +868,9 @@ namespace kiwi return; } - template - void PathEvaluator::evalPath(const Kiwi* kw, + template + template + void PathEvaluator::eval(const Kiwi* kw, const KGraphNode* startNode, const KGraphNode* node, const size_t topN, @@ -841,7 +887,7 @@ namespace kiwi const std::unordered_set* blocklist ) { - const size_t langVocabSize = kw->langMdl.knlm->getHeader().vocab_size; + const size_t langVocabSize = kw->langMdl.vocabSize(); auto& nCache = cache[i]; Vector> refCache; @@ -998,9 +1044,603 @@ namespace kiwi nCache.resize(validCount); } + template + struct LmEvalData + { + LmState state; + float score = 0; + uint32_t length = 0; + }; + + template + template + void PathEvaluator::type>::evalMorphemes( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const Vector& morphScores, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ) + { + thread_local Vector> bestPathIndicesSmall; + thread_local Vector> bestPathValuesSmall; + thread_local UnorderedMap, WordLL> bestPathes; + // pair: [index, size] + thread_local UnorderedMap, pair> bestPathIndices; + thread_local Vector> bestPathValues; + thread_local Vector rootIds; + thread_local Vector> evalMatrix; + thread_local Vector nextWids; + + const LangModel& langMdl = kw->langMdl; + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const size_t langVocabSize = langMdl.vocabSize(); + + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + totalPrevPathes += cache[prev - startNode].size(); + } + evalMatrix.resize(totalPrevPathes * morphs.size()); + nextWids.clear(); + + const bool useContainerForSmall = totalPrevPathes <= 48; + + size_t prevId = -1; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) + { + ++prevId; + + const kchar_t* leftFormFirst, * leftFormLast; + if (prevPath.ownFormId) + { + leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); + leftFormLast = leftFormFirst + ownForms[0].size(); + } + else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) + { + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); + } + else + { + leftFormFirst = prevPath.morpheme->getForm().data(); + leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); + } + const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; + + for (size_t curId = 0; curId < morphs.size(); ++curId) + { + const auto curMorph = morphs[curId]; + float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount + morphScores[curId]; + Wid firstWid; + if (curMorph->chunks.empty() || curMorph->complex) + { + firstWid = curMorph->lmMorphemeId; + } + else + { + firstWid = curMorph->chunks[0]->lmMorphemeId; + } + + if (prevPath.combineSocket) + { + // merge with only the same socket + if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) + { + goto invalidCandidate; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) candScore -= spacePenalty; + else goto invalidCandidate; + } + firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; + } + + const CondVowel cvowel = curMorph->vowel; + const CondPolarity cpolar = curMorph->polar; + + if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) + { + // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 + } + else if (ignoreCondScore) + { + candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + } + else + { + if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) goto invalidCandidate; + } + + size_t length = 0; + if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex)) + { + // no op + } + else + { + if (morphBase[firstWid].tag == POSTag::p) + { + goto invalidCandidate; + } + + if (curMorph->chunks.empty() || curMorph->complex) + { + length = 1; + } + else + { + length = curMorph->chunks.size(); + for (size_t i = 1; i < length; ++i) + { + const Wid wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto invalidCandidate; + } + } + } + } + evalMatrix[prevId * morphs.size() + curId].state = prevPath.lmState; + evalMatrix[prevId * morphs.size() + curId].score = candScore; + evalMatrix[prevId * morphs.size() + curId].length = length; + if (length > 0) nextWids.emplace_back(firstWid); + if (length > 1) + { + for (size_t i = 1; i < length; ++i) + { + nextWids.emplace_back(curMorph->chunks[i]->lmMorphemeId); + } + } + continue; + invalidCandidate: + evalMatrix[prevId * morphs.size() + curId].score = -INFINITY; + evalMatrix[prevId * morphs.size() + curId].length = 0; + } + } + } + + { + size_t widOffset = 0; + for (auto& e : evalMatrix) + { + //if (e.length == 0) continue; + float score = 0; + for (size_t i = 0; i < e.length; ++i) + { + score += e.state.next(langMdl, nextWids[widOffset + i]); + } + e.score += score; + widOffset += e.length; + } + } + + for (size_t curId = 0; curId < morphs.size(); ++curId) + { + const auto curMorph = morphs[curId]; + + if (top1) + { + if (useContainerForSmall) + { + bestPathIndicesSmall.clear(); + bestPathValuesSmall.clear(); + } + else + { + bestPathes.clear(); + } + } + else + { + bestPathIndices.clear(); + bestPathValues.clear(); + } + + + const Morpheme* lastMorph; + if (curMorph->chunks.empty() || curMorph->complex) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + + size_t prevId = -1; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) + { + ++prevId; + auto& em = evalMatrix[prevId * morphs.size() + curId]; + if (em.score < -99999) + { + continue; + } + + if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) + { + rootIds.resize(prevSpStates.size()); + iota(rootIds.begin(), rootIds.end(), 0); + } + else + { + rootIds.resize(1); + rootIds[0] = commonRootId; + } + + for (auto rootId : rootIds) + { + const auto* prevMorpheme = &morphBase[prevPath.wid]; + auto spState = prevPath.spState; + if (rootId != commonRootId) + { + spState = prevSpStates[rootId]; + } + const float candScoreWithRule = em.score + ruleBasedScorer(prevMorpheme, spState); + + // update special state + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) + { + spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + } + + PathHash ph{ em.state, prevPath.rootId, spState }; + if (top1) + { + if (useContainerForSmall) + { + const size_t foundIdx = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph) - bestPathIndicesSmall.begin(); + if (foundIdx >= bestPathIndicesSmall.size()) + { + bestPathIndicesSmall.emplace_back(ph); + bestPathValuesSmall.emplace_back(curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); + if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; + } + else + { + auto& target = bestPathValuesSmall[foundIdx]; + if (candScoreWithRule > target.accScore) + { + target = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; + if (rootId != commonRootId) target.rootId = rootId; + } + } + } + else + { + WordLL newPath{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; + if (rootId != commonRootId) newPath.rootId = rootId; + auto inserted = bestPathes.emplace(ph, newPath); + if (!inserted.second) + { + auto& target = inserted.first->second; + if (candScoreWithRule > target.accScore) + { + target = newPath; + } + } + } + } + else + { + auto inserted = bestPathIndices.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); + if (inserted.second) + { + bestPathValues.emplace_back(curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); + if (rootId != commonRootId) bestPathValues.back().rootId = rootId; + bestPathValues.resize(bestPathValues.size() + topN - 1); + } + else + { + auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; + auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; + if (distance(bestPathFirst, bestPathLast) < topN) + { + *bestPathLast = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; + if (rootId != commonRootId) bestPathLast->rootId = rootId; + push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); + ++inserted.first->second.second; + } + else + { + if (candScoreWithRule > bestPathFirst->accScore) + { + pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + *(bestPathLast - 1) = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; + if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; + push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + } + } + } + } + } + } + } + + if (top1) + { + if (useContainerForSmall) + { + for (auto& p : bestPathValuesSmall) + { + resultOut.emplace_back(move(p)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + if (curMorph->chunks.empty() || curMorph->complex) + { + newPath.wid = lastSeqId; + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + else + { + newPath.wid = lastSeqId; + } + } + } + else + { + for (auto& p : bestPathes) + { + resultOut.emplace_back(move(p.second)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + if (curMorph->chunks.empty() || curMorph->complex) + { + newPath.wid = lastSeqId; + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + else + { + newPath.wid = lastSeqId; + } + } + } + } + else + { + for (auto& p : bestPathIndices) + { + const auto index = p.second.first; + const auto size = p.second.second; + for (size_t i = 0; i < size; ++i) + { + resultOut.emplace_back(move(bestPathValues[index + i])); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + if (curMorph->chunks.empty() || curMorph->complex) + { + newPath.wid = lastSeqId; + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + else + { + newPath.wid = lastSeqId; + } + } + } + } + } + } + + template + template + void PathEvaluator::type>::eval(const Kiwi* kw, + const KGraphNode* startNode, + const KGraphNode* node, + const size_t topN, + Vector>>& cache, + const Vector& ownFormList, + size_t i, + size_t ownFormId, + CandTy&& cands, + bool unknownForm, + const Vector& prevSpStates, + bool splitComplex, + bool splitSaisiot, + bool mergeSaisiot, + const std::unordered_set* blocklist + ) + { + thread_local Vector maxScores; + thread_local Vector validMorphCands; + thread_local Vector lbScores; + const size_t langVocabSize = kw->langMdl.vocabSize(); + auto& nCache = cache[i]; + + float whitespaceDiscount = 0; + if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) + { + whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; + } + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) + { + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); + } + + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; + const Morpheme* zCodaMorph = nullptr; + const Morpheme* zSiotMorph = nullptr; + validMorphCands.clear(); + lbScores.clear(); + for (auto& curMorph : cands) + { + if (splitComplex && curMorph->getCombined()->complex) continue; + if (blocklist && blocklist->count(curMorph->getCombined())) continue; + + // 덧붙은 받침(zCoda)을 위한 지름길 + if (curMorph->tag == POSTag::z_coda) + { + zCodaMorph = curMorph; + continue; + } + else if (curMorph->tag == POSTag::z_siot) + { + zSiotMorph = curMorph; + continue; + } + + if (!curMorph->chunks.empty() && !curMorph->complex) + { + // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 + if (node->prev && node[-(int)node->prev].endPos < node->startPos + && curMorph->kform + && curMorph->kform->size() == 1 + && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') + && curMorph->chunks[0]->kform + && curMorph->chunks[0]->kform->size() == 1 + && (*curMorph->chunks[0]->kform)[0] == u'하') + { + continue; + } + } + validMorphCands.emplace_back(curMorph); + lbScores.emplace_back(kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag)); + } + + for (bool ignoreCond : {false, true}) + { + // 덧붙은 받침(zCoda)을 위한 지름길 + if (zCodaMorph) + { + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isJClass(lastTag) && !isEClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += zCodaMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zCodaMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[zCodaMorph->lmMorphemeId]; + newPath.wid = zCodaMorph->lmMorphemeId; + } + } + } + // 사이시옷(zSiot)을 위한 지름길 + if (zSiotMorph) + { + if (!(splitSaisiot || mergeSaisiot)) + { + continue; + } + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isNNClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += zSiotMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zSiotMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[zSiotMorph->lmMorphemeId]; + newPath.wid = zSiotMorph->lmMorphemeId; + } + } + } + + if (topN == 1) + { + evalMorphemes(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + else + { + evalMorphemes(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + if (!nCache.empty()) break; + } + + maxScores.clear(); + maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); + + if (topN == 1) + { + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + maxScores[rootId] = max(maxScores[rootId], c.accScore); + } + } + else + { + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + if (c.accScore > maxScores[rootId * topN]) + { + pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + maxScores[rootId * topN + topN - 1] = c.accScore; + push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + } + } + } + + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) + { + const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; + if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; + } + nCache.resize(validCount); + } + template - inline PathEvaluator::Path generateTokenList(const WordLL* result, + inline BestPathFinder::Path generateTokenList(const WordLL* result, const utils::ContainerSearcher>& csearcher, const KGraphNode* graph, const Vector& ownFormList, @@ -1021,7 +1661,7 @@ namespace kiwi return morphFirst + morph->lmMorphemeId; }; - PathEvaluator::Path ret; + BestPathFinder::Path ret; const WordLL* prev = steps.back()->parent; for (auto it = steps.rbegin(); it != steps.rend(); ++it) { @@ -1114,7 +1754,7 @@ namespace kiwi } template - Vector PathEvaluator::findBestPath(const Kiwi* kw, + Vector BestPathFinder::find(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, @@ -1132,7 +1772,7 @@ namespace kiwi Vector ownFormList; Vector unknownNodeCands, unknownNodeLCands; - const size_t langVocabSize = kw->langMdl.knlm->getHeader().vocab_size; + const size_t langVocabSize = kw->langMdl.vocabSize(); const KGraphNode* startNode = graph; const KGraphNode* endNode = graph + graphSize - 1; @@ -1163,6 +1803,7 @@ namespace kiwi } #endif + using Evaluator = PathEvaluator; // middle nodes for (size_t i = 1; i < graphSize - 1; ++i) { @@ -1176,7 +1817,7 @@ namespace kiwi if (node->form) { - evalPath(kw, startNode, node, topN, cache, + Evaluator::eval(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, node->form->candidate, false, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); if (all_of(node->form->candidate.begin(), node->form->candidate.end(), [](const Morpheme* m) @@ -1186,14 +1827,14 @@ namespace kiwi { ownFormList.emplace_back(node->form->form); ownFormId = ownFormList.size(); - evalPath(kw, startNode, node, topN, cache, + Evaluator::eval(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, unknownNodeLCands, true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); }; } else { - evalPath(kw, startNode, node, topN, cache, + Evaluator::eval(kw, startNode, node, topN, cache, ownFormList, i, ownFormId, unknownNodeCands, true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); } diff --git a/src/SkipBigramModel.cpp b/src/SkipBigramModel.cpp new file mode 100644 index 00000000..f290779a --- /dev/null +++ b/src/SkipBigramModel.cpp @@ -0,0 +1,45 @@ +#include "SkipBigramModel.hpp" + +namespace kiwi +{ + namespace sb + { + template + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + { + auto& header = *reinterpret_cast(mem.get()); + switch (header.keySize) + { + case 1: + return make_unique>(std::move(mem)); + case 2: + return make_unique>(std::move(mem)); + case 4: + return make_unique>(std::move(mem)); + case 8: + return make_unique>(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; + } + } + + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + + struct CreateOptimizedModelGetter + { + template + struct Wrapper + { + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + }; + }; + + std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& mem, ArchType archType) + { + static tp::Table table{ CreateOptimizedModelGetter{} }; + auto fn = table[static_cast(archType)]; + if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; + return (*fn)(std::move(mem)); + } + } +} diff --git a/src/SkipBigramModel.hpp b/src/SkipBigramModel.hpp index 160afd3a..9e9eee2b 100644 --- a/src/SkipBigramModel.hpp +++ b/src/SkipBigramModel.hpp @@ -97,43 +97,5 @@ namespace kiwi float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const; }; - - template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) - { - auto& header = *reinterpret_cast(mem.get()); - switch (header.keySize) - { - case 1: - return make_unique>(std::move(mem)); - case 2: - return make_unique>(std::move(mem)); - case 4: - return make_unique>(std::move(mem)); - case 8: - return make_unique>(std::move(mem)); - default: - throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; - } - } - - using FnCreateOptimizedModel = decltype(&createOptimizedModel); - - struct CreateOptimizedModelGetter - { - template - struct Wrapper - { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; - }; - }; - - inline std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& mem, ArchType archType) - { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; - if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; - return (*fn)(std::move(mem)); - } } } diff --git a/src/capi/kiwi_c.cpp b/src/capi/kiwi_c.cpp index 1ba7abf4..f7d61ec5 100644 --- a/src/capi/kiwi_c.cpp +++ b/src/capi/kiwi_c.cpp @@ -111,7 +111,7 @@ kiwi_builder_h kiwi_builder_init(const char* model_path, int num_threads, int op { BuildOption buildOption = (BuildOption)(options & 0xFF); bool useSBG = !!(options & KIWI_BUILD_MODEL_TYPE_SBG); - return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, useSBG}; + return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, useSBG ? ModelType::sbg : ModelType::knlm }; } catch (...) { diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 09fdee5c..17906b87 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -436,7 +436,7 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., tokenFilter, {}, 0.1, false, {}, 0, &devset); + trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., tokenFilter, {}, 0.1, false, {}, 0, {}, &devset); for (size_t i = 0; i < 2; ++i) { { @@ -505,6 +505,7 @@ TEST(KiwiCpp, SentenceBoundaryErrors) EXPECT_EQ(sentRanges.size(), 1); if (sentRanges.size() > 1) { + kiwi.splitIntoSents(str, Match::allWithNormalizing, &res); for (auto& r : sentRanges) { std::cerr << std::string{ &str[r.first], r.second - r.first } << std::endl; @@ -612,12 +613,14 @@ TEST(KiwiCpp, FalsePositiveSB) u"도서전에서 관람객의 관심을 받을 것으로 예상되는 프로그램으로는 '인문학 아카데미'가 있어요. 이 프로그램에서는 유시민 전 의원, 광고인 박웅현 씨 등이 문화 역사 미학 등 다양한 분야에 대해 강의할 예정이다. 또한, '북 멘토 프로그램'도 이어져요. 이 프로그램에서는 각 분야 전문가들이 경험과 노하우를 전수해 주는 프로그램으로, 시 창작(이정록 시인), 번역(강주헌 번역가), 북 디자인(오진경 북디자이너) 등의 분야에서 멘토링이 이뤄져요.", }) { - auto tokens = kiwi.analyze(str, 10, Match::allWithNormalizing)[0].first; + auto res = kiwi.analyze(str, 1, Match::allWithNormalizing); + auto tokens = res[0].first; auto sbCount = std::count_if(tokens.begin(), tokens.end(), [](const TokenInfo& t) { return t.tag == POSTag::sb; }); EXPECT_EQ(sbCount, 0); + kiwi.analyze(str, 10, Match::allWithNormalizing); } } @@ -911,12 +914,14 @@ TEST(KiwiCpp, AnalyzeWithLoadDefaultDict) TEST(KiwiCpp, AnalyzeSBG) { - Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, true }.build(); - kiwi.analyze(TEST_SENT, Match::all); - - auto tokens = kiwi.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", kiwi::Match::allWithNormalizing).first; - EXPECT_EQ(tokens.size(), 11); - EXPECT_EQ(tokens[8].str, u"걸"); + Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, ModelType::knlm }.build(); + Kiwi kiwiSbg = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, ModelType::sbg }.build(); + kiwiSbg.analyze(TEST_SENT, Match::all); + + auto res = kiwi.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", 3, kiwi::Match::allWithNormalizing); + auto resSbg = kiwiSbg.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", 3, kiwi::Match::allWithNormalizing); + EXPECT_EQ(resSbg[0].first.size(), 11); + EXPECT_EQ(resSbg[0].first[8].str, u"걸"); } TEST(KiwiCpp, AnalyzeMultithread) diff --git a/third_party/streamvbyte b/third_party/streamvbyte new file mode 160000 index 00000000..f27641e3 --- /dev/null +++ b/third_party/streamvbyte @@ -0,0 +1 @@ +Subproject commit f27641e3194d14d667e30928a418685d943ab62c diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index 4913add8..9dc876a4 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -11,8 +11,20 @@ using namespace std; using namespace kiwi; +const char* modelTypeToStr(ModelType type) +{ + switch (type) + { + case ModelType::knlm: return "knlm"; + case ModelType::knlmTransposed: return "knlm-transposed"; + case ModelType::sbg: return "sbg"; + case ModelType::pclm: return "pclm"; + } + return "unknown"; +} + int doEvaluate(const string& modelPath, const string& output, const vector& input, - bool normCoda, bool zCoda, bool multiDict, bool useSBG, + bool normCoda, bool zCoda, bool multiDict, ModelType modelType, float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, int repeat) { @@ -48,14 +60,18 @@ int doEvaluate(const string& modelPath, const string& output, const vector 0) kw.setTypoCostWeight(typoCostWeight); cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; cout << "ArchType : " << archToStr(kw.archType()) << endl; - cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; + cout << "Model Type : " << modelTypeToStr(kw.modelType()) << endl; + if (kw.getKnLM()) + { + cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; + } cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB\n" << endl; double avgMicro = 0, avgMacro = 0; @@ -130,7 +146,7 @@ int main(int argc, const char* argv[]) SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; - SwitchArg useSBG{ "", "sbg", "use SkipBigram", false }; + ValueArg modelType{ "t", "type", "model type", false, "knlm", "string" }; ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float"}; SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; @@ -144,7 +160,7 @@ int main(int argc, const char* argv[]) cmd.add(noNormCoda); cmd.add(noZCoda); cmd.add(noMulti); - cmd.add(useSBG); + cmd.add(modelType); cmd.add(typoWeight); cmd.add(bTypo); cmd.add(cTypo); @@ -160,7 +176,32 @@ int main(int argc, const char* argv[]) cerr << "error: " << e.error() << " for arg " << e.argId() << endl; return -1; } + ModelType kiwiModelType = ModelType::none; + { + auto v = modelType.getValue(); + if (v == "knlm") + { + kiwiModelType = ModelType::knlm; + } + else if (v == "sbg") + { + kiwiModelType = ModelType::sbg; + } + else if (v == "knlm-transposed") + { + kiwiModelType = ModelType::knlmTransposed; + } + else if (v == "pclm") + { + kiwiModelType = ModelType::pclm; + } + else + { + cerr << "Invalid model type" << endl; + return -1; + } + } return doEvaluate(model, output, files.getValue(), - !noNormCoda, !noZCoda, !noMulti, useSBG, typoWeight, bTypo, cTypo, lTypo, repeat); + !noNormCoda, !noZCoda, !noMulti, kiwiModelType, typoWeight, bTypo, cTypo, lTypo, repeat); } diff --git a/tools/pclm_builder.cpp b/tools/pclm_builder.cpp new file mode 100644 index 00000000..1542e80e --- /dev/null +++ b/tools/pclm_builder.cpp @@ -0,0 +1,62 @@ +#include +#include + +#include +#include +#include +#include "toolUtils.h" + +using namespace std; +using namespace kiwi; + +int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, size_t minCnt, const std::string& output) +{ + try + { + tutils::Timer timer; + KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); + auto ret = pclm::PCLanguageModelBase::build(contextDef, embedding); + ret.writeToFile(output + "/pclm.mdl"); + double tm = timer.getElapsed(); + cout << "Total: " << tm << " ms " << endl; + return 0; + } + catch (const exception& e) + { + cerr << e.what() << endl; + return -1; + } +} + +using namespace TCLAP; + +int main(int argc, const char* argv[]) +{ + tutils::setUTF8Output(); + + CmdLine cmd{ "Kiwi PCLanguageModel Builder", ' ', "0.21.0" }; + + ValueArg mdef{ "m", "morpheme-def", "morpheme definition", true, "", "string" }; + ValueArg cdef{ "c", "context-def", "context definition", true, "", "string" }; + ValueArg emb{ "e", "emb", "embedding file", true, "", "string" }; + ValueArg minCnt{ "n", "min-cnt", "min count of morpheme", false, 10, "int" }; + ValueArg output{ "o", "output", "", true, "", "string" }; + + cmd.add(mdef); + cmd.add(cdef); + cmd.add(emb); + cmd.add(minCnt); + cmd.add(output); + + try + { + cmd.parse(argc, argv); + } + catch (const ArgException& e) + { + cerr << "error: " << e.error() << " for arg " << e.argId() << endl; + return -1; + } + + return run(mdef, cdef, emb, minCnt, output); +} diff --git a/vsproj/build_pclm.vcxproj b/vsproj/build_pclm.vcxproj new file mode 100644 index 00000000..dd26dd5a --- /dev/null +++ b/vsproj/build_pclm.vcxproj @@ -0,0 +1,225 @@ + + + + + Debug + ARM64 + + + Debug + Win32 + + + Release + ARM64 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {C63940BA-24B0-452C-A618-E435888BB45C} + Win32Proj + KiwiRun + 10.0 + + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + true + v143 + Unicode + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + false + v143 + true + Unicode + + + + + + + + + + + + + + + + + + + + + + + + + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + + Level3 + NotUsing + MaxSpeed + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + + + Console + true + true + + + + + Level3 + NotUsing + MaxSpeed + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + + + Console + true + true + + + + + NotUsing + Level3 + Disabled + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + + + Console + + + + + NotUsing + Level3 + Disabled + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + + + Console + + + + + NotUsing + Level3 + Disabled + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + + + Console + + + + + Level3 + NotUsing + MaxSpeed + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + + + Console + true + true + + + + + {f790bc37-2732-4ed1-9ca5-7248bed3588e} + + + + + + + + + \ No newline at end of file diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index aca1d9e3..58754a20 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -40,6 +40,7 @@ + @@ -69,6 +70,7 @@ + @@ -120,7 +122,9 @@ + + @@ -136,6 +140,13 @@ + + + + + + + {F790BC37-2732-4ED1-9CA5-7248BED3588E} @@ -213,7 +224,7 @@ true - $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\include\;$(IncludePath) + $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\third_party/streamvbyte/include;$(ProjectDir)..\include\;$(IncludePath) true @@ -225,7 +236,7 @@ false - $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\include\;$(IncludePath) + $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\third_party/streamvbyte/include;$(ProjectDir)..\include\;$(IncludePath) false @@ -251,7 +262,7 @@ NotUsing Level3 Disabled - KIWI_ARCH_X86_64=1;KIWI_USE_MIMALLOC;_DEBUG;_CONSOLE;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_MIMALLOC;_DEBUG;_CONSOLE;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) true MultiThreadedDebug /Qvec-report:1 /utf-8 /D _CRT_SECURE_NO_WARNINGS=1 /bigobj %(AdditionalOptions) @@ -314,7 +325,7 @@ MaxSpeed true true - KIWI_ARCH_X86_64=1;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) true AdvancedVectorExtensions2 /Qvec-report:1 /utf-8 /D _CRT_SECURE_NO_WARNINGS=1 /bigobj %(AdditionalOptions) From 8c02638eabdf8e1cf93ff204c63afe519d1a3aa1 Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 21 Jan 2025 00:21:21 +0900 Subject: [PATCH 03/53] Refactor PathEvaluator.hpp --- include/kiwi/Kiwi.h | 2 +- src/Kiwi.cpp | 96 +-------- src/PathEvaluator.cpp | 126 +++++++++++ src/PathEvaluator.h | 105 ++++++++++ src/PathEvaluator.hpp | 321 +++-------------------------- vsproj/kiwi_shared_library.vcxproj | 4 + 6 files changed, 274 insertions(+), 380 deletions(-) create mode 100644 src/PathEvaluator.cpp create mode 100644 src/PathEvaluator.h diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index ecc874c2..981fe8ba 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -86,7 +86,7 @@ namespace kiwi std::shared_ptr combiningRule; std::unique_ptr pool; - inline const Morpheme* getDefaultMorpheme(POSTag tag) const; + const Morpheme* getDefaultMorpheme(POSTag tag) const; template cmb::AutoJoiner newJoinerImpl() const diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 927a9f2c..b00e8b20 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -54,95 +54,7 @@ namespace kiwi continualTypoTolerant, lengtheningTypoTolerant); dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant); - - static tp::Table lmKnLM_8{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_32{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_64{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_8{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_16{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_32{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_8{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_16{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_32{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_64{ FindBestPathGetter::type>{} }; - static tp::Table lmPcLM_16{ FindBestPathGetter::type>{} }; - static tp::Table lmPcLM_32{ FindBestPathGetter::type>{} }; - - if (langMdl.type == ModelType::sbg) - { - switch (langMdl.sbg->getHeader().keySize) - { - case 1: - dfFindBestPath = (void*)lmSbg_8[static_cast(selectedArch)]; - break; - case 2: - dfFindBestPath = (void*)lmSbg_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmSbg_32[static_cast(selectedArch)]; - break; - case 8: - dfFindBestPath = (void*)lmSbg_64[static_cast(selectedArch)]; - break; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if(langMdl.type == ModelType::knlm) - { - switch (langMdl.knlm->getHeader().key_size) - { - case 1: - dfFindBestPath = (void*)lmKnLM_8[static_cast(selectedArch)]; - break; - case 2: - dfFindBestPath = (void*)lmKnLM_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmKnLM_32[static_cast(selectedArch)]; - break; - case 8: - dfFindBestPath = (void*)lmKnLM_64[static_cast(selectedArch)]; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::knlmTransposed) - { - switch (langMdl.knlm->getHeader().key_size) - { - case 1: - dfFindBestPath = (void*)lmKnLMT_8[static_cast(selectedArch)]; - break; - case 2: - dfFindBestPath = (void*)lmKnLMT_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmKnLMT_32[static_cast(selectedArch)]; - break; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::pclm) - { - switch (langMdl.pclm->getHeader().keySize) - { - case 2: - dfFindBestPath = (void*)lmPcLM_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmPcLM_32[static_cast(selectedArch)]; - break; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else - { - throw Exception{ "Unsupported model type" }; - } + dfFindBestPath = (void*)getFindBestPathFn(selectedArch, langMdl); } Kiwi::~Kiwi() = default; @@ -690,7 +602,7 @@ namespace kiwi inline void insertPathIntoResults( vector& ret, Vector& spStatesByRet, - const Vector& pathes, + const Vector& pathes, size_t topN, Match matchOptions, bool integrateAllomorph, @@ -716,7 +628,7 @@ namespace kiwi Vector selectedPathes(pathes.size()); for (size_t i = 0; i < ret.size(); ++i) { - auto pred = [&](const BestPathFinder::ChunkResult& p) + auto pred = [&](const PathResult& p) { return p.prevState == spStatesByRet[i]; }; @@ -1098,7 +1010,7 @@ namespace kiwi if (nodes.size() <= 2) continue; findPretokenizedGroupOfNode(nodeInWhichPretokenized, nodes, pretokenizedPrev, pretokenizedFirst); - Vector res = (*reinterpret_cast(dfFindBestPath))( + Vector res = (*reinterpret_cast(dfFindBestPath))( this, spStatesByRet, nodes.data(), diff --git a/src/PathEvaluator.cpp b/src/PathEvaluator.cpp new file mode 100644 index 00000000..77f1e664 --- /dev/null +++ b/src/PathEvaluator.cpp @@ -0,0 +1,126 @@ +#include "PathEvaluator.hpp" + +using namespace std; + +namespace kiwi +{ + template class LmState> + struct FindBestPathGetter + { + template + struct Wrapper + { + static constexpr FnFindBestPath value = &BestPathFinder::findBestPath(i)>>; + }; + }; + + + template + inline FnFindBestPath getPcLMFindBestPath(ArchType archType, size_t windowSize) + { + static tp::Table w4{ FindBestPathGetter::type>{} }; + static tp::Table w7{ FindBestPathGetter::type>{} }; + static tp::Table w8{ FindBestPathGetter::type>{} }; + switch (windowSize) + { + case 4: + return w4[static_cast(archType)]; + case 7: + return w7[static_cast(archType)]; + case 8: + return w8[static_cast(archType)]; + default: + throw Exception{ "Unsupported `window_size` : " + to_string(windowSize) }; + } + } + + FnFindBestPath getFindBestPathFn(ArchType archType, const LangModel& langMdl) + { + const auto archIdx = static_cast(archType); + static tp::Table lmKnLM_8{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLM_32{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_8{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_16{ FindBestPathGetter::type>{} }; + static tp::Table lmKnLMT_32{ FindBestPathGetter::type>{} }; + static tp::Table lmSbg_8{ FindBestPathGetter::type>{} }; + static tp::Table lmSbg_16{ FindBestPathGetter::type>{} }; + static tp::Table lmSbg_32{ FindBestPathGetter::type>{} }; + + if (langMdl.type == ModelType::sbg) + { + switch (langMdl.sbg->getHeader().keySize) + { + case 1: + return lmSbg_8[archIdx]; + case 2: + return lmSbg_16[archIdx]; + case 4: + return lmSbg_32[archIdx]; + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else if (langMdl.type == ModelType::knlm) + { + switch (langMdl.knlm->getHeader().key_size) + { + case 1: + return lmKnLM_8[archIdx]; + break; + case 2: + return lmKnLM_16[archIdx]; + break; + case 4: + return lmKnLM_32[archIdx]; + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else if (langMdl.type == ModelType::knlmTransposed) + { + switch (langMdl.knlm->getHeader().key_size) + { + case 1: + return lmKnLMT_8[archIdx]; + break; + case 2: + return lmKnLMT_16[archIdx]; + break; + case 4: + return lmKnLMT_32[archIdx]; + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else if (langMdl.type == ModelType::pclm) + { + switch (langMdl.pclm->getHeader().keySize) + { + case 2: + return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); + case 4: + return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else if (langMdl.type == ModelType::pclmLocal) + { + switch (langMdl.pclm->getHeader().keySize) + { + case 2: + return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); + case 4: + return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); + default: + throw Exception{ "Wrong `lmKeySize`" }; + } + } + else + { + + } + return nullptr; + } +} diff --git a/src/PathEvaluator.h b/src/PathEvaluator.h new file mode 100644 index 00000000..61823660 --- /dev/null +++ b/src/PathEvaluator.h @@ -0,0 +1,105 @@ +#pragma once +#include + +namespace kiwi +{ + struct SpecialState + { + uint8_t singleQuote : 1; + uint8_t doubleQuote : 1; + uint8_t bulletHash : 6; + + SpecialState() : singleQuote{ 0 }, doubleQuote{ 0 }, bulletHash{ 0 } + { + } + + operator uint8_t() const + { + return reinterpret_cast(*this); + } + + bool operator<(const SpecialState& o) const + { + return (uint8_t)(*this) < (uint8_t)o; + } + + bool operator==(const SpecialState& o) const + { + return (uint8_t)(*this) == (uint8_t)o; + } + }; + + struct PathNode + { + const Morpheme* morph = nullptr; + KString str; + uint32_t begin = 0, end = 0; + float wordScore = 0, typoCost = 0; + uint32_t typoFormId = 0; + uint32_t nodeId = 0; + + PathNode(const Morpheme* _morph = nullptr, + const KString& _str = {}, + uint32_t _begin = 0, + uint32_t _end = 0, + float _wordScore = 0, + float _typoCost = 0, + uint32_t _typoFormId = 0, + uint32_t _nodeId = 0 + ) + : morph{ _morph }, str{ _str }, begin{ _begin }, end{ _end }, + wordScore{ _wordScore }, typoCost{ _typoCost }, typoFormId{ _typoFormId }, nodeId{ _nodeId } + { + } + + bool operator==(const PathNode& o) const + { + return morph == o.morph + && str == o.str + && begin == o.begin + && end == o.end + && wordScore == o.wordScore + && typoCost == o.typoCost + && typoFormId == o.typoFormId; + } + }; + using Path = Vector; + + struct PathResult + { + Path path; + float score = 0; + SpecialState prevState; + SpecialState curState; + + PathResult(Path&& _path = {}, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) + : path{ move(_path) }, score{ _score }, prevState{ _prevState }, curState{ _curState } + { + } + + PathResult(const Path& _path, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) + : path{ _path }, score{ _score }, prevState{ _prevState }, curState{ _curState } + { + } + }; + + struct BestPathFinder + { + template + static Vector findBestPath(const Kiwi* kw, + const Vector& prevSpStates, + const KGraphNode* graph, + const size_t graphSize, + const size_t topN, + bool openEnd, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ); + }; + + using FnFindBestPath = decltype(&BestPathFinder::findBestPath); + + FnFindBestPath getFindBestPathFn(ArchType archType, const LangModel& langMdl); +} diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 08f36941..23d04db5 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -12,37 +12,12 @@ #include "StrUtils.h" #include "SortUtils.hpp" #include "LimitedVector.hpp" +#include "PathEvaluator.h" using namespace std; namespace kiwi { - struct SpecialState - { - uint8_t singleQuote : 1; - uint8_t doubleQuote : 1; - uint8_t bulletHash : 6; - - SpecialState() : singleQuote{ 0 }, doubleQuote{ 0 }, bulletHash{ 0 } - { - } - - operator uint8_t() const - { - return reinterpret_cast(*this); - } - - bool operator<(const SpecialState& o) const - { - return (uint8_t)(*this) < (uint8_t)o; - } - - bool operator==(const SpecialState& o) const - { - return (uint8_t)(*this) == (uint8_t)o; - } - }; - template struct WordLL; @@ -115,7 +90,7 @@ namespace kiwi const std::unordered_set* blocklist = nullptr ); - template + template static void evalMorphemes( Vector>& resultOut, const Kiwi* kw, @@ -127,93 +102,13 @@ namespace kiwi const KGraphNode* node, const KGraphNode* startNode, const size_t topN, + const size_t totalPrevPathes, const float ignoreCondScore, const float nodeLevelDiscount, const Vector& prevSpStates ); }; - struct BestPathFinder - { - public: - struct Result - { - const Morpheme* morph = nullptr; - KString str; - uint32_t begin = 0, end = 0; - float wordScore = 0, typoCost = 0; - uint32_t typoFormId = 0; - uint32_t nodeId = 0; - - Result(const Morpheme* _morph = nullptr, - const KString& _str = {}, - uint32_t _begin = 0, - uint32_t _end = 0, - float _wordScore = 0, - float _typoCost = 0, - uint32_t _typoFormId = 0, - uint32_t _nodeId = 0 - ) - : morph{ _morph }, str{ _str }, begin{ _begin }, end{ _end }, - wordScore{ _wordScore }, typoCost{ _typoCost }, typoFormId{ _typoFormId }, nodeId{ _nodeId } - { - } - - bool operator==(const Result& o) const - { - return morph == o.morph - && str == o.str - && begin == o.begin - && end == o.end - && wordScore == o.wordScore - && typoCost == o.typoCost - && typoFormId == o.typoFormId; - } - }; - using Path = Vector; - - struct ChunkResult - { - Path path; - float score = 0; - SpecialState prevState; - SpecialState curState; - - ChunkResult(Path&& _path = {}, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) - : path{ move(_path) }, score{ _score }, prevState{ _prevState }, curState{ _curState } - {} - - ChunkResult(const Path& _path, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) - : path{ _path }, score{ _score }, prevState{ _prevState }, curState{ _curState } - {} - }; - - template - static Vector find(const Kiwi* kw, - const Vector& prevSpStates, - const KGraphNode* graph, - const size_t graphSize, - const size_t topN, - bool openEnd, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - }; - - using FnFindBestPath = decltype(&BestPathFinder::find>); - - template class LmState> - struct FindBestPathGetter - { - template - struct Wrapper - { - static constexpr FnFindBestPath value = &BestPathFinder::find(i)>>; - }; - }; - template struct WordLL { @@ -1053,7 +948,7 @@ namespace kiwi }; template - template + template void PathEvaluator::type>::evalMorphemes( Vector>& resultOut, const Kiwi* kw, @@ -1065,17 +960,13 @@ namespace kiwi const KGraphNode* node, const KGraphNode* startNode, const size_t topN, + const size_t totalPrevPathes, const float ignoreCondScore, const float nodeLevelDiscount, const Vector& prevSpStates ) { - thread_local Vector> bestPathIndicesSmall; - thread_local Vector> bestPathValuesSmall; - thread_local UnorderedMap, WordLL> bestPathes; - // pair: [index, size] - thread_local UnorderedMap, pair> bestPathIndices; - thread_local Vector> bestPathValues; + thread_local BestPathConatiner bestPathCont; thread_local Vector rootIds; thread_local Vector> evalMatrix; thread_local Vector nextWids; @@ -1086,16 +977,9 @@ namespace kiwi const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; const size_t langVocabSize = langMdl.vocabSize(); - size_t totalPrevPathes = 0; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - totalPrevPathes += cache[prev - startNode].size(); - } evalMatrix.resize(totalPrevPathes * morphs.size()); nextWids.clear(); - const bool useContainerForSmall = totalPrevPathes <= 48; - size_t prevId = -1; for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { @@ -1232,25 +1116,7 @@ namespace kiwi for (size_t curId = 0; curId < morphs.size(); ++curId) { const auto curMorph = morphs[curId]; - - if (top1) - { - if (useContainerForSmall) - { - bestPathIndicesSmall.clear(); - bestPathValuesSmall.clear(); - } - else - { - bestPathes.clear(); - } - } - else - { - bestPathIndices.clear(); - bestPathValues.clear(); - } - + bestPathCont.clear(); const Morpheme* lastMorph; if (curMorph->chunks.empty() || curMorph->complex) @@ -1319,146 +1185,12 @@ namespace kiwi } PathHash ph{ em.state, prevPath.rootId, spState }; - if (top1) - { - if (useContainerForSmall) - { - const size_t foundIdx = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph) - bestPathIndicesSmall.begin(); - if (foundIdx >= bestPathIndicesSmall.size()) - { - bestPathIndicesSmall.emplace_back(ph); - bestPathValuesSmall.emplace_back(curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); - if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; - } - else - { - auto& target = bestPathValuesSmall[foundIdx]; - if (candScoreWithRule > target.accScore) - { - target = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; - if (rootId != commonRootId) target.rootId = rootId; - } - } - } - else - { - WordLL newPath{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; - if (rootId != commonRootId) newPath.rootId = rootId; - auto inserted = bestPathes.emplace(ph, newPath); - if (!inserted.second) - { - auto& target = inserted.first->second; - if (candScoreWithRule > target.accScore) - { - target = newPath; - } - } - } - } - else - { - auto inserted = bestPathIndices.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); - if (inserted.second) - { - bestPathValues.emplace_back(curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); - if (rootId != commonRootId) bestPathValues.back().rootId = rootId; - bestPathValues.resize(bestPathValues.size() + topN - 1); - } - else - { - auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; - auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; - if (distance(bestPathFirst, bestPathLast) < topN) - { - *bestPathLast = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; - if (rootId != commonRootId) bestPathLast->rootId = rootId; - push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); - ++inserted.first->second.second; - } - else - { - if (candScoreWithRule > bestPathFirst->accScore) - { - pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - *(bestPathLast - 1) = WordLL{ curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState }; - if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; - push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - } - } - } - } + bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); } } } - if (top1) - { - if (useContainerForSmall) - { - for (auto& p : bestPathValuesSmall) - { - resultOut.emplace_back(move(p)); - auto& newPath = resultOut.back(); - - // fill the rest information of resultOut - if (curMorph->chunks.empty() || curMorph->complex) - { - newPath.wid = lastSeqId; - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } - else - { - newPath.wid = lastSeqId; - } - } - } - else - { - for (auto& p : bestPathes) - { - resultOut.emplace_back(move(p.second)); - auto& newPath = resultOut.back(); - - // fill the rest information of resultOut - if (curMorph->chunks.empty() || curMorph->complex) - { - newPath.wid = lastSeqId; - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } - else - { - newPath.wid = lastSeqId; - } - } - } - } - else - { - for (auto& p : bestPathIndices) - { - const auto index = p.second.first; - const auto size = p.second.second; - for (size_t i = 0; i < size; ++i) - { - resultOut.emplace_back(move(bestPathValues[index + i])); - auto& newPath = resultOut.back(); - - // fill the rest information of resultOut - if (curMorph->chunks.empty() || curMorph->complex) - { - newPath.wid = lastSeqId; - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } - else - { - newPath.wid = lastSeqId; - } - } - } - } + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); } } @@ -1560,6 +1292,7 @@ namespace kiwi newPath.wid = zCodaMorph->lmMorphemeId; } } + continue; } // 사이시옷(zSiot)을 위한 지름길 if (zSiotMorph) @@ -1583,19 +1316,33 @@ namespace kiwi newPath.wid = zSiotMorph->lmMorphemeId; } } + continue; + } + + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + totalPrevPathes += cache[prev - startNode].size(); } + const bool useContainerForSmall = totalPrevPathes <= 48; - if (topN == 1) + if (topN > 1) + { + evalMorphemes(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + else if (useContainerForSmall) { - evalMorphemes(nCache, kw, ownFormList, cache, + evalMorphemes(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, lbScores, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else { - evalMorphemes(nCache, kw, ownFormList, cache, + evalMorphemes(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, lbScores, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } if (!nCache.empty()) break; } @@ -1640,7 +1387,7 @@ namespace kiwi template - inline BestPathFinder::Path generateTokenList(const WordLL* result, + inline Path generateTokenList(const WordLL* result, const utils::ContainerSearcher>& csearcher, const KGraphNode* graph, const Vector& ownFormList, @@ -1661,7 +1408,7 @@ namespace kiwi return morphFirst + morph->lmMorphemeId; }; - BestPathFinder::Path ret; + Path ret; const WordLL* prev = steps.back()->parent; for (auto it = steps.rbegin(); it != steps.rend(); ++it) { @@ -1754,7 +1501,7 @@ namespace kiwi } template - Vector BestPathFinder::find(const Kiwi* kw, + Vector BestPathFinder::findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, @@ -1906,7 +1653,7 @@ namespace kiwi #endif utils::ContainerSearcher> csearcher{ cache }; - Vector ret; + Vector ret; size_t numUniqRootIdAndSpState; { UnorderedSet> uniqRootIdAndSpState; @@ -1939,7 +1686,7 @@ namespace kiwi ret.emplace_back(move(tokens), cand[i].accScore, uniqStates[cand[i].rootId], cand[i].spState); } } - sort(ret.begin(), ret.end(), [](const ChunkResult& a, const ChunkResult& b) + sort(ret.begin(), ret.end(), [](const PathResult& a, const PathResult& b) { return a.score > b.score; }); diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index 58754a20..e77740a1 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -68,6 +68,7 @@ + @@ -122,6 +123,9 @@ + + EIGEN_VECTORIZE_AVX512;KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + From 01fe634cb51b69d6d0173b96ebb50cfb0665533b Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 21 Jan 2025 00:26:46 +0900 Subject: [PATCH 04/53] implemented PoC of softmax window token mixing --- include/kiwi/PCLanguageModel.h | 4 +- include/kiwi/Types.h | 1 + src/LmState.hpp | 45 ++++++++++--- src/PCLanguageModel.cpp | 96 +++++++++++++++++++++------- src/PCLanguageModel.hpp | 111 +++++++++++++++++++++++++++++---- 5 files changed, 211 insertions(+), 46 deletions(-) diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h index e2aca30f..14c23698 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/PCLanguageModel.h @@ -45,8 +45,8 @@ namespace kiwi virtual ~PCLanguageModelBase() {} const Header& getHeader() const { return *reinterpret_cast(base.get()); } - static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding); - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, bool reorderContextIdx = true); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false); }; } } diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index 40bc4c62..78144de0 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -309,6 +309,7 @@ namespace kiwi knlm = 1, /**< Kneser-Ney Language Model */ sbg = 2, /**< Skip-Bigram Model */ pclm = 3, /**< Pre-computed Context Language Model */ + pclmLocal = 4, /**< Pre-computed Context Language Model (Only local context) */ knlmTransposed, }; diff --git a/src/LmState.hpp b/src/LmState.hpp index daca463c..a24be7ae 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -104,29 +104,58 @@ namespace kiwi } }; - template + template class PcLMState { - friend struct Hash>; + friend struct Hash>; + protected: int32_t node = 0; uint32_t contextIdx = 0; + public: + static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; + + PcLMState() = default; + PcLMState(const LangModel& lm) {} + + bool operator==(const PcLMState& other) const + { + return node == other.node; + } + + float next(const LangModel& lm, VocabTy next) + { + auto& pclm = static_cast&>(*lm.pclm); + size_t historyPos = 0; + std::array history = { {0,} }; + return pclm.progress(node, contextIdx, historyPos, history, next); + } + }; + + template + class PcLMState : public PcLMState + { + static constexpr bool useDistantTokens = true; + friend struct Hash>; + protected: size_t historyPos = 0; - std::array history = { {0,} }; + std::array history = { {0,} }; public: static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; PcLMState() = default; PcLMState(const LangModel& lm) {} bool operator==(const PcLMState& other) const { - return node == other.node && historyPos == other.historyPos && history == other.history; + return PcLMState::operator==(other) && historyPos == other.historyPos && history == other.history; } float next(const LangModel& lm, VocabTy next) { - auto& pclm = static_cast&>(*lm.pclm); - return pclm.progress(node, contextIdx, next); + auto& pclm = static_cast&>(*lm.pclm); + return pclm.progress(node, contextIdx, historyPos, history, next); } }; @@ -184,10 +213,10 @@ namespace kiwi template using type = SbgState; }; - template + template struct WrappedPcLM { - template using type = PcLMState; + template using type = PcLMState; }; template diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index ca769c9d..88d3e7b2 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -10,7 +10,7 @@ namespace kiwi { namespace pclm { - utils::MemoryObject PCLanguageModelBase::build(const string& contextDefinition, const string& embedding) + utils::MemoryObject PCLanguageModelBase::build(const string& contextDefinition, const string& embedding, bool reorderContextId) { ifstream contextStr, embeddingStr; if (!openFile(contextStr, contextDefinition)) @@ -86,13 +86,6 @@ namespace kiwi trie.build(p.first.begin(), p.first.end(), p.second + 1); } } - for (auto& p : erasedContexts) - { - if (auto* node = trie.find(p.first.begin(), p.first.end())) - { - //node->val = p.second + 1; - } - } } Vector nodeSizes; @@ -120,13 +113,16 @@ namespace kiwi valueCnts[0] = -1; // remap value idx by frequency - iota(valueArgsorted.begin(), valueArgsorted.end(), 0); - sort(valueArgsorted.begin(), valueArgsorted.end(), [&](uint32_t a, uint32_t b) { return valueCnts[a] > valueCnts[b]; }); - for (size_t i = 0; i < valueArgsorted.size(); ++i) + if (reorderContextId) { - valueNewIdx[valueArgsorted[i]] = (uint32_t)i; + iota(valueArgsorted.begin(), valueArgsorted.end(), 0); + sort(valueArgsorted.begin(), valueArgsorted.end(), [&](uint32_t a, uint32_t b) { return valueCnts[a] > valueCnts[b]; }); + for (size_t i = 0; i < valueArgsorted.size(); ++i) + { + valueNewIdx[valueArgsorted[i]] = (uint32_t)i; + } + for (auto& v : values) v = valueNewIdx[v]; } - for (auto& v : values) v = valueNewIdx[v]; } assert(nodeSizes.size() - 1 == keys.size()); @@ -145,34 +141,43 @@ namespace kiwi const uint32_t dim = utils::read(embeddingStr); const uint32_t contextSize = utils::read(embeddingStr); const uint32_t outputSize = utils::read(embeddingStr); + const uint32_t windowSize = utils::read(embeddingStr); Vector contextEmb(dim * contextSize); Vector contextEmbScale(contextSize); Vector contextEmbBias(contextSize); + Vector contextValidTokenSum(contextSize); Vector contextConfidence(contextSize); Vector distantEmb(dim * outputSize); Vector distantEmbScale(outputSize); Vector distantEmbBias(outputSize); Vector distantConfidence(outputSize); + vector positionConfidence(windowSize); Vector outputEmb(dim * outputSize); Vector outputEmbScale(outputSize); + Vector distantMask(outputSize); embeddingStr.read((char*)contextEmb.data(), contextEmb.size()); embeddingStr.read((char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); embeddingStr.read((char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextValidTokenSum.data(), contextValidTokenSum.size() * sizeof(uint16_t)); embeddingStr.read((char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); embeddingStr.read((char*)distantEmb.data(), distantEmb.size()); embeddingStr.read((char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); embeddingStr.read((char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); embeddingStr.read((char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)positionConfidence.data(), positionConfidence.size() * sizeof(uint16_t)); embeddingStr.read((char*)outputEmb.data(), outputEmb.size()); embeddingStr.read((char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantMask.data(), distantMask.size()); // remap context embedding + if (reorderContextId) { Vector newContextEmb(contextEmb.size()); Vector newContextEmbScale(contextSize); Vector newContextEmbBias(contextSize); + Vector newContextValidTokenSum(contextSize); for (size_t i = 0; i < contextSize; ++i) { auto idx = valueNewIdx[i]; @@ -181,10 +186,29 @@ namespace kiwi copy(src, src + dim, dst); newContextEmbScale[idx] = contextEmbScale[i]; newContextEmbBias[idx] = contextEmbBias[i]; + newContextValidTokenSum[idx] = contextValidTokenSum[i]; } contextEmb = move(newContextEmb); contextEmbScale = move(newContextEmbScale); contextEmbBias = move(newContextEmbBias); + contextValidTokenSum = move(newContextValidTokenSum); + } + + // compress distantMask into bits + const size_t compressedDistantMaskSize = (outputSize + 7) / 8; + { + for (size_t i = 0; i < outputSize; ++i) + { + if (i % 8 == 0) + { + distantMask[i / 8] = distantMask[i]; + } + else + { + distantMask[i / 8] |= distantMask[i] << (i % 8); + } + } + distantMask.resize(compressedDistantMaskSize); } Header header; @@ -193,7 +217,7 @@ namespace kiwi header.contextSize = contextSize; header.vocabSize = outputSize; header.keySize = 4; - header.windowSize = 8; + header.windowSize = windowSize; header.numNodes = nodeSizes.size(); size_t finalSize = 0; @@ -202,8 +226,10 @@ namespace kiwi header.valueOffset = alignedOffsetInc(finalSize, compressedKeys.size()); header.embOffset = alignedOffsetInc(finalSize, compressedValues.size()); finalSize += dim * (contextSize + outputSize * 2); - finalSize += contextSize * sizeof(uint16_t) * 3; + finalSize += contextSize * sizeof(uint16_t) * 4; finalSize += outputSize * sizeof(uint16_t) * 4; + finalSize += windowSize * sizeof(uint16_t); + finalSize += compressedDistantMaskSize; utils::MemoryOwner mem{ finalSize }; utils::omstream ostr{ (char*)mem.get(), (std::ptrdiff_t)mem.size() }; @@ -218,48 +244,70 @@ namespace kiwi ostr.write((const char*)contextEmb.data(), contextEmb.size()); ostr.write((const char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); ostr.write((const char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); + ostr.write((const char*)contextValidTokenSum.data(), contextValidTokenSum.size() * sizeof(uint16_t)); ostr.write((const char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); ostr.write((const char*)distantEmb.data(), distantEmb.size()); ostr.write((const char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); ostr.write((const char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); ostr.write((const char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + ostr.write((const char*)positionConfidence.data(), positionConfidence.size() * sizeof(uint16_t)); ostr.write((const char*)outputEmb.data(), outputEmb.size()); ostr.write((const char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); + ostr.write((const char*)distantMask.data(), distantMask.size()); return mem; } - template + template + inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) + { + auto& header = *reinterpret_cast(mem.get()); + switch (header.windowSize) + { + case 4: + return make_unique>(std::move(mem)); + case 7: + return make_unique>(std::move(mem)); + case 8: + return make_unique>(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; + }; + } + + template std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { auto& header = *reinterpret_cast(mem.get()); switch (header.keySize) { case 1: - return make_unique>(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); case 2: - return make_unique>(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); case 4: - return make_unique>(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; } } - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + template struct CreateOptimizedModelGetter { template struct Wrapper { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), useDistantTokens>; }; }; - std::unique_ptr PCLanguageModelBase::create(utils::MemoryObject&& mem, ArchType archType) + std::unique_ptr PCLanguageModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens) { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; + static tp::Table tableWithoutDistantTokens{ CreateOptimizedModelGetter{} }, + tableWithDistantTokens{ CreateOptimizedModelGetter{} }; + auto fn = (useDistantTokens ? tableWithDistantTokens : tableWithoutDistantTokens)[static_cast(archType)]; if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; return (*fn)(std::move(mem)); } diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index afd7f461..4d73c0af 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -9,7 +9,7 @@ #include "ArchAvailable.h" #include "search.h" #include "streamvbyte.h" - +#include "SkipBigramModelImpl.hpp" namespace kiwi { @@ -35,7 +35,14 @@ namespace kiwi } } - template + template + void logsoftmaxInplace(Arr& arr) + { + arr -= arr.maxCoeff(); + arr -= std::log(arr.exp().sum()); + } + + template class PCLanguageModel : public PCLanguageModelBase { using MyNode = Node; @@ -45,11 +52,14 @@ namespace kiwi std::unique_ptr valueData; std::unique_ptr contextEmb; std::unique_ptr contextBias; + std::unique_ptr contextValidTokenSum; std::unique_ptr contextConf; std::unique_ptr distantEmb; std::unique_ptr distantBias; std::unique_ptr distantConf; + std::unique_ptr positionConf; std::unique_ptr outputEmb; + std::unique_ptr distantMask; MyNode* findLowerNode(MyNode* node, KeyType k) const { @@ -199,10 +209,15 @@ namespace kiwi auto* eptr = ptr + header.embOffset; contextEmb = make_unique(header.contextSize * header.dim); contextBias = make_unique(header.contextSize); + contextValidTokenSum = make_unique(header.contextSize); contextConf = make_unique(header.contextSize); - distantEmb = make_unique(header.vocabSize * header.dim); - distantBias = make_unique(header.vocabSize); - distantConf = make_unique(header.vocabSize); + if (useDistantTokens) + { + distantEmb = make_unique(header.vocabSize * header.dim); + distantBias = make_unique(header.vocabSize); + distantConf = make_unique(header.vocabSize); + positionConf = make_unique(header.windowSize); + } outputEmb = make_unique(header.vocabSize * header.dim); const uint16_t* contextEmbScale = reinterpret_cast(eptr + header.contextSize * header.dim); @@ -218,6 +233,11 @@ namespace kiwi eptr += sizeof(uint16_t); } for (size_t i = 0; i < header.contextSize; ++i) + { + contextValidTokenSum[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.contextSize; ++i) { contextConf[i] = half2float(*reinterpret_cast(eptr)); eptr += sizeof(uint16_t); @@ -226,18 +246,23 @@ namespace kiwi const uint16_t* distantEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); for (size_t i = 0; i < header.vocabSize; ++i) { - dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); + if (useDistantTokens) dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); eptr += header.dim; } eptr += header.vocabSize * sizeof(uint16_t); for (size_t i = 0; i < header.vocabSize; ++i) { - distantBias[i] = half2float(*reinterpret_cast(eptr)); + if (useDistantTokens) distantBias[i] = half2float(*reinterpret_cast(eptr)); eptr += sizeof(uint16_t); } for (size_t i = 0; i < header.vocabSize; ++i) { - distantConf[i] = half2float(*reinterpret_cast(eptr)); + if (useDistantTokens) distantConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.windowSize; ++i) + { + if (useDistantTokens) positionConf[i] = half2float(*reinterpret_cast(eptr)); eptr += sizeof(uint16_t); } @@ -247,6 +272,14 @@ namespace kiwi dequantize(&outputEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(outputEmbScale[i])); eptr += header.dim; } + eptr += header.vocabSize * sizeof(uint16_t); + + if (useDistantTokens) + { + const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; + distantMask = make_unique(compressedDistantMaskSize); + std::copy(eptr, eptr + compressedDistantMaskSize, distantMask.get()); + } } uint32_t progressContextNode(int32_t& nodeIdx, KeyType next) const @@ -303,13 +336,67 @@ namespace kiwi } } - float progress(int32_t& nodeIdx, uint32_t& contextIdx, KeyType next) const + inline bool distantTokenMask(uint32_t idx) const + { + if (useDistantTokens) return (distantMask[idx / 8] & (1 << (idx % 8))) != 0; + else return false; + } + + float progress(int32_t& nodeIdx, + uint32_t& contextIdx, + size_t& historyPos, + std::array& history, + KeyType next) const { const auto& header = getHeader(); - Eigen::Map contextVec{ &contextEmb[contextIdx * header.dim], header.dim }; - Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; - float ll = contextVec.dot(outputVec) - contextBias[contextIdx]; + const bool validDistantToken = distantTokenMask(next); + float ll = 0; + + thread_local Eigen::MatrixXf mat; + mat.resize(header.dim, 1 + windowSize); + thread_local Eigen::VectorXf lls; + lls.resize(1 + windowSize); + if (useDistantTokens && validDistantToken) + { + lls[0] = contextConf[contextIdx]; + lls.tail(windowSize) = Eigen::Map{ &positionConf[0], windowSize }; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + lls[i + 1] += historyToken ? distantConf[historyToken] : -99999; + } + logsoftmaxInplace(lls.array()); + + mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; + lls[0] -= contextBias[contextIdx]; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + if (historyToken) mat.col(i + 1) = Eigen::Map{ &distantEmb[historyToken * header.dim], header.dim }; + else mat.col(i + 1).setZero(); + lls[i + 1] -= distantBias[historyToken]; + } + lls.tail(windowSize).array() += contextValidTokenSum[contextIdx]; + Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; + lls += mat.transpose() * outputVec; + ll = sb::LogExpSum{}(lls.data(), std::integral_constant()); + } + else + { + lls[0] = -contextBias[contextIdx]; + mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; + Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; + lls.head(1) += mat.transpose() * outputVec; + ll = lls[0]; + } + contextIdx = progressContextNode(nodeIdx, next); + if (history[windowSize]) + { + history[historyPos] = history[windowSize]; + historyPos = (historyPos + 1) % windowSize; + } + history[windowSize] = validDistantToken ? next : 0; return ll; } }; From dd8777122dfa5280cab8778cf2a1d7ab0342cc12 Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 21 Jan 2025 00:28:30 +0900 Subject: [PATCH 05/53] Add `exclusiveWindow` arg to `HSDataset` --- include/kiwi/Dataset.h | 3 ++- src/Dataset.cpp | 33 +++++++++++++++++++++++++++++---- src/KiwiBuilder.cpp | 4 ++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 0a0e9727..9d97d089 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -64,6 +64,7 @@ namespace kiwi size_t batchSize = 0; size_t causalContextSize = 0; size_t windowSize = 0; + bool exclusiveWindow = true; size_t totalTokens = 0; size_t passedSents = 0; size_t passedWorkItems = 0; @@ -74,7 +75,7 @@ namespace kiwi size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); public: - HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0); + HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, bool _exclusiveWindow = true, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; diff --git a/src/Dataset.cpp b/src/Dataset.cpp index 95812b7a..1b17add8 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -5,7 +5,9 @@ using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, +HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, + size_t _windowSize, bool _exclusiveWindow, + size_t _workers, double _dropoutProb, double _dropoutProbOnHistory) : workers{ _workers ? make_unique(_workers) : nullptr }, dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, @@ -13,7 +15,8 @@ HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windo locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, causalContextSize{ _causalContextSize }, - windowSize{ _windowSize } + windowSize{ _windowSize }, + exclusiveWindow{ _exclusiveWindow } { } @@ -181,6 +184,7 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, } } + int32_t lastV = nonVocab; for (size_t i = 1; i < tokens.size(); ++i) { int32_t v = tokenToVocab[tokens[i]]; @@ -225,12 +229,33 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, if (windowTokenValidness[v]) { std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); - history.pop_front(); - history.push_back(v); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = v; + } + else + { + history.pop_front(); + history.push_back(v); + } } else { local.inData.resize(local.inData.size() + windowSize, -1); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = nonVocab; + } } } diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index f0c39f00..88711b31 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -2349,7 +2349,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, HSDataset* splitDataset ) const { - HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb, dropoutProbOnHistory }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory }; auto& sents = dataset.sents.get(); const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; @@ -2382,7 +2382,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, if (splitDataset) { - *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; + *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb }; splitDataset->dummyBuilder = dataset.dummyBuilder; splitDataset->knlm = knlm; splitDataset->morphemes = &srcBuilder->morphemes; From f5eccd444db5b6ed6e205d636736c61bcf89b49f Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 21 Jan 2025 00:29:02 +0900 Subject: [PATCH 06/53] Add constructor for pclmLocal mode LM --- src/KiwiBuilder.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 88711b31..a0e3ed68 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -795,7 +795,11 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio if (modelType == ModelType::pclm) { - langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType); + langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, true); + } + else if (modelType == ModelType::pclmLocal) + { + langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, false); } if (!!(options & BuildOption::loadDefaultDict)) From ad74c55360a41aff4a24fa08f8f036a007a227aa Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 25 Jan 2025 01:11:23 +0900 Subject: [PATCH 07/53] Refactor LangModel implementations --- include/kiwi/Joiner.h | 119 +- include/kiwi/Kiwi.h | 33 +- include/kiwi/Knlm.h | 24 +- include/kiwi/LangModel.h | 80 ++ include/kiwi/LmState.h | 43 - include/kiwi/PCLanguageModel.h | 19 +- include/kiwi/SkipBigramModel.h | 23 +- src/BestPathContainer.hpp | 293 +++++ src/Joiner.cpp | 317 +----- src/Joiner.hpp | 239 +++- src/Kiwi.cpp | 137 +-- src/Kiwi.hpp | 19 + src/KiwiBuilder.cpp | 45 +- src/Knlm.cpp | 361 +++++- src/Knlm.hpp | 523 +++------ src/LmState.hpp | 304 ----- src/PCLanguageModel.cpp | 297 ++++- src/PCLanguageModel.hpp | 338 ++---- src/PathEvaluator.cpp | 126 -- src/PathEvaluator.h | 4 +- src/PathEvaluator.hpp | 1704 +++++++++++----------------- src/SkipBigramModel.cpp | 60 +- src/SkipBigramModel.hpp | 80 +- src/SkipBigramModelImpl.hpp | 2 +- src/SkipBigramTrainer.hpp | 16 +- src/archImpl/avx2.cpp | 2 +- src/archImpl/avx512bw.cpp | 2 +- src/archImpl/neon.cpp | 2 +- src/archImpl/none.cpp | 2 +- src/archImpl/sse2.cpp | 2 +- src/archImpl/sse4_1.cpp | 2 +- vsproj/kiwi_shared_library.vcxproj | 7 +- 32 files changed, 2465 insertions(+), 2760 deletions(-) create mode 100644 include/kiwi/LangModel.h create mode 100644 src/BestPathContainer.hpp create mode 100644 src/Kiwi.hpp delete mode 100644 src/LmState.hpp delete mode 100644 src/PathEvaluator.cpp diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 4447d3d1..98247929 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -1,12 +1,11 @@ #pragma once #include "Types.h" #include "ArchUtils.h" -#include "LmState.h" +#include "LangModel.h" namespace kiwi { class Kiwi; - template class VoidState; struct Form; namespace cmb @@ -57,34 +56,113 @@ namespace kiwi LmState lmState; float score = 0; - Candidate(const CompiledRule& _cr, const LangModel& lm) + Candidate(const CompiledRule& _cr, const lm::ILangModel* lm) : joiner{ _cr }, lmState{ lm } { } }; template - struct Candidate> + struct Candidate> { Joiner joiner; - Candidate(const CompiledRule& _cr, const LangModel& lm) + Candidate(const CompiledRule& _cr, const lm::ILangModel* lm) : joiner{ _cr } { } }; - class AutoJoiner + class ErasedVector { - friend class kiwi::Kiwi; + using FnDestruct = void(*)(ErasedVector*); + using FnCopyConstruct = void(*)(ErasedVector*, const ErasedVector&); + + template + static void destructImpl(ErasedVector* self) + { + auto* target = reinterpret_cast*>(&self->vec); + target->~Vector(); + } + + template + static void copyConstructImpl(ErasedVector* self, const ErasedVector& other) + { + auto* target = reinterpret_cast*>(&self->vec); + new (target) Vector{ *reinterpret_cast*>(&other.vec) }; + } - struct AddVisitor; - struct AddVisitor2; - const Kiwi* kiwi = nullptr; union { - typename std::aligned_storage) + sizeof(int), alignof(Vector)>::type candBuf; + Vector vec; }; + FnDestruct destruct = nullptr; + FnCopyConstruct copyConstruct = nullptr; + public: + + template + ErasedVector(Vector&& v) + { + auto* target = reinterpret_cast*>(&vec); + new (target) Vector{ move(v) }; + destruct = &destructImpl; + copyConstruct = ©ConstructImpl; + } + + ~ErasedVector() + { + if (destruct) + { + (*destruct)(this); + destruct = nullptr; + copyConstruct = nullptr; + } + } + + ErasedVector(const ErasedVector& other) + : destruct{ other.destruct }, copyConstruct{ other.copyConstruct } + { + if (!destruct) return; + (*copyConstruct)(this, other); + } + + ErasedVector(ErasedVector&& other) + { + std::swap(vec, other.vec); + std::swap(destruct, other.destruct); + std::swap(copyConstruct, other.copyConstruct); + } + + ErasedVector& operator=(const ErasedVector& other) + { + this->~ErasedVector(); + new (this) ErasedVector{ other }; + } + + ErasedVector& operator=(ErasedVector&& other) + { + std::swap(vec, other.vec); + std::swap(destruct, other.destruct); + std::swap(copyConstruct, other.copyConstruct); + return *this; + } + + template + Vector& get() + { + return *reinterpret_cast*>(&vec); + } + + template + const Vector& get() const + { + return *reinterpret_cast*>(&vec); + } + }; + + class AutoJoiner + { + friend class kiwi::Kiwi; template explicit AutoJoiner(const Kiwi& kiwi, Candidate&& state); @@ -93,16 +171,27 @@ namespace kiwi void foreachMorpheme(const Form* formHead, Func&& func) const; template - void add(size_t morphemeId, Space space, Vector>& candidates); + void addImpl(size_t morphemeId, Space space, Vector>& candidates); template - void add(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates); + void addImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates); template - void addWithoutSearch(size_t morphemeId, Space space, Vector>>& candidates); + void addWithoutSearchImpl(size_t morphemeId, Space space, Vector>>& candidates); template - void addWithoutSearch(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates); + void addWithoutSearchImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates); + + template + struct Dispatcher; + + using FnAdd = void(*)(AutoJoiner*, size_t, Space, Vector>>&); + using FnAdd2 = void(*)(AutoJoiner*, U16StringView, POSTag, bool, Space, Vector>>&); + + const Kiwi* kiwi = nullptr; + FnAdd dfAdd = nullptr; + FnAdd2 dfAdd2 = nullptr; + ErasedVector candBuf; public: ~AutoJoiner(); diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 981fe8ba..b3ebc9be 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -25,7 +25,7 @@ #include "ThreadPool.h" #include "WordDetector.h" #include "TagUtils.h" -#include "LmState.h" +#include "LangModel.h" #include "Joiner.h" #include "TypoTransformer.h" @@ -60,6 +60,7 @@ namespace kiwi friend class KiwiBuilder; friend struct BestPathFinder; template friend struct PathEvaluator; + template friend struct MorphemeEvaluator; friend class cmb::AutoJoiner; template class LmState> friend struct NewAutoJoinerGetter; @@ -82,22 +83,17 @@ namespace kiwi Vector typoPtrs; Vector typoForms; utils::FrozenTrie formTrie; - LangModel langMdl; + std::shared_ptr langMdl; std::shared_ptr combiningRule; std::unique_ptr pool; const Morpheme* getDefaultMorpheme(POSTag tag) const; - template - cmb::AutoJoiner newJoinerImpl() const - { - return cmb::AutoJoiner{ *this, cmb::Candidate{ *combiningRule, langMdl } }; - } - ArchType selectedArch = ArchType::none; void* dfSplitByTrie = nullptr; void* dfFindForm = nullptr; void* dfFindBestPath = nullptr; + void* dfNewJoiner = nullptr; public: enum class SpecialMorph { @@ -130,7 +126,7 @@ namespace kiwi * kiwi::KiwiBuilder 를 통해 생성된 객체만이 형태소 분석에 사용할 수 있다. */ Kiwi(ArchType arch = ArchType::default_, - LangModel _langMdl = {}, + const std::shared_ptr& _langMdl = {}, bool typoTolerant = false, bool continualTypoTolerant = false, bool lengtheningTypoTolerant = false); @@ -157,7 +153,7 @@ namespace kiwi ArchType archType() const { return selectedArch; } - ModelType modelType() const { return langMdl.type; } + ModelType modelType() const { return langMdl ? langMdl->getType() : ModelType::none; } /** * @brief 현재 Kiwi 객체가 오타 교정 기능이 켜진 상태로 생성되었는지 알려준다. @@ -371,6 +367,10 @@ namespace kiwi TokenResult* tokenizedResultOut = nullptr ) const; + + template + cmb::AutoJoiner newJoinerImpl() const; + /** * @brief 형태소들을 결합하여 텍스트로 복원해주는 작업을 수행하는 AutoJoiner를 반환한다. * @@ -381,11 +381,6 @@ namespace kiwi */ cmb::AutoJoiner newJoiner(bool lmSearch = true) const; - /** - * @brief Kiwi에 내장된 언어 모델에 접근할 수 있는 LmObject 객체를 생성한다. - */ - std::unique_ptr newLmObject() const; - /** * @brief `TokenInfo::typoFormId`로부터 실제 오타 형태를 복원한다. * @@ -517,9 +512,9 @@ namespace kiwi integrateAllomorph = v; } - const lm::KnLangModelBase* getKnLM() const + const lm::ILangModel* getLangModel() const { - return langMdl.knlm.get(); + return langMdl.get(); } void findMorpheme(std::vector& out, const std::u16string& s, POSTag tag = POSTag::unknown) const; @@ -536,7 +531,7 @@ namespace kiwi Vector forms; Vector morphemes; UnorderedMap formMap; - LangModel langMdl; + std::shared_ptr langMdl; std::shared_ptr combiningRule; WordDetector detector; @@ -682,7 +677,7 @@ namespace kiwi */ bool ready() const { - return !!langMdl.knlm; + return !!langMdl; } void saveModel(const std::string& modelPath) const; diff --git a/include/kiwi/Knlm.h b/include/kiwi/Knlm.h index 7727ad17..7ab8044a 100644 --- a/include/kiwi/Knlm.h +++ b/include/kiwi/Knlm.h @@ -1,21 +1,12 @@ #pragma once -#include -#include -#include -#include -#include -#include - -#include "Utils.h" -#include "Mmap.h" -#include "ArchUtils.h" +#include "LangModel.h" namespace kiwi { namespace lm { - struct Header + struct KnLangModelHeader { uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset; uint64_t unk_id, bos_id, eos_id, vocab_size; @@ -24,7 +15,7 @@ namespace kiwi }; template - struct Node + struct KnLangModelNode { KeyType num_nexts = 0; DiffType lower = 0; @@ -32,7 +23,7 @@ namespace kiwi float ll = 0, gamma = 0; }; - class KnLangModelBase + class KnLangModelBase : public ILangModel { protected: utils::MemoryObject base; @@ -50,14 +41,17 @@ namespace kiwi public: virtual ~KnLangModelBase() {} - const Header& getHeader() const { return *reinterpret_cast(base.get()); } + size_t vocabSize() const override { return getHeader().vocab_size; } + size_t getMemorySize() const override { return base.size(); } + + const KnLangModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } virtual ptrdiff_t getLowerNode(ptrdiff_t node_idx) const = 0; virtual size_t nonLeafNodeSize() const = 0; virtual const void* getExtraBuf() const = 0; - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool transposed = false); template> static utils::MemoryOwner build(Trie&& ngram_cf, diff --git a/include/kiwi/LangModel.h b/include/kiwi/LangModel.h new file mode 100644 index 00000000..ec839e63 --- /dev/null +++ b/include/kiwi/LangModel.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "Utils.h" +#include "Mmap.h" +#include "ArchUtils.h" +#include "Types.h" + +namespace kiwi +{ + namespace lm + { + class ILangModel + { + public: + virtual ~ILangModel() = default; + virtual ModelType getType() const = 0; + virtual size_t vocabSize() const = 0; + virtual size_t getMemorySize() const = 0; + + virtual void* getFindBestPathFn() const = 0; + virtual void* getNewJoinerFn() const = 0; + }; + + template + struct LmStateBase + { + float next(const ILangModel* langMdl, typename DerivedLM::VocabType nextToken) + { + using LmStateType = typename DerivedLM::LmStateType; + return static_cast(this)->nextImpl(static_cast(langMdl), nextToken); + } + }; + + template + class VoidLangModel; + + template + struct VoidState : public LmStateBase> + { + bool operator==(const VoidState& other) const + { + return true; + } + + float nextImpl(const VoidLangModel* langMdl, uint32_t nextToken) + { + return 0; + } + }; + + template + class VoidLangModel : public ILangModel + { + public: + using VocabType = uint32_t; + using LmStateType = VoidState; + + ModelType getType() const override { return ModelType::none; } + size_t vocabSize() const override { return 0; } + void* getFindBestPathFn() const override { return nullptr; } + void* getNewJoinerFn() const override { return nullptr; } + }; + } + + template + struct Hash> + { + size_t operator()(const lm::VoidState& state) const + { + return 0; + } + }; +} diff --git a/include/kiwi/LmState.h b/include/kiwi/LmState.h index dddaac98..e69de29b 100644 --- a/include/kiwi/LmState.h +++ b/include/kiwi/LmState.h @@ -1,43 +0,0 @@ -#pragma once - -#include -#include "Utils.h" -#include "Trie.hpp" -#include "Knlm.h" -#include "SkipBigramModel.h" -#include "PCLanguageModel.h" - -namespace kiwi -{ - struct LangModel - { - ModelType type; - std::shared_ptr knlm; - std::shared_ptr sbg; - std::shared_ptr pclm; - - size_t vocabSize() const - { - if (knlm) return knlm->getHeader().vocab_size; - else return pclm->getHeader().vocabSize; - } - }; - - class LmObjectBase - { - public: - virtual ~LmObjectBase() {} - - virtual size_t vocabSize() const = 0; - - virtual float evalSequence(const uint32_t* seq, size_t length, size_t stride) const = 0; - - virtual void predictNext(const uint32_t* seq, size_t length, size_t stride, float* outScores) const = 0; - - virtual void evalSequences( - const uint32_t* prefix, size_t prefixLength, size_t prefixStride, - const uint32_t* suffix, size_t suffixLength, size_t suffixStride, - size_t seqSize, const uint32_t** seq, const size_t* seqLength, const size_t* seqStride, float* outScores - ) const = 0; - }; -} diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h index 14c23698..c58835c0 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/PCLanguageModel.h @@ -9,12 +9,13 @@ #include "ArchUtils.h" #include "Mmap.h" +#include "LangModel.h" namespace kiwi { - namespace pclm + namespace lm { - struct Header + struct PcLangModelHeader { uint64_t vocabSize, contextSize; uint16_t dim; @@ -33,20 +34,24 @@ namespace kiwi uint32_t nextOffset = 0; }; - class PCLanguageModelBase + class PcLangModelBase : public ILangModel { protected: utils::MemoryObject base; - PCLanguageModelBase(utils::MemoryObject&& mem) : base{ std::move(mem) } + PcLangModelBase(utils::MemoryObject&& mem) : base{ std::move(mem) } { } public: - virtual ~PCLanguageModelBase() {} - const Header& getHeader() const { return *reinterpret_cast(base.get()); } + virtual ~PcLangModelBase() {} + size_t vocabSize() const override { return getHeader().vocabSize; } + ModelType getType() const override { return ModelType::pclm; } + size_t getMemorySize() const override { return base.size(); } + + const PcLangModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, bool reorderContextIdx = true); - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false); }; } } diff --git a/include/kiwi/SkipBigramModel.h b/include/kiwi/SkipBigramModel.h index 69e73069..396eb583 100644 --- a/include/kiwi/SkipBigramModel.h +++ b/include/kiwi/SkipBigramModel.h @@ -1,26 +1,18 @@ #pragma once -#include -#include -#include -#include -#include -#include - -#include "ArchUtils.h" -#include "Mmap.h" +#include "Knlm.h" namespace kiwi { - namespace sb + namespace lm { - struct Header + struct SkipBigramModelHeader { uint64_t vocabSize; uint8_t keySize, windowSize, compressed, quantize, _rsv[4]; }; - class SkipBigramModelBase + class SkipBigramModelBase : public ILangModel { protected: utils::MemoryObject base; @@ -30,9 +22,12 @@ namespace kiwi } public: virtual ~SkipBigramModelBase() {} - const Header& getHeader() const { return *reinterpret_cast(base.get()); } + size_t vocabSize() const override { return getHeader().vocabSize; } + ModelType getType() const override { return ModelType::sbg; } + + const SkipBigramModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + static std::unique_ptr create(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem, ArchType archType = ArchType::none); }; } } diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp new file mode 100644 index 00000000..6cee09cf --- /dev/null +++ b/src/BestPathContainer.hpp @@ -0,0 +1,293 @@ +#pragma once + +#include + +namespace kiwi +{ + template + struct WordLL; + + using Wid = uint32_t; + + enum class PathEvaluatingMode + { + topN, + top1, + top1Small, + }; + + template + struct WordLL + { + const Morpheme* morpheme = nullptr; + float accScore = 0, accTypoCost = 0; + const WordLL* parent = nullptr; + LmState lmState; + Wid wid = 0; + uint16_t ownFormId = 0; + uint8_t combineSocket = 0; + uint8_t rootId = 0; + SpecialState spState; + + WordLL() = default; + + WordLL(const Morpheme* _morph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) + : morpheme{ _morph }, + accScore{ _accScore }, + accTypoCost{ _accTypoCost }, + parent{ _parent }, + lmState{ _lmState }, + spState{ _spState }, + rootId{ parent ? parent->rootId : (uint8_t)0 } + { + } + + const WordLL* root() const + { + if (parent) return parent->root(); + else return this; + } + }; + + static constexpr uint8_t commonRootId = -1; + + template + struct PathHash + { + LmState lmState; + uint8_t rootId, spState; + + PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) + : lmState{ _lmState }, rootId{ _rootId }, spState{ _spState } + { + } + + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } + { + } + + bool operator==(const PathHash& o) const + { + return lmState == o.lmState && rootId == o.rootId && spState == o.spState; + } + }; + + template + struct Hash> + { + size_t operator()(const PathHash& p) const + { + size_t ret = 0; + if (sizeof(PathHash) % sizeof(size_t)) + { + auto ptr = reinterpret_cast(&p); + for (size_t i = 0; i < sizeof(PathHash) / sizeof(uint32_t); ++i) + { + ret ^= ptr[i]; + } + } + else + { + auto ptr = reinterpret_cast(&p); + for (size_t i = 0; i < sizeof(PathHash) / sizeof(size_t); ++i) + { + ret ^= ptr[i]; + } + } + return ret; + } + }; + + struct WordLLGreater + { + template + bool operator()(const WordLL& a, const WordLL& b) const + { + return a.accScore > b.accScore; + } + }; + + template + inline std::ostream& printDebugPath(std::ostream& os, const WordLL& src) + { + if (src.parent) + { + printDebugPath(os, *src.parent); + } + + if (src.morpheme) src.morpheme->print(os); + else os << "NULL"; + os << " , "; + return os; + } + + template + class BestPathConatiner; + + template + class BestPathConatiner + { + // pair: [index, size] + UnorderedMap, std::pair> bestPathIndex; + Vector> bestPathValues; + public: + inline void clear() + { + bestPathIndex.clear(); + bestPathValues.clear(); + } + + inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); + if (inserted.second) + { + bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + if (rootId != commonRootId) bestPathValues.back().rootId = rootId; + bestPathValues.resize(bestPathValues.size() + topN - 1); + } + else + { + auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; + auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; + if (distance(bestPathFirst, bestPathLast) < topN) + { + *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + if (rootId != commonRootId) bestPathLast->rootId = rootId; + push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); + ++inserted.first->second.second; + } + else + { + if (accScore > bestPathFirst->accScore) + { + pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; + push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + } + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& p : bestPathIndex) + { + const auto index = p.second.first; + const auto size = p.second.second; + for (size_t i = 0; i < size; ++i) + { + resultOut.emplace_back(move(bestPathValues[index + i])); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + } + }; + + template + class BestPathConatiner + { + UnorderedMap, WordLL> bestPathes; + public: + inline void clear() + { + bestPathes.clear(); + } + + inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + WordLL newPath{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + if (rootId != commonRootId) newPath.rootId = rootId; + auto inserted = bestPathes.emplace(ph, newPath); + if (!inserted.second) + { + auto& target = inserted.first->second; + if (accScore > target.accScore) + { + target = newPath; + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& p : bestPathes) + { + resultOut.emplace_back(move(p.second)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + }; + + template + class BestPathConatiner + { + Vector> bestPathIndicesSmall; + Vector> bestPathValuesSmall; + public: + + inline void clear() + { + bestPathIndicesSmall.clear(); + bestPathValuesSmall.clear(); + } + + inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + const auto it = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph); + if (it == bestPathIndicesSmall.end()) + { + bestPathIndicesSmall.push_back(ph); + bestPathValuesSmall.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; + } + else + { + auto& target = bestPathValuesSmall[it - bestPathIndicesSmall.begin()]; + if (accScore > target.accScore) + { + target = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + if (rootId != commonRootId) target.rootId = rootId; + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& p : bestPathValuesSmall) + { + resultOut.emplace_back(move(p)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + }; +} \ No newline at end of file diff --git a/src/Joiner.cpp b/src/Joiner.cpp index c80eb13a..21c4c007 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -212,339 +212,44 @@ namespace kiwi } } - AutoJoiner::~AutoJoiner() - { - reinterpret_cast(candBuf).~CandVector(); - } - - AutoJoiner::AutoJoiner(const AutoJoiner& o) - : kiwi{ o.kiwi } - { - new (&candBuf) CandVector{ reinterpret_cast(o.candBuf) }; - } + AutoJoiner::~AutoJoiner() = default; - AutoJoiner::AutoJoiner(AutoJoiner&& o) - : kiwi{ o.kiwi } - { - new (&candBuf) CandVector{ reinterpret_cast(o.candBuf) }; - } - - AutoJoiner& AutoJoiner::operator=(const AutoJoiner& o) - { - kiwi = o.kiwi; - reinterpret_cast(candBuf) = reinterpret_cast(o.candBuf); - return *this; - } + AutoJoiner::AutoJoiner(const AutoJoiner& o) = default; - AutoJoiner& AutoJoiner::operator=(AutoJoiner&& o) - { - kiwi = o.kiwi; - reinterpret_cast(candBuf) = reinterpret_cast(o.candBuf); - return *this; - } + AutoJoiner::AutoJoiner(AutoJoiner&& o) = default; - template - void AutoJoiner::add(size_t morphemeId, Space space, Vector>& candidates) - { - auto& morph = kiwi->morphemes[morphemeId]; - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, morph.lmMorphemeId); - cand.joiner.add(morph.getForm(), morph.tag, space); - } - - sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) - { - return a.score > b.score; - }); - } + AutoJoiner& AutoJoiner::operator=(const AutoJoiner& o) = default; - template - void AutoJoiner::foreachMorpheme(const Form* formHead, Func&& func) const - { - if (kiwi->isTypoTolerant()) - { - auto tformHead = reinterpret_cast(formHead); - do - { - if (tformHead->score() == 0) - { - for (auto m : tformHead->form(kiwi->forms.data()).candidate) - { - func(m); - } - } - ++tformHead; - } while (tformHead[-1].hash() == tformHead[0].hash()); - } - else - { - do - { - for (auto m : formHead->candidate) - { - func(m); - } - ++formHead; - } while (formHead[-1].form == formHead[0].form); - } - } - - template - void AutoJoiner::add(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) - { - const Form* formHead; - auto node = kiwi->formTrie.root(); - for (auto c : normalizeHangul(form)) - { - node = node->template nextOpt(kiwi->formTrie, c); - if (!node) break; - } - - // prevent unknown or partial tag - POSTag fixedTag = tag; - if (tag == POSTag::unknown || tag == POSTag::p) - { - fixedTag = POSTag::nnp; - } - - if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) - { - Vector cands; - foreachMorpheme(formHead, [&](const Morpheme* m) - { - if (areTagsEqual(m->tag, fixedTag, inferRegularity)) - { - cands.emplace_back(m); - } - }); - - if (cands.size() <= 1) - { - auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; - if (!cands.empty()) tag = cands[0]->tag; - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, lmId); - cand.joiner.add(form, tag, space); - } - } - else - { - size_t oSize = candidates.size(); - for (size_t i = 1; i < cands.size(); ++i) - { - for (size_t o = 0; o < oSize; ++o) - { - candidates.emplace_back(candidates[o]); - auto& n = candidates.back(); - n.score += n.lmState.next(kiwi->langMdl, cands[i]->lmMorphemeId); - n.joiner.add(form, cands[i]->tag, space); - } - } - for (size_t o = 0; o < oSize; ++o) - { - auto& n = candidates[o]; - n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId); - n.joiner.add(form, cands[0]->tag, space); - } - - UnorderedMap> bestScoreByState; - for (size_t i = 0; i < candidates.size(); ++i) - { - auto& c = candidates[i]; - auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i)); - if (!inserted.second) - { - if (inserted.first->second.first < c.score) - { - inserted.first->second = make_pair(c.score, i); - } - } - } - - if (bestScoreByState.size() < candidates.size()) - { - Vector> newCandidates; - newCandidates.reserve(bestScoreByState.size()); - for (auto& p : bestScoreByState) - { - newCandidates.emplace_back(std::move(candidates[p.second.second])); - } - candidates = std::move(newCandidates); - } - } - } - else - { - auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, lmId); - cand.joiner.add(form, tag, space); - } - } - sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) - { - return a.score > b.score; - }); - } - - template - void AutoJoiner::addWithoutSearch(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) - { - if (inferRegularity) - { - auto node = kiwi->formTrie.root(); - for (auto c : normalizeHangul(form)) - { - node = node->template nextOpt(kiwi->formTrie, c); - if (!node) break; - } - - if (node) - { - if (const Form* formHead = node->val(kiwi->formTrie)) - { - Vector cands; - foreachMorpheme(formHead, [&](const Morpheme* m) - { - if (areTagsEqual(m->tag, tag, true)) - { - cands.emplace_back(m); - } - }); - - if (!cands.empty()) - { - tag = cands[0]->tag; - } - } - } - } - candidates[0].joiner.add(form, tag, space); - } - - template - void AutoJoiner::addWithoutSearch(size_t morphemeId, Space space, Vector>>& candidates) - { - auto& morph = kiwi->morphemes[morphemeId]; - for (auto& cand : candidates) - { - cand.joiner.add(morph.getForm(), morph.tag, space); - } - } - - struct AutoJoiner::AddVisitor - { - AutoJoiner* joiner; - U16StringView form; - POSTag tag; - bool inferRegularity; - Space space; - - AddVisitor(AutoJoiner* _joiner, U16StringView _form, POSTag _tag, bool _inferRegularity, Space _space) - : joiner{ _joiner }, form{ _form }, tag{ _tag }, inferRegularity{ _inferRegularity }, space{ _space } - { - } - - template - void operator()(Vector>>& o) const - { - return joiner->addWithoutSearch(form, tag, inferRegularity, space, o); - } - - template - void operator()(Vector>& o) const - { - return joiner->add(form, tag, inferRegularity, space, o); - } - }; - - struct AutoJoiner::AddVisitor2 - { - AutoJoiner* joiner; - size_t morphemeId; - Space space; - - AddVisitor2(AutoJoiner* _joiner, size_t _morphemeId, Space _space) - : joiner{ _joiner }, morphemeId{ _morphemeId }, space{ _space } - { - } - - template - void operator()(Vector>>& o) const - { - return joiner->addWithoutSearch(morphemeId, space, o); - } - - template - void operator()(Vector>& o) const - { - return joiner->add(morphemeId, space, o); - } - }; - - struct GetU16Visitor - { - vector>* rangesOut; - - GetU16Visitor(vector>* _rangesOut) - : rangesOut{ _rangesOut } - { - } - - template - u16string operator()(const Vector>& o) const - { - return o[0].joiner.getU16(rangesOut); - } - }; - - struct GetU8Visitor - { - vector>* rangesOut; - - GetU8Visitor(vector>* _rangesOut) - : rangesOut{ _rangesOut } - { - } - - template - string operator()(const Vector>& o) const - { - return o[0].joiner.getU8(rangesOut); - } - }; + AutoJoiner& AutoJoiner::operator=(AutoJoiner&& o) = default; void AutoJoiner::add(size_t morphemeId, Space space) { - return mapbox::util::apply_visitor(AddVisitor2{ this, morphemeId, space }, reinterpret_cast(candBuf)); + return (*dfAdd)(this, morphemeId, space, candBuf.get>>()); } void AutoJoiner::add(const u16string& form, POSTag tag, bool inferRegularity, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, nonstd::to_string_view(form), tag, inferRegularity, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, nonstd::to_string_view(form), tag, inferRegularity, space, candBuf.get>>()); } void AutoJoiner::add(const char16_t* form, POSTag tag, bool inferRegularity, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, U16StringView{ form }, tag, inferRegularity, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, U16StringView{ form }, tag, inferRegularity, space, candBuf.get>>()); } void AutoJoiner::add(U16StringView form, POSTag tag, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, form, tag, false, space, candBuf.get>>()); } u16string AutoJoiner::getU16(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast(candBuf)); + return candBuf.get>>()[0].joiner.getU16(rangesOut); } string AutoJoiner::getU8(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast(candBuf)); + return candBuf.get>>()[0].joiner.getU8(rangesOut); } } } diff --git a/src/Joiner.hpp b/src/Joiner.hpp index bf195e2f..b7537039 100644 --- a/src/Joiner.hpp +++ b/src/Joiner.hpp @@ -1,10 +1,9 @@ #pragma once #include #include -#include +#include #include "Combiner.h" #include "StrUtils.h" -#include "LmState.hpp" using namespace std; @@ -12,40 +11,234 @@ namespace kiwi { namespace cmb { - namespace detail + template + void AutoJoiner::addImpl(size_t morphemeId, Space space, Vector>& candidates) + { + auto& morph = kiwi->morphemes[morphemeId]; + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), morph.lmMorphemeId); + cand.joiner.add(morph.getForm(), morph.tag, space); + } + + sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) + { + return a.score > b.score; + }); + } + + template + void AutoJoiner::foreachMorpheme(const Form* formHead, Func&& func) const { - template class Type, class> - struct VCUnpack2nd; + if (kiwi->isTypoTolerant()) + { + auto tformHead = reinterpret_cast(formHead); + do + { + if (tformHead->score() == 0) + { + for (auto m : tformHead->form(kiwi->forms.data()).candidate) + { + func(m); + } + } + ++tformHead; + } while (tformHead[-1].hash() == tformHead[0].hash()); + } + else + { + do + { + for (auto m : formHead->candidate) + { + func(m); + } + ++formHead; + } while (formHead[-1].form == formHead[0].form); + } + } + + template + void AutoJoiner::addImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) + { + const Form* formHead; + auto node = kiwi->formTrie.root(); + for (auto c : normalizeHangul(form)) + { + node = node->template nextOpt(kiwi->formTrie, c); + if (!node) break; + } + + // prevent unknown or partial tag + POSTag fixedTag = tag; + if (tag == POSTag::unknown || tag == POSTag::p) + { + fixedTag = POSTag::nnp; + } + + if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) + { + Vector cands; + foreachMorpheme(formHead, [&](const Morpheme* m) + { + if (areTagsEqual(m->tag, fixedTag, inferRegularity)) + { + cands.emplace_back(m); + } + }); + + if (cands.size() <= 1) + { + auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; + if (!cands.empty()) tag = cands[0]->tag; + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), lmId); + cand.joiner.add(form, tag, space); + } + } + else + { + size_t oSize = candidates.size(); + for (size_t i = 1; i < cands.size(); ++i) + { + for (size_t o = 0; o < oSize; ++o) + { + candidates.emplace_back(candidates[o]); + auto& n = candidates.back(); + n.score += n.lmState.next(kiwi->langMdl.get(), cands[i]->lmMorphemeId); + n.joiner.add(form, cands[i]->tag, space); + } + } + for (size_t o = 0; o < oSize; ++o) + { + auto& n = candidates[o]; + n.score += n.lmState.next(kiwi->langMdl.get(), cands[0]->lmMorphemeId); + n.joiner.add(form, cands[0]->tag, space); + } - template class Type, std::ptrdiff_t ...arches> - struct VCUnpack2nd> + UnorderedMap> bestScoreByState; + for (size_t i = 0; i < candidates.size(); ++i) + { + auto& c = candidates[i]; + auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i)); + if (!inserted.second) + { + if (inserted.first->second.first < c.score) + { + inserted.first->second = make_pair(c.score, i); + } + } + } + + if (bestScoreByState.size() < candidates.size()) + { + Vector> newCandidates; + newCandidates.reserve(bestScoreByState.size()); + for (auto& p : bestScoreByState) + { + newCandidates.emplace_back(std::move(candidates[p.second.second])); + } + candidates = std::move(newCandidates); + } + } + } + else + { + auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), lmId); + cand.joiner.add(form, tag, space); + } + } + sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) + { + return a.score > b.score; + }); + } + + template + void AutoJoiner::addWithoutSearchImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) + { + if (inferRegularity) { - using type = std::tuple(arches)>>>...>; - }; + auto node = kiwi->formTrie.root(); + for (auto c : normalizeHangul(form)) + { + node = node->template nextOpt(kiwi->formTrie, c); + if (!node) break; + } + + if (node) + { + if (const Form* formHead = node->val(kiwi->formTrie)) + { + Vector cands; + foreachMorpheme(formHead, [&](const Morpheme* m) + { + if (areTagsEqual(m->tag, tag, true)) + { + cands.emplace_back(m); + } + }); - template class ... Types> - struct VCUnpack; + if (!cands.empty()) + { + tag = cands[0]->tag; + } + } + } + } + candidates[0].joiner.add(form, tag, space); + } - template class ... Types> - struct VCUnpack, Types...> + template + void AutoJoiner::addWithoutSearchImpl(size_t morphemeId, Space space, Vector>>& candidates) + { + auto& morph = kiwi->morphemes[morphemeId]; + for (auto& cand : candidates) { - using type = TupleCat>::type...>; - }; + cand.joiner.add(morph.getForm(), morph.tag, space); + } } - using CandTypeTuple = typename detail::VCUnpack::type, WrappedKnLM::type, WrappedKnLM::type, WrappedKnLM::type, - WrappedSbg<8, uint8_t>::type, WrappedSbg<8, uint16_t>::type, WrappedSbg<8, uint32_t>::type, WrappedSbg<8, uint64_t>::type - >::type; + template + struct AutoJoiner::Dispatcher + { + static void add(AutoJoiner* joiner, size_t morphemeId, Space space, Vector>& candidates) + { + return joiner->addImpl(morphemeId, space, candidates); + } - using CandVector = typename detail::VariantFromTuple::type; + static void add2(AutoJoiner* joiner, U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) + { + return joiner->addImpl2(form, tag, inferRegularity, space, candidates); + } + }; + + template + struct AutoJoiner::Dispatcher> + { + static void add(AutoJoiner* joiner, size_t morphemeId, Space space, Vector>>& candidates) + { + return joiner->addWithoutSearchImpl(morphemeId, space, candidates); + } + + static void add2(AutoJoiner* joiner, U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) + { + return joiner->addWithoutSearchImpl2(form, tag, inferRegularity, space, candidates); + } + }; template AutoJoiner::AutoJoiner(const Kiwi& _kiwi, Candidate&& state) - : kiwi{ &_kiwi } + : kiwi{ &_kiwi }, candBuf{ Vector>{ { move(state) } } } { - new (&candBuf) CandVector(Vector>{ { move(state) } }); + using Dp = Dispatcher; + dfAdd = reinterpret_cast(&Dp::add); + dfAdd2 = reinterpret_cast(&Dp::add2); } + } } diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index b00e8b20..66af12ba 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -4,16 +4,17 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "FeatureTestor.h" #include "FrozenTrie.hpp" -#include "LmState.hpp" #include "StrUtils.h" #include "SortUtils.hpp" #include "serializer.hpp" #include "Joiner.hpp" #include "PathEvaluator.hpp" +#include "Kiwi.hpp" using namespace std; @@ -42,19 +43,19 @@ namespace kiwi } Kiwi::Kiwi(ArchType arch, - LangModel _langMdl, + const std::shared_ptr & _langMdl, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant) - : langMdl(_langMdl) + : langMdl{ _langMdl }, selectedArch{ arch } { - selectedArch = arch; dfSplitByTrie = (void*)getSplitByTrieFn(selectedArch, typoTolerant, continualTypoTolerant, lengtheningTypoTolerant); dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant); - dfFindBestPath = (void*)getFindBestPathFn(selectedArch, langMdl); + dfFindBestPath = langMdl ? langMdl->getFindBestPathFn() : nullptr; + dfNewJoiner = langMdl ? langMdl->getNewJoinerFn() : nullptr; } Kiwi::~Kiwi() = default; @@ -1148,137 +1149,15 @@ namespace kiwi return _asyncAnalyzeEcho(move(str), move(pretokenized), matchOptions, blocklist); } - using FnNewAutoJoiner = cmb::AutoJoiner(Kiwi::*)() const; - - template class LmState> - struct NewAutoJoinerGetter - { - template - struct Wrapper - { - static constexpr FnNewAutoJoiner value = &Kiwi::newJoinerImpl(i)>>; - }; - }; - cmb::AutoJoiner Kiwi::newJoiner(bool lmSearch) const { - static tp::Table lmVoid{ NewAutoJoinerGetter{} }; - static tp::Table lmKnLM_8{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_16{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_32{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_64{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_8{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_16{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_32{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_64{ NewAutoJoinerGetter::type>{} }; - - const auto archIdx = static_cast(selectedArch); - if (lmSearch) { - size_t vocabTySize = langMdl.knlm->getHeader().key_size; - if (langMdl.sbg) - { - switch (vocabTySize) - { - case 1: - return (this->*lmSbg_8[archIdx])(); - case 2: - return (this->*lmSbg_16[archIdx])(); - case 4: - return (this->*lmSbg_32[archIdx])(); - case 8: - return (this->*lmSbg_64[archIdx])(); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize)}; - } - } - else - { - switch (vocabTySize) - { - case 1: - return (this->*lmKnLM_8[archIdx])(); - case 2: - return (this->*lmKnLM_16[archIdx])(); - case 4: - return (this->*lmKnLM_32[archIdx])(); - case 8: - return (this->*lmKnLM_64[archIdx])(); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } - } - } - else - { - return (this->*lmVoid[archIdx])(); - } - } - - using FnNewLmObject = std::unique_ptr(*)(const LangModel&); - - template - std::unique_ptr makeNewLmObject(const LangModel& lm) - { - return make_unique>(lm); - } - - template class LmState> - struct NewLmObjectGetter - { - template - struct Wrapper - { - static constexpr FnNewLmObject value = makeNewLmObject(i)>>; - }; - }; - - std::unique_ptr Kiwi::newLmObject() const - { - static tp::Table lmKnLM_8{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_16{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_32{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_64{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_8{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_16{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_32{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_64{ NewLmObjectGetter::type>{} }; - - const auto archIdx = static_cast(selectedArch); - - size_t vocabTySize = langMdl.knlm->getHeader().key_size; - if (langMdl.sbg) - { - switch (vocabTySize) - { - case 1: - return (lmSbg_8[archIdx])(langMdl); - case 2: - return (lmSbg_16[archIdx])(langMdl); - case 4: - return (lmSbg_32[archIdx])(langMdl); - case 8: - return (lmSbg_64[archIdx])(langMdl); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } + return (*reinterpret_cast(dfNewJoiner))(this); } else { - switch (vocabTySize) - { - case 1: - return (lmKnLM_8[archIdx])(langMdl); - case 2: - return (lmKnLM_16[archIdx])(langMdl); - case 4: - return (lmKnLM_32[archIdx])(langMdl); - case 8: - return (lmKnLM_64[archIdx])(langMdl); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } + return cmb::AutoJoiner{ *this, cmb::Candidate>{ *combiningRule, langMdl.get() }}; } } diff --git a/src/Kiwi.hpp b/src/Kiwi.hpp new file mode 100644 index 00000000..4c39e668 --- /dev/null +++ b/src/Kiwi.hpp @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace kiwi +{ + using FnNewJoiner = cmb::AutoJoiner(*)(const Kiwi*); + + template + cmb::AutoJoiner Kiwi::newJoinerImpl() const + { + return cmb::AutoJoiner{ *this, cmb::Candidate{ *combiningRule, langMdl.get() }}; + } + + template + cmb::AutoJoiner newJoinerWithKiwi(const Kiwi* kiwi) + { + return kiwi->newJoinerImpl(); + } +} diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index a0e3ed68..360f15e9 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -16,6 +16,7 @@ #include "RaggedVector.hpp" #include "SkipBigramTrainer.hpp" #include "SkipBigramModel.hpp" +#include "PCLanguageModel.hpp" #include "SortUtils.hpp" using namespace std; @@ -740,7 +741,7 @@ void KiwiBuilder::updateForms() void KiwiBuilder::updateMorphemes(size_t vocabSize) { - if (vocabSize == 0) vocabSize = langMdl.vocabSize(); + if (vocabSize == 0) vocabSize = langMdl->vocabSize(); for (auto& m : morphemes) { if (m.lmMorphemeId > 0) continue; @@ -782,20 +783,17 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio loadMorphBin(iss); } - langMdl.type = modelType; - if (modelType == ModelType::knlm || modelType == ModelType::knlmTransposed || modelType == ModelType::sbg) + if (modelType == ModelType::knlm || modelType == ModelType::knlmTransposed) { - langMdl.knlm = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType); + langMdl = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType, modelType == ModelType::knlmTransposed); } - - if (modelType == ModelType::sbg) + else if (modelType == ModelType::sbg) { - langMdl.sbg = sb::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); + langMdl = lm::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); } - - if (modelType == ModelType::pclm) + else if (modelType == ModelType::pclm || modelType == ModelType::pclmLocal) { - langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, true); + langMdl = lm::PcLangModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, modelType == ModelType::pclm); } else if (modelType == ModelType::pclmLocal) { @@ -957,11 +955,11 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) if (lmVocabSize <= 0xFFFF) { - langMdl.knlm = buildKnLM(args, lmVocabSize, realMorph); + langMdl = buildKnLM(args, lmVocabSize, realMorph); } else { - langMdl.knlm = buildKnLM(args, lmVocabSize, realMorph); + langMdl = buildKnLM(args, lmVocabSize, realMorph); } updateMorphemes(); @@ -984,9 +982,9 @@ namespace kiwi { } - sb::FeedingData operator()(size_t i, size_t threadId = 0) + lm::FeedingData operator()(size_t i, size_t threadId = 0) { - sb::FeedingData ret; + lm::FeedingData ret; ret.len = sents[i].size(); if (lmBuf[threadId].size() < ret.len) { @@ -1008,7 +1006,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) using Vid = uint16_t; auto realMorph = restoreMorphemeMap(); - sb::SkipBigramTrainer sbg; + lm::SkipBigramTrainer sbg; RaggedVector sents; for (auto& path : args.corpora) { @@ -1086,7 +1084,9 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) return true; }; - sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, langMdl.knlm->nonLeafNodeSize() }; + auto* knlm = dynamic_cast(langMdl.get()); + + sbg = lm::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, knlm->nonLeafNodeSize() }; Vector lmLogProbs; Vector baseNodes; auto tc = sbg.newContext(); @@ -1107,7 +1107,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) lmLogProbs.resize(sent.size()); baseNodes.resize(sent.size()); } - langMdl.knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin()); + knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin()); //float sum = sbg.evaluate(&sent[0], lmLogProbs.data(), sent.size()); float sum = accumulate(lmLogProbs.begin() + 1, lmLogProbs.begin() + sent.size(), 0.); size_t cnt = sent.size() - 1; @@ -1124,7 +1124,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) if (args.numWorkers <= 1) { - sbg.train(SBDataFeeder{ sents, langMdl.knlm.get() }, [&](const sb::ObservingData& od) + sbg.train(SBDataFeeder{ sents, knlm }, [&](const lm::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; @@ -1138,7 +1138,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) } else { - sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, langMdl.knlm.get(), 8 }, [&](const sb::ObservingData& od) + sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, knlm, 8 }, [&](const lm::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; @@ -1167,7 +1167,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) lmLogProbs.resize(sent.size()); baseNodes.resize(sent.size()); } - langMdl.knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin(), baseNodes.begin()); + knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin(), baseNodes.begin()); float sum = sbg.evaluate(&sent[0], baseNodes.data(), lmLogProbs.data(), sent.size()); size_t cnt = sent.size() - 1; llCnt += cnt; @@ -1201,7 +1201,8 @@ void KiwiBuilder::saveModel(const string& modelPath) const saveMorphBin(ofs); } { - auto mem = langMdl.knlm->getMemory(); + auto* knlm = dynamic_cast(langMdl.get()); + auto mem = knlm->getMemory(); ofstream ofs{ modelPath + "/sj.knlm", ios_base::binary }; ofs.write((const char*)mem.get(), mem.size()); } @@ -2379,7 +2380,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, } } - auto& knlm = srcBuilder->langMdl.knlm; + auto knlm = dynamic_pointer_cast(langMdl); dataset.knlm = knlm; dataset.morphemes = &srcBuilder->morphemes; dataset.forms = &srcBuilder->forms; diff --git a/src/Knlm.cpp b/src/Knlm.cpp index 2ab66388..58ffd0f0 100644 --- a/src/Knlm.cpp +++ b/src/Knlm.cpp @@ -1,46 +1,383 @@ #include "Knlm.hpp" +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" namespace kiwi { namespace lm { - template + template + template + void KnLangModel::dequantizeDispatch( + tp::seq, + size_t bits, + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + using Fn = void(*)(Vector&, Vector&, + const char*, size_t, + const char*, size_t, + const float*, + const float*, + size_t, + size_t); + static constexpr Fn table[] = { + &dequantize... + }; + return table[bits - 1](restored_floats, restored_leaf_ll, + llq_data, llq_size, + gammaq_data, gammaq_size, + ll_table, gamma_table, + num_non_leaf_nodes, num_leaf_nodes + ); + } + + template + KnLangModel::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } + { + auto* ptr = reinterpret_cast(base.get()); + auto& header = getHeader(); + const size_t quantized = header.quantized & 0x1F; + const bool compressed = header.quantized & 0x80; + + Vector d_node_size; + auto* node_sizes = reinterpret_cast(ptr + header.node_offset); + key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); + std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); + size_t num_leaf_nodes = 0; + if (compressed) + { + d_node_size.resize(header.num_nodes); + auto qc_header = reinterpret_cast(ptr + header.node_offset); + auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); + QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); + node_sizes = d_node_size.data(); + } + + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) num_non_leaf_nodes++; + else num_leaf_nodes++; + } + + // restore ll & gamma data + Vector restored_leaf_ll, restored_floats; + const float* ll_data = nullptr; + const float* gamma_data = nullptr; + const float* leaf_ll_data = nullptr; + if (quantized) + { + if (quantized > 16) + { + throw std::runtime_error{ "16+ bits quantization not supported." }; + } + + restored_floats.resize(num_non_leaf_nodes * 2); + restored_leaf_ll.resize(num_leaf_nodes); + leaf_ll_data = restored_leaf_ll.data(); + ll_data = &restored_floats[0]; + gamma_data = &restored_floats[num_non_leaf_nodes]; + + const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); + const float* gamma_table = ll_table + ((size_t)1 << quantized); + + dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, + ptr + header.ll_offset, header.gamma_offset - header.ll_offset, + ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, + ll_table, + gamma_table, + num_non_leaf_nodes, + num_leaf_nodes + ); + extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); + } + else + { + ll_data = reinterpret_cast(ptr + header.ll_offset); + gamma_data = reinterpret_cast(ptr + header.gamma_offset); + leaf_ll_data = ll_data + num_non_leaf_nodes; + extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); + } + + size_t htx_vocab_size = header.vocab_size; + if (header.htx_offset) + { + htx_data = reinterpret_cast(ptr + header.htx_offset); + htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; + extra_buf = toAlignedPtr(htx_data + header.vocab_size); + } + + if (!header.extra_buf_size) + { + extra_buf = nullptr; + } + + // restore node's data + node_data = make_unique(num_non_leaf_nodes); + all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); + value_data = &all_value_data[htx_vocab_size]; + std::fill(&all_value_data[0], value_data, 0); + + size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; + Vector> key_ranges; + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) + { + auto& node = node_data[non_leaf_idx]; + if (!key_ranges.empty()) + { + auto& back = key_ranges.back(); + value_data[back[1]] = non_leaf_idx - back[0]; + } + node.num_nexts = node_sizes[i]; + node.next_offset = next_offset; + node.ll = ll_data[non_leaf_idx]; + node.gamma = gamma_data[non_leaf_idx]; + next_offset += node_sizes[i]; + key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); + non_leaf_idx++; + } + else + { + auto& back = key_ranges.back(); + reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; + back[1]++; + while (key_ranges.back()[1] == key_ranges.back()[2]) + { + key_ranges.pop_back(); + if (key_ranges.empty()) break; + key_ranges.back()[1]++; + } + leaf_idx++; + } + } + + for (size_t i = 0; i < node_data[0].num_nexts; ++i) + { + auto k = key_data[i]; + auto v = value_data[i]; + all_value_data[k] = v; + } + + Vector tempBuf; + for (size_t i = 0; i < non_leaf_idx; ++i) + { + auto& node = node_data[i]; + nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); + } + + if (htx_data) + { + ptrdiff_t node = 0; + progress(node, (KeyType)header.bos_id); + unk_ll = getLL(node, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); + } + else + { + unk_ll = getLL(0, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, (KeyType)header.bos_id); + } + + Deque dq; + for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->num_nexts; ++i) + { + auto k = key_data[p->next_offset + i]; + auto v = value_data[p->next_offset + i]; + if (v <= 0) continue; + auto* child = &p[v]; + child->lower = findLowerNode(p, k) - child; + dq.emplace_back(child); + } + } + } + + template + float KnLangModel::getLL(ptrdiff_t node_idx, KeyType next) const + { + DiffType v; + auto* node = &node_data[node_idx]; + if (node_idx == 0) + { + v = all_value_data[next]; + if (v == 0) return unk_ll; + } + else + { + if (!nst::search( + &key_data[node->next_offset], + &value_data[node->next_offset], + node->num_nexts, next, v + )) + { + return node->gamma + getLL(node_idx + node->lower, next); + } + } + + // non-leaf node + if (v > 0) + { + return node_data[node_idx + v].ll; + } + // leaf node + else + { + return reinterpret_cast(v); + } + } + + template + template + float KnLangModel::progress(IdxType& node_idx, KeyType next) const + { + float acc = 0; + while (1) + { + DiffType v; + auto* node = &node_data[node_idx]; + auto* keys = &key_data[node->next_offset]; + auto* values = &value_data[node->next_offset]; + PREFETCH_T0(node + node->lower); + if (node_idx == 0) + { + v = all_value_data[next]; + if (v == 0) + { + if (htx_data) + { + IdxType lv; + if (nst::search( + &key_data[0], + value_data, + node_data[0].num_nexts, htx_data[next], lv + )) node_idx = lv; + else node_idx = 0; + } + return acc + unk_ll; + } + } + else + { + if (!nst::search( + keys, + values, + node->num_nexts, next, v + )) + { + acc += node->gamma; + node_idx += node->lower; + PREFETCH_T0(&key_data[node_data[node_idx].next_offset]); + continue; + } + } + + // non-leaf node + if (v > 0) + { + node_idx += v; + return acc + node_data[node_idx].ll; + } + // leaf node + else + { + while (node->lower) + { + node += node->lower; + DiffType lv; + if (nst::search( + &key_data[node->next_offset], + &value_data[node->next_offset], + node->num_nexts, next, lv + )) + { + if (lv > 0) + { + node += lv; + node_idx = node - &node_data[0]; + return acc + reinterpret_cast(v); + } + } + } + if (htx_data) + { + IdxType lv; + if (nst::search( + &key_data[0], + value_data, + node_data[0].num_nexts, htx_data[next], lv + )) node_idx = lv; + else node_idx = 0; + } + else node_idx = 0; + return acc + reinterpret_cast(v); + } + } + } + + template + void* KnLangModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder::findBestPath>; + } + + template + void* KnLangModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + + template std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { auto* ptr = reinterpret_cast(mem.get()); - auto& header = *reinterpret_cast(ptr); + auto& header = *reinterpret_cast(ptr); switch (header.key_size) { case 1: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 2: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 4: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 8: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) }; } } - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + template struct CreateOptimizedModelGetter { template struct Wrapper { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), transposed>; }; }; - std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType) + std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool transposed) { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; + static tp::Table table{ CreateOptimizedModelGetter{} }; + static tp::Table tableTransposed{ CreateOptimizedModelGetter{} }; + auto fn = (transposed ? tableTransposed : table)[static_cast(archType)]; if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; return (*fn)(std::move(mem)); } } -} \ No newline at end of file +} diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 1f9f5821..8c8dab82 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -17,80 +17,13 @@ namespace kiwi { namespace lm { - static constexpr size_t serialAlignment = 16; - - using QCode = qe::QCode<0, 2, 8, 16>; - - template - inline void dequantize( - Vector& restored_floats, Vector& restored_leaf_ll, - const char* llq_data, size_t llq_size, - const char* gammaq_data, size_t gammaq_size, - const float* ll_table, - const float* gamma_table, - size_t num_non_leaf_nodes, - size_t num_leaf_nodes - ) - { - FixedLengthEncoder llq{ llq_data, (ptrdiff_t)llq_size }; - FixedLengthEncoder gammaq{ gammaq_data, (ptrdiff_t)gammaq_size }; - - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i] = ll_table[llq.read()]; - } - - for (size_t i = 0; i < num_leaf_nodes; ++i) - { - restored_leaf_ll[i] = ll_table[llq.read()]; - } - - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i + num_non_leaf_nodes] = gamma_table[gammaq.read()]; - } - } - - template<> - inline void dequantize<8>( - Vector& restored_floats, Vector& restored_leaf_ll, - const char* llq_data, size_t llq_size, - const char* gammaq_data, size_t gammaq_size, - const float* ll_table, - const float* gamma_table, - size_t num_non_leaf_nodes, - size_t num_leaf_nodes - ) - { - const uint8_t* non_leaf_q = reinterpret_cast(llq_data); - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i] = ll_table[non_leaf_q[i]]; - } - - const uint8_t* leaf_q = reinterpret_cast(llq_data + num_non_leaf_nodes); - for (size_t i = 0; i < num_leaf_nodes; ++i) - { - restored_leaf_ll[i] = ll_table[leaf_q[i]]; - } + template + class KnLMState; - const uint8_t* gamma_q = reinterpret_cast(gammaq_data); - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i + num_non_leaf_nodes] = gamma_table[gamma_q[i]]; - } - } - - inline const void* toAlignedPtr(const void* ptr, size_t alignment = serialAlignment) - { - auto addr = reinterpret_cast(ptr); - return reinterpret_cast((addr + alignment - 1) & ~(alignment - 1)); - } - - template + template class KnLangModel : public KnLangModelBase { - using MyNode = Node; + using MyNode = KnLangModelNode; std::unique_ptr node_data; std::unique_ptr key_data; @@ -140,311 +73,22 @@ namespace kiwi const float* gamma_table, size_t num_non_leaf_nodes, size_t num_leaf_nodes - ) - { - using Fn = void(*)(Vector&, Vector&, - const char*, size_t, - const char*, size_t, - const float*, - const float*, - size_t, - size_t); - static constexpr Fn table[] = { - &dequantize... - }; - return table[bits - 1](restored_floats, restored_leaf_ll, - llq_data, llq_size, - gammaq_data, gammaq_size, - ll_table, gamma_table, - num_non_leaf_nodes, num_leaf_nodes - ); - } + ); public: - KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } - { - auto* ptr = reinterpret_cast(base.get()); - auto& header = getHeader(); - const size_t quantized = header.quantized & 0x1F; - const bool compressed = header.quantized & 0x80; - - Vector d_node_size; - auto* node_sizes = reinterpret_cast(ptr + header.node_offset); - key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); - std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); - size_t num_leaf_nodes = 0; - if (compressed) - { - d_node_size.resize(header.num_nodes); - auto qc_header = reinterpret_cast(ptr + header.node_offset); - auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); - QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); - node_sizes = d_node_size.data(); - } - - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) num_non_leaf_nodes++; - else num_leaf_nodes++; - } - - // restore ll & gamma data - Vector restored_leaf_ll, restored_floats; - const float* ll_data = nullptr; - const float* gamma_data = nullptr; - const float* leaf_ll_data = nullptr; - if (quantized) - { - if (quantized > 16) - { - throw std::runtime_error{ "16+ bits quantization not supported." }; - } - - restored_floats.resize(num_non_leaf_nodes * 2); - restored_leaf_ll.resize(num_leaf_nodes); - leaf_ll_data = restored_leaf_ll.data(); - ll_data = &restored_floats[0]; - gamma_data = &restored_floats[num_non_leaf_nodes]; - - const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); - const float* gamma_table = ll_table + ((size_t)1 << quantized); - - dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, - ptr + header.ll_offset, header.gamma_offset - header.ll_offset, - ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, - ll_table, - gamma_table, - num_non_leaf_nodes, - num_leaf_nodes - ); - extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); - } - else - { - ll_data = reinterpret_cast(ptr + header.ll_offset); - gamma_data = reinterpret_cast(ptr + header.gamma_offset); - leaf_ll_data = ll_data + num_non_leaf_nodes; - extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); - } - - size_t htx_vocab_size = header.vocab_size; - if (header.htx_offset) - { - htx_data = reinterpret_cast(ptr + header.htx_offset); - htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; - extra_buf = toAlignedPtr(htx_data + header.vocab_size); - } - - if (!header.extra_buf_size) - { - extra_buf = nullptr; - } + using VocabType = KeyType; + using LmStateType = KnLMState; - // restore node's data - node_data = make_unique(num_non_leaf_nodes); - all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); - value_data = &all_value_data[htx_vocab_size]; - std::fill(&all_value_data[0], value_data, 0); + KnLangModel(utils::MemoryObject&& mem); + ModelType getType() const override { return transposed ? ModelType::knlmTransposed : ModelType::knlm; } - size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; - Vector> key_ranges; - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) - { - auto& node = node_data[non_leaf_idx]; - if (!key_ranges.empty()) - { - auto& back = key_ranges.back(); - value_data[back[1]] = non_leaf_idx - back[0]; - } - node.num_nexts = node_sizes[i]; - node.next_offset = next_offset; - node.ll = ll_data[non_leaf_idx]; - node.gamma = gamma_data[non_leaf_idx]; - next_offset += node_sizes[i]; - key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); - non_leaf_idx++; - } - else - { - auto& back = key_ranges.back(); - reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; - back[1]++; - while (key_ranges.back()[1] == key_ranges.back()[2]) - { - key_ranges.pop_back(); - if (key_ranges.empty()) break; - key_ranges.back()[1]++; - } - leaf_idx++; - } - } - - for (size_t i = 0; i < node_data[0].num_nexts; ++i) - { - auto k = key_data[i]; - auto v = value_data[i]; - all_value_data[k] = v; - } - - Vector tempBuf; - for (size_t i = 0; i < non_leaf_idx; ++i) - { - auto& node = node_data[i]; - nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); - } - - if (htx_data) - { - ptrdiff_t node = 0; - progress(node, (KeyType)header.bos_id); - unk_ll = getLL(node, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); - } - else - { - unk_ll = getLL(0, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, (KeyType)header.bos_id); - } - - Deque dq; - for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) - { - auto p = dq.front(); - for (size_t i = 0; i < p->num_nexts; ++i) - { - auto k = key_data[p->next_offset + i]; - auto v = value_data[p->next_offset + i]; - if (v <= 0) continue; - auto* child = &p[v]; - child->lower = findLowerNode(p, k) - child; - dq.emplace_back(child); - } - } - } - - float getLL(ptrdiff_t node_idx, KeyType next) const - { - DiffType v; - auto* node = &node_data[node_idx]; - if (node_idx == 0) - { - v = all_value_data[next]; - if (v == 0) return unk_ll; - } - else - { - if (!nst::search( - &key_data[node->next_offset], - &value_data[node->next_offset], - node->num_nexts, next, v - )) - { - return node->gamma + getLL(node_idx + node->lower, next); - } - } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; - // non-leaf node - if (v > 0) - { - return node_data[node_idx + v].ll; - } - // leaf node - else - { - return reinterpret_cast(v); - } - } + float getLL(ptrdiff_t node_idx, KeyType next) const; template - float progress(IdxType& node_idx, KeyType next) const - { - float acc = 0; - while (1) - { - DiffType v; - auto* node = &node_data[node_idx]; - auto* keys = &key_data[node->next_offset]; - auto* values = &value_data[node->next_offset]; - PREFETCH_T0(node + node->lower); - if (node_idx == 0) - { - v = all_value_data[next]; - if (v == 0) - { - if (htx_data) - { - IdxType lv; - if (nst::search( - &key_data[0], - value_data, - node_data[0].num_nexts, htx_data[next], lv - )) node_idx = lv; - else node_idx = 0; - } - return acc + unk_ll; - } - } - else - { - if (!nst::search( - keys, - values, - node->num_nexts, next, v - )) - { - acc += node->gamma; - node_idx += node->lower; - PREFETCH_T0(&key_data[node_data[node_idx].next_offset]); - continue; - } - } - - // non-leaf node - if (v > 0) - { - node_idx += v; - return acc + node_data[node_idx].ll; - } - // leaf node - else - { - while (node->lower) - { - node += node->lower; - DiffType lv; - if (nst::search( - &key_data[node->next_offset], - &value_data[node->next_offset], - node->num_nexts, next, lv - )) - { - if (lv > 0) - { - node += lv; - node_idx = node - &node_data[0]; - return acc + reinterpret_cast(v); - } - } - } - if (htx_data) - { - IdxType lv; - if (nst::search( - &key_data[0], - value_data, - node_data[0].num_nexts, htx_data[next], lv - )) node_idx = lv; - else node_idx = 0; - } - else node_idx = 0; - return acc + reinterpret_cast(v); - } - } - } + float progress(IdxType& node_idx, KeyType next) const; float _progress(ptrdiff_t& node_idx, size_t next) const override { @@ -723,6 +367,100 @@ namespace kiwi } }; + template + struct KnLMState : public LmStateBase> + { + int32_t node = 0; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = _transposed; + + KnLMState() = default; + KnLMState(const ILangModel* lm) : node{ (int32_t)static_cast*>(lm)->getBosNodeIdx() } {} + + bool operator==(const KnLMState& other) const + { + return node == other.node; + } + + float nextImpl(const KnLangModel<_arch, VocabTy, transposed>* lm, VocabTy next) + { + return lm->progress(node, next); + } + + }; + + static constexpr size_t serialAlignment = 16; + + using QCode = qe::QCode<0, 2, 8, 16>; + + template + inline void dequantize( + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + FixedLengthEncoder llq{ llq_data, (ptrdiff_t)llq_size }; + FixedLengthEncoder gammaq{ gammaq_data, (ptrdiff_t)gammaq_size }; + + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i] = ll_table[llq.read()]; + } + + for (size_t i = 0; i < num_leaf_nodes; ++i) + { + restored_leaf_ll[i] = ll_table[llq.read()]; + } + + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i + num_non_leaf_nodes] = gamma_table[gammaq.read()]; + } + } + + template<> + inline void dequantize<8>( + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + const uint8_t* non_leaf_q = reinterpret_cast(llq_data); + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i] = ll_table[non_leaf_q[i]]; + } + + const uint8_t* leaf_q = reinterpret_cast(llq_data + num_non_leaf_nodes); + for (size_t i = 0; i < num_leaf_nodes; ++i) + { + restored_leaf_ll[i] = ll_table[leaf_q[i]]; + } + + const uint8_t* gamma_q = reinterpret_cast(gammaq_data); + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i + num_non_leaf_nodes] = gamma_table[gamma_q[i]]; + } + } + + inline const void* toAlignedPtr(const void* ptr, size_t alignment = serialAlignment) + { + auto addr = reinterpret_cast(ptr); + return reinterpret_cast((addr + alignment - 1) & ~(alignment - 1)); + } + + template void quantize(const std::vector& ll_table, const std::vector& gamma_table, const std::vector& ll, const std::vector& leaf_ll, @@ -797,7 +535,7 @@ namespace kiwi } template - utils::MemoryOwner buildCompressedModel(Header header, + utils::MemoryOwner buildCompressedModel(KnLangModelHeader header, const std::vector& min_cf_by_order, float unigram_alpha, utils::ContinuousTrie&& compressed_ngrams, @@ -968,7 +706,7 @@ namespace kiwi quantizeDispatch(tp::gen_seq<16>{}, quantized, ll_table, gamma_table, - ll, leaf_ll, gamma, + ll, leaf_ll, gamma, llq, gammaq ); } @@ -990,7 +728,7 @@ namespace kiwi size_t final_size = 0; - header.node_offset = alignedOffsetInc(final_size, sizeof(Header)); + header.node_offset = alignedOffsetInc(final_size, sizeof(KnLangModelHeader)); if (compressed) { header.key_offset = alignedOffsetInc(final_size, c_node_size.tellp()); @@ -1026,7 +764,7 @@ namespace kiwi utils::MemoryOwner ret{ final_size + extra_buf_size }; utils::omstream ostr{ (char*)ret.get(), (std::ptrdiff_t)ret.size() }; - ostr.write((const char*)&header, sizeof(Header)); + ostr.write((const char*)&header, sizeof(KnLangModelHeader)); writePadding(ostr); if (compressed) { @@ -1089,8 +827,8 @@ namespace kiwi }; template - utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, - size_t order, const std::vector& min_cf_by_order, + utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, + size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, const std::vector>* bigram_list, const HistoryTx* history_transformer, const void* extra_buf, size_t extra_buf_size @@ -1098,7 +836,7 @@ namespace kiwi { using TrieNode = typename GetNodeType::type>::type>::type; using Key = typename TrieNode::Key; - if (quantize > 16) throw std::invalid_argument{ "16+ bits quantization not supported."}; + if (quantize > 16) throw std::invalid_argument{ "16+ bits quantization not supported." }; size_t max_vid = 0; utils::ContinuousTrie compressed_ngrams{ 1 }; std::vector unigram_pats, unigram_cnts; @@ -1186,7 +924,7 @@ namespace kiwi denom = std::accumulate(unigram_cnts.begin(), unigram_cnts.end(), 0.); for (auto& p : unigram_cnts) p /= denom; - Header header = { 0, }; + KnLangModelHeader header = { 0, }; header.order = order; header.diff_size = 4; header.unk_id = unk_id; @@ -1198,36 +936,47 @@ namespace kiwi if (max_vid <= 0xFF) { - return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFF) { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFFFFFF) { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } } } -} \ No newline at end of file + + template + struct Hash> + { + size_t operator()(const lm::KnLMState& state) const + { + std::hash hasher; + return hasher(state.node); + } + }; + +} diff --git a/src/LmState.hpp b/src/LmState.hpp deleted file mode 100644 index a24be7ae..00000000 --- a/src/LmState.hpp +++ /dev/null @@ -1,304 +0,0 @@ -#pragma once - -#include -#include -#include "Knlm.hpp" -#include "SkipBigramModel.hpp" -#include "PCLanguageModel.hpp" - -namespace kiwi -{ - template - class VoidState - { - public: - static constexpr ArchType arch = _arch; - - VoidState() = default; - VoidState(const LangModel& lm) {} - - bool operator==(const VoidState& other) const - { - return true; - } - - float next(const LangModel& lm, size_t next) - { - return 0; - } - }; - - template - class KnLMState - { - friend struct Hash>; - int32_t node = 0; - public: - static constexpr ArchType arch = _arch; - static constexpr bool transposed = _transposed; - - KnLMState() = default; - KnLMState(const LangModel& lm) : node{ (int32_t)static_cast&>(*lm.knlm).getBosNodeIdx() } {} - - bool operator==(const KnLMState& other) const - { - return node == other.node; - } - - float next(const LangModel& lm, VocabTy next) - { - return static_cast&>(*lm.knlm).progress(node, next); - } - - void predict(const LangModel& lm, float* out) const - { - - } - }; - - template - class SbgState : public KnLMState<_arch, VocabTy> - { - friend struct Hash>; - size_t historyPos = 0; - std::array history = { {0,} }; - public: - static constexpr ArchType arch = _arch; - static constexpr bool transposed = false; - - SbgState() = default; - SbgState(const LangModel& lm) : KnLMState<_arch, VocabTy>{ lm } {} - - bool operator==(const SbgState& other) const - { - return KnLMState<_arch, VocabTy>::operator==(other) && historyPos == other.historyPos && history == other.history; - } - - void getLastHistory(VocabTy* out, size_t n) const - { - for (size_t i = 0; i < n; ++i) - { - out[i] = history[(historyPos + windowSize + i - n) % windowSize]; - } - } - - float next(const LangModel& lm, VocabTy next) - { - auto& sbg = static_cast&>(*lm.sbg); - float ll = KnLMState::next(lm, next); - if (sbg.isValidVocab(next)) - { - if (ll > -13) - { - ll = sbg.evaluate(history.data(), windowSize, next, ll); - } - history[historyPos] = next; - historyPos = (historyPos + 1) % windowSize; - } - return ll; - } - - void predict(const LangModel& lm, float* out) const - { - - } - }; - - template - class PcLMState - { - friend struct Hash>; - protected: - int32_t node = 0; - uint32_t contextIdx = 0; - public: - static constexpr ArchType arch = _arch; - static constexpr bool transposed = true; - - PcLMState() = default; - PcLMState(const LangModel& lm) {} - - bool operator==(const PcLMState& other) const - { - return node == other.node; - } - - float next(const LangModel& lm, VocabTy next) - { - auto& pclm = static_cast&>(*lm.pclm); - size_t historyPos = 0; - std::array history = { {0,} }; - return pclm.progress(node, contextIdx, historyPos, history, next); - } - }; - - template - class PcLMState : public PcLMState - { - static constexpr bool useDistantTokens = true; - friend struct Hash>; - protected: - size_t historyPos = 0; - std::array history = { {0,} }; - public: - static constexpr ArchType arch = _arch; - static constexpr bool transposed = true; - - PcLMState() = default; - PcLMState(const LangModel& lm) {} - - bool operator==(const PcLMState& other) const - { - return PcLMState::operator==(other) && historyPos == other.historyPos && history == other.history; - } - - float next(const LangModel& lm, VocabTy next) - { - auto& pclm = static_cast&>(*lm.pclm); - return pclm.progress(node, contextIdx, historyPos, history, next); - } - }; - - // hash for LmState - template - struct Hash> - { - size_t operator()(const VoidState& state) const - { - return 0; - } - }; - - template - struct Hash> - { - size_t operator()(const KnLMState& state) const - { - std::hash hasher; - return hasher(state.node); - } - }; - - template - struct Hash> - { - size_t operator()(const SbgState& state) const - { - Hash> hasher; - std::hash vocabHasher; - size_t ret = hasher(state); - for (size_t i = 0; i < windowSize; ++i) - { - ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); - } - return ret; - } - }; - - template - struct WrappedKnLM - { - template using type = KnLMState; - }; - - template - struct WrappedKnLMTransposed - { - template using type = KnLMState; - }; - - template - struct WrappedSbg - { - template using type = SbgState; - }; - - template - struct WrappedPcLM - { - template using type = PcLMState; - }; - - template - class LmObject : public LmObjectBase - { - LangModel mdl; - public: - LmObject(const LangModel& _mdl) : mdl(_mdl) - { - } - - size_t vocabSize() const override - { - return mdl.knlm->getHeader().vocab_size; - } - - template - float evalSequence(It first, It last) const - { - float ret = 0; - LmStateTy state{ mdl }; - for (; first != last; ++first) - { - ret += state.next(mdl, *first); - } - return ret; - } - - float evalSequence(const uint32_t* seq, size_t length, size_t stride) const override - { - float ret = 0; - LmStateTy state{ mdl }; - for (size_t i = 0; i < length; ++i) - { - ret += state.next(mdl, *seq); - seq = reinterpret_cast(reinterpret_cast(seq) + stride); - } - return ret; - } - - void predictNext(const uint32_t* seq, size_t length, size_t stride, float* outScores) const override - { - LmStateTy state{ mdl }; - for (size_t i = 0; i < length; ++i) - { - state.next(mdl, *seq); - seq = reinterpret_cast(reinterpret_cast(seq) + stride); - } - state.predict(mdl, outScores); - } - - void evalSequences( - const uint32_t* prefix, size_t prefixLength, size_t prefixStride, - const uint32_t* suffix, size_t suffixLength, size_t suffixStride, - size_t seqSize, const uint32_t** seq, const size_t* seqLength, const size_t* seqStride, float* outScores - ) const override - { - float ret = 0; - LmStateTy state{ mdl }; - for (size_t i = 0; i < prefixLength; ++i) - { - ret += state.next(mdl, *prefix); - prefix = reinterpret_cast(reinterpret_cast(prefix) + prefixStride); - } - - Vector states(seqSize, state); - std::fill(outScores, outScores + seqSize, ret); - for (size_t s = 0; s < seqSize; ++s) - { - auto p = seq[s]; - for (size_t i = 0; i < seqLength[s]; ++i) - { - outScores[s] += states[s].next(mdl, *p); - p = reinterpret_cast(reinterpret_cast(p) + seqStride[s]); - } - - for (size_t i = 0; i < suffixLength; ++i) - { - outScores[s] += states[s].next(mdl, *suffix); - suffix = reinterpret_cast(reinterpret_cast(suffix) + suffixStride); - } - } - } - }; -} diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index 88d3e7b2..c5e3c220 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -1,5 +1,8 @@ #include #include +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" #include "PCLanguageModel.hpp" #include "StrUtils.h" #include "FrozenTrie.hpp" @@ -8,9 +11,263 @@ using namespace std; namespace kiwi { - namespace pclm + namespace lm { - utils::MemoryObject PCLanguageModelBase::build(const string& contextDefinition, const string& embedding, bool reorderContextId) + inline float half2float(uint16_t h) + { + union + { + uint32_t i; + float f; + } u; + u.i = (uint32_t)(h & 0x8000) << 16; + u.i |= ((uint32_t)(h & 0x7FFF) + 0x1C000) << 13; + return u.f; + } + + inline void dequantize(float* out, const int8_t* ints, size_t n, float scale) + { + for (size_t i = 0; i < n; ++i) + { + out[i] = ints[i] * scale; + } + } + + template + void logsoftmaxInplace(Arr& arr) + { + arr -= arr.maxCoeff(); + arr -= std::log(arr.exp().sum()); + } + + template + PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ std::move(mem) } + { + auto* ptr = reinterpret_cast(base.get()); + auto& header = getHeader(); + + Vector nodeSizes(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); + keyData = make_unique(header.numNodes - 1); + if (std::is_same::value) + { + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); + } + else + { + Vector tempKeyData(header.numNodes - 1); + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), tempKeyData.data(), header.numNodes - 1); + std::copy(tempKeyData.begin(), tempKeyData.end(), keyData.get()); + } + Vector values(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.valueOffset), values.data(), header.numNodes); + + size_t numNonLeafNodes = 0, numLeafNodes = 0; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) numNonLeafNodes++; + else numLeafNodes++; + } + + nodeData = make_unique(numNonLeafNodes); + valueData = make_unique(header.numNodes - 1); + + size_t nonLeafIdx = 0, leafIdx = 0, nextOffset = 0; + Vector> keyRanges; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) + { + auto& node = nodeData[nonLeafIdx]; + if (!keyRanges.empty()) + { + auto& back = keyRanges.back(); + valueData[back[1]] = nonLeafIdx - back[0]; + } + node.value = values[i]; + node.numNexts = nodeSizes[i]; + node.nextOffset = nextOffset; + nextOffset += nodeSizes[i]; + keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)node.nextOffset, (size_t)(node.nextOffset + node.numNexts) }); + nonLeafIdx++; + } + else + { + auto& back = keyRanges.back(); + valueData[back[1]] = -(int32_t)values[i]; + back[1]++; + while (keyRanges.back()[1] == keyRanges.back()[2]) + { + keyRanges.pop_back(); + if (keyRanges.empty()) break; + keyRanges.back()[1]++; + } + leafIdx++; + } + } + + Vector tempBuf; + for (size_t i = 0; i < nonLeafIdx; ++i) + { + auto& node = nodeData[i]; + nst::prepare(&keyData[node.nextOffset], &valueData[node.nextOffset], node.numNexts, tempBuf); + } + + Deque dq; + for (dq.emplace_back(&nodeData[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->numNexts; ++i) + { + auto k = keyData[p->nextOffset + i]; + auto v = valueData[p->nextOffset + i]; + if (v <= 0) continue; + auto* child = &p[v]; + child->lower = findLowerNode(p, k) - child; + if (child->value == 0) + { + child->value = findLowerValue(p, k); + } + dq.emplace_back(child); + } + } + + auto* eptr = ptr + header.embOffset; + contextEmb = make_unique(header.contextSize * header.dim); + contextBias = make_unique(header.contextSize); + contextValidTokenSum = make_unique(header.contextSize); + contextConf = make_unique(header.contextSize); + if (useDistantTokens) + { + distantEmb = make_unique(header.vocabSize * header.dim); + distantBias = make_unique(header.vocabSize); + distantConf = make_unique(header.vocabSize); + positionConf = make_unique(header.windowSize); + } + outputEmb = make_unique(header.vocabSize * header.dim); + + const uint16_t* contextEmbScale = reinterpret_cast(eptr + header.contextSize * header.dim); + for (size_t i = 0; i < header.contextSize; ++i) + { + dequantize(&contextEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(contextEmbScale[i])); + eptr += header.dim; + } + eptr += header.contextSize * sizeof(uint16_t); + for (size_t i = 0; i < header.contextSize; ++i) + { + contextBias[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.contextSize; ++i) + { + contextValidTokenSum[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.contextSize; ++i) + { + contextConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + const uint16_t* distantEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); + for (size_t i = 0; i < header.vocabSize; ++i) + { + if (useDistantTokens) dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); + eptr += header.dim; + } + eptr += header.vocabSize * sizeof(uint16_t); + for (size_t i = 0; i < header.vocabSize; ++i) + { + if (useDistantTokens) distantBias[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.vocabSize; ++i) + { + if (useDistantTokens) distantConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + for (size_t i = 0; i < header.windowSize; ++i) + { + if (useDistantTokens) positionConf[i] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + const uint16_t* outputEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); + for (size_t i = 0; i < header.vocabSize; ++i) + { + dequantize(&outputEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(outputEmbScale[i])); + eptr += header.dim; + } + eptr += header.vocabSize * sizeof(uint16_t); + + if (useDistantTokens) + { + const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; + distantMask = make_unique(compressedDistantMaskSize); + std::copy(eptr, eptr + compressedDistantMaskSize, distantMask.get()); + } + } + + template + float PcLangModel::progress(int32_t& nodeIdx, + uint32_t& contextIdx, + size_t& historyPos, + std::array& history, + KeyType next) const + { + const auto& header = getHeader(); + const bool validDistantToken = distantTokenMask(next); + float ll = 0; + + thread_local Eigen::MatrixXf mat; + mat.resize(header.dim, 1 + windowSize); + thread_local Eigen::VectorXf lls; + lls.resize(1 + windowSize); + if (useDistantTokens && validDistantToken) + { + lls[0] = contextConf[contextIdx]; + lls.tail(windowSize) = Eigen::Map{ &positionConf[0], windowSize }; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + lls[i + 1] += historyToken ? distantConf[historyToken] : -99999; + } + logsoftmaxInplace(lls.array()); + + mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; + lls[0] -= contextBias[contextIdx]; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + if (historyToken) mat.col(i + 1) = Eigen::Map{ &distantEmb[historyToken * header.dim], header.dim }; + else mat.col(i + 1).setZero(); + lls[i + 1] -= distantBias[historyToken]; + } + lls.tail(windowSize).array() += contextValidTokenSum[contextIdx]; + Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; + lls += mat.transpose() * outputVec; + ll = LogExpSum{}(lls.data(), std::integral_constant()); + } + else + { + lls[0] = -contextBias[contextIdx]; + mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; + Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; + lls.head(1) += mat.transpose() * outputVec; + ll = lls[0]; + } + + contextIdx = progressContextNode(nodeIdx, next); + if (history[windowSize]) + { + history[historyPos] = history[windowSize]; + historyPos = (historyPos + 1) % windowSize; + } + history[windowSize] = validDistantToken ? next : 0; + return ll; + } + + utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, bool reorderContextId) { ifstream contextStr, embeddingStr; if (!openFile(contextStr, contextDefinition)) @@ -211,8 +468,8 @@ namespace kiwi distantMask.resize(compressedDistantMaskSize); } - Header header; - memset(&header, 0, sizeof(Header)); + PcLangModelHeader header; + memset(&header, 0, sizeof(PcLangModelHeader)); header.dim = dim; header.contextSize = contextSize; header.vocabSize = outputSize; @@ -221,7 +478,7 @@ namespace kiwi header.numNodes = nodeSizes.size(); size_t finalSize = 0; - header.nodeOffset = alignedOffsetInc(finalSize, sizeof(Header)); + header.nodeOffset = alignedOffsetInc(finalSize, sizeof(PcLangModelHeader)); header.keyOffset = alignedOffsetInc(finalSize, compressedNodeSizes.size()); header.valueOffset = alignedOffsetInc(finalSize, compressedKeys.size()); header.embOffset = alignedOffsetInc(finalSize, compressedValues.size()); @@ -233,7 +490,7 @@ namespace kiwi utils::MemoryOwner mem{ finalSize }; utils::omstream ostr{ (char*)mem.get(), (std::ptrdiff_t)mem.size() }; - ostr.write((const char*)&header, sizeof(Header)); + ostr.write((const char*)&header, sizeof(PcLangModelHeader)); writePadding(ostr); ostr.write((const char*)compressedNodeSizes.data(), compressedNodeSizes.size()); writePadding(ostr); @@ -257,27 +514,39 @@ namespace kiwi return mem; } + template + void* PcLangModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder::findBestPath>; + } + + template + void* PcLangModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + template - inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) + inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) { - auto& header = *reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(mem.get()); switch (header.windowSize) { case 4: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 7: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 8: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; }; } template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { - auto& header = *reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(mem.get()); switch (header.keySize) { case 1: @@ -303,7 +572,7 @@ namespace kiwi }; }; - std::unique_ptr PCLanguageModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens) + std::unique_ptr PcLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens) { static tp::Table tableWithoutDistantTokens{ CreateOptimizedModelGetter{} }, tableWithDistantTokens{ CreateOptimizedModelGetter{} }; diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index 4d73c0af..c6e1e5b7 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -13,37 +13,13 @@ namespace kiwi { - namespace pclm + namespace lm { - inline float half2float(uint16_t h) - { - union - { - uint32_t i; - float f; - } u; - u.i = (uint32_t)(h & 0x8000) << 16; - u.i |= ((uint32_t)(h & 0x7FFF) + 0x1C000) << 13; - return u.f; - } - - inline void dequantize(float* out, const int8_t* ints, size_t n, float scale) - { - for (size_t i = 0; i < n; ++i) - { - out[i] = ints[i] * scale; - } - } - - template - void logsoftmaxInplace(Arr& arr) - { - arr -= arr.maxCoeff(); - arr -= std::log(arr.exp().sum()); - } + template + class PcLMState; template - class PCLanguageModel : public PCLanguageModelBase + class PcLangModel : public PcLangModelBase { using MyNode = Node; @@ -115,172 +91,13 @@ namespace kiwi } public: - PCLanguageModel(utils::MemoryObject&& mem) : PCLanguageModelBase{ std::move(mem) } - { - auto* ptr = reinterpret_cast(base.get()); - auto& header = getHeader(); - - Vector nodeSizes(header.numNodes); - streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); - keyData = make_unique(header.numNodes - 1); - if (std::is_same::value) - { - streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); - } - else - { - Vector tempKeyData(header.numNodes - 1); - streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), tempKeyData.data(), header.numNodes - 1); - std::copy(tempKeyData.begin(), tempKeyData.end(), keyData.get()); - } - Vector values(header.numNodes); - streamvbyte_decode_0124(reinterpret_cast(ptr + header.valueOffset), values.data(), header.numNodes); - - size_t numNonLeafNodes = 0, numLeafNodes = 0; - for (size_t i = 0; i < header.numNodes; ++i) - { - if (nodeSizes[i]) numNonLeafNodes++; - else numLeafNodes++; - } - - nodeData = make_unique(numNonLeafNodes); - valueData = make_unique(header.numNodes - 1); - - size_t nonLeafIdx = 0, leafIdx = 0, nextOffset = 0; - Vector> keyRanges; - for (size_t i = 0; i < header.numNodes; ++i) - { - if (nodeSizes[i]) - { - auto& node = nodeData[nonLeafIdx]; - if (!keyRanges.empty()) - { - auto& back = keyRanges.back(); - valueData[back[1]] = nonLeafIdx - back[0]; - } - node.value = values[i]; - node.numNexts = nodeSizes[i]; - node.nextOffset = nextOffset; - nextOffset += nodeSizes[i]; - keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)node.nextOffset, (size_t)(node.nextOffset + node.numNexts) }); - nonLeafIdx++; - } - else - { - auto& back = keyRanges.back(); - valueData[back[1]] = -(int32_t)values[i]; - back[1]++; - while (keyRanges.back()[1] == keyRanges.back()[2]) - { - keyRanges.pop_back(); - if (keyRanges.empty()) break; - keyRanges.back()[1]++; - } - leafIdx++; - } - } + using VocabType = KeyType; + using LmStateType = PcLMState; - Vector tempBuf; - for (size_t i = 0; i < nonLeafIdx; ++i) - { - auto& node = nodeData[i]; - nst::prepare(&keyData[node.nextOffset], &valueData[node.nextOffset], node.numNexts, tempBuf); - } - - Deque dq; - for (dq.emplace_back(&nodeData[0]); !dq.empty(); dq.pop_front()) - { - auto p = dq.front(); - for (size_t i = 0; i < p->numNexts; ++i) - { - auto k = keyData[p->nextOffset + i]; - auto v = valueData[p->nextOffset + i]; - if (v <= 0) continue; - auto* child = &p[v]; - child->lower = findLowerNode(p, k) - child; - if (child->value == 0) - { - child->value = findLowerValue(p, k); - } - dq.emplace_back(child); - } - } + PcLangModel(utils::MemoryObject&& mem); - auto* eptr = ptr + header.embOffset; - contextEmb = make_unique(header.contextSize * header.dim); - contextBias = make_unique(header.contextSize); - contextValidTokenSum = make_unique(header.contextSize); - contextConf = make_unique(header.contextSize); - if (useDistantTokens) - { - distantEmb = make_unique(header.vocabSize * header.dim); - distantBias = make_unique(header.vocabSize); - distantConf = make_unique(header.vocabSize); - positionConf = make_unique(header.windowSize); - } - outputEmb = make_unique(header.vocabSize * header.dim); - - const uint16_t* contextEmbScale = reinterpret_cast(eptr + header.contextSize * header.dim); - for (size_t i = 0; i < header.contextSize; ++i) - { - dequantize(&contextEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(contextEmbScale[i])); - eptr += header.dim; - } - eptr += header.contextSize * sizeof(uint16_t); - for (size_t i = 0; i < header.contextSize; ++i) - { - contextBias[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - for (size_t i = 0; i < header.contextSize; ++i) - { - contextValidTokenSum[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - for (size_t i = 0; i < header.contextSize; ++i) - { - contextConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - - const uint16_t* distantEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); - for (size_t i = 0; i < header.vocabSize; ++i) - { - if (useDistantTokens) dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); - eptr += header.dim; - } - eptr += header.vocabSize * sizeof(uint16_t); - for (size_t i = 0; i < header.vocabSize; ++i) - { - if (useDistantTokens) distantBias[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - for (size_t i = 0; i < header.vocabSize; ++i) - { - if (useDistantTokens) distantConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - for (size_t i = 0; i < header.windowSize; ++i) - { - if (useDistantTokens) positionConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); - } - - const uint16_t* outputEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); - for (size_t i = 0; i < header.vocabSize; ++i) - { - dequantize(&outputEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(outputEmbScale[i])); - eptr += header.dim; - } - eptr += header.vocabSize * sizeof(uint16_t); - - if (useDistantTokens) - { - const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; - distantMask = make_unique(compressedDistantMaskSize); - std::copy(eptr, eptr + compressedDistantMaskSize, distantMask.get()); - } - } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; uint32_t progressContextNode(int32_t& nodeIdx, KeyType next) const { @@ -342,81 +159,90 @@ namespace kiwi else return false; } - float progress(int32_t& nodeIdx, - uint32_t& contextIdx, - size_t& historyPos, - std::array& history, - KeyType next) const - { - const auto& header = getHeader(); - const bool validDistantToken = distantTokenMask(next); - float ll = 0; - - thread_local Eigen::MatrixXf mat; - mat.resize(header.dim, 1 + windowSize); - thread_local Eigen::VectorXf lls; - lls.resize(1 + windowSize); - if (useDistantTokens && validDistantToken) - { - lls[0] = contextConf[contextIdx]; - lls.tail(windowSize) = Eigen::Map{ &positionConf[0], windowSize }; - for (size_t i = 0; i < windowSize; ++i) - { - const auto historyToken = history[(historyPos + i) % windowSize]; - lls[i + 1] += historyToken ? distantConf[historyToken] : -99999; - } - logsoftmaxInplace(lls.array()); + float progress(int32_t& nodeIdx, + uint32_t& contextIdx, + size_t& historyPos, + std::array& history, + KeyType next) const; - mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; - lls[0] -= contextBias[contextIdx]; - for (size_t i = 0; i < windowSize; ++i) - { - const auto historyToken = history[(historyPos + i) % windowSize]; - if (historyToken) mat.col(i + 1) = Eigen::Map{ &distantEmb[historyToken * header.dim], header.dim }; - else mat.col(i + 1).setZero(); - lls[i + 1] -= distantBias[historyToken]; - } - lls.tail(windowSize).array() += contextValidTokenSum[contextIdx]; - Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; - lls += mat.transpose() * outputVec; - ll = sb::LogExpSum{}(lls.data(), std::integral_constant()); - } - else - { - lls[0] = -contextBias[contextIdx]; - mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; - Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; - lls.head(1) += mat.transpose() * outputVec; - ll = lls[0]; - } + }; - contextIdx = progressContextNode(nodeIdx, next); - if (history[windowSize]) - { - history[historyPos] = history[windowSize]; - historyPos = (historyPos + 1) % windowSize; - } - history[windowSize] = validDistantToken ? next : 0; - return ll; + template + struct PcLMState : public LmStateBase> + { + int32_t node = 0; + uint32_t contextIdx = 0; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; + + PcLMState() = default; + PcLMState(const ILangModel* lm) {} + + bool operator==(const PcLMState& other) const + { + return node == other.node; + } + + float nextImpl(const PcLangModel* lm, VocabTy next) + { + size_t historyPos = 0; + std::array history = { {0,} }; + return lm->progress(node, contextIdx, historyPos, history, next); } }; - static constexpr size_t serialAlignment = 16; + template + struct PcLMState : public LmStateBase> + { + static constexpr bool useDistantTokens = true; + + int32_t node = 0; + uint32_t contextIdx = 0; + size_t historyPos = 0; + std::array history = { {0,} }; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; - inline size_t alignedOffsetInc(size_t& offset, size_t inc, size_t alignment = serialAlignment) + PcLMState() = default; + PcLMState(const ILangModel* lm) {} + + bool operator==(const PcLMState& other) const + { + return node == other.node && historyPos == other.historyPos && history == other.history; + } + + float nextImpl(const PcLangModel* lm, VocabTy next) + { + return lm->progress(node, contextIdx, historyPos, history, next); + } + }; + } + + template + struct Hash> + { + size_t operator()(const lm::PcLMState& state) const { - return offset = (offset + inc + alignment - 1) & ~(alignment - 1); + Hash hasher; + return hasher(state.node); } + }; - inline std::ostream& writePadding(std::ostream& os, size_t alignment = serialAlignment) + template + struct Hash> + { + size_t operator()(const lm::PcLMState& state) const { - const size_t pos = os.tellp(); - size_t pad = ((pos + alignment - 1) & ~(alignment - 1)) - pos; - for (size_t i = 0; i < pad; ++i) + Hash hasher; + std::hash vocabHasher; + size_t ret = hasher(state.node); + for (size_t i = 0; i < state.history.size(); ++i) { - os.put(0); + ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); } - return os; + return ret; } - } + }; } diff --git a/src/PathEvaluator.cpp b/src/PathEvaluator.cpp deleted file mode 100644 index 77f1e664..00000000 --- a/src/PathEvaluator.cpp +++ /dev/null @@ -1,126 +0,0 @@ -#include "PathEvaluator.hpp" - -using namespace std; - -namespace kiwi -{ - template class LmState> - struct FindBestPathGetter - { - template - struct Wrapper - { - static constexpr FnFindBestPath value = &BestPathFinder::findBestPath(i)>>; - }; - }; - - - template - inline FnFindBestPath getPcLMFindBestPath(ArchType archType, size_t windowSize) - { - static tp::Table w4{ FindBestPathGetter::type>{} }; - static tp::Table w7{ FindBestPathGetter::type>{} }; - static tp::Table w8{ FindBestPathGetter::type>{} }; - switch (windowSize) - { - case 4: - return w4[static_cast(archType)]; - case 7: - return w7[static_cast(archType)]; - case 8: - return w8[static_cast(archType)]; - default: - throw Exception{ "Unsupported `window_size` : " + to_string(windowSize) }; - } - } - - FnFindBestPath getFindBestPathFn(ArchType archType, const LangModel& langMdl) - { - const auto archIdx = static_cast(archType); - static tp::Table lmKnLM_8{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_32{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_8{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_16{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLMT_32{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_8{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_16{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_32{ FindBestPathGetter::type>{} }; - - if (langMdl.type == ModelType::sbg) - { - switch (langMdl.sbg->getHeader().keySize) - { - case 1: - return lmSbg_8[archIdx]; - case 2: - return lmSbg_16[archIdx]; - case 4: - return lmSbg_32[archIdx]; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::knlm) - { - switch (langMdl.knlm->getHeader().key_size) - { - case 1: - return lmKnLM_8[archIdx]; - break; - case 2: - return lmKnLM_16[archIdx]; - break; - case 4: - return lmKnLM_32[archIdx]; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::knlmTransposed) - { - switch (langMdl.knlm->getHeader().key_size) - { - case 1: - return lmKnLMT_8[archIdx]; - break; - case 2: - return lmKnLMT_16[archIdx]; - break; - case 4: - return lmKnLMT_32[archIdx]; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::pclm) - { - switch (langMdl.pclm->getHeader().keySize) - { - case 2: - return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); - case 4: - return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if (langMdl.type == ModelType::pclmLocal) - { - switch (langMdl.pclm->getHeader().keySize) - { - case 2: - return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); - case 4: - return getPcLMFindBestPath(archType, langMdl.pclm->getHeader().windowSize); - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else - { - - } - return nullptr; - } -} diff --git a/src/PathEvaluator.h b/src/PathEvaluator.h index 61823660..b6d4b1d0 100644 --- a/src/PathEvaluator.h +++ b/src/PathEvaluator.h @@ -85,7 +85,7 @@ namespace kiwi struct BestPathFinder { - template + template static Vector findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, @@ -100,6 +100,4 @@ namespace kiwi }; using FnFindBestPath = decltype(&BestPathFinder::findBestPath); - - FnFindBestPath getFindBestPathFn(ArchType archType, const LangModel& langMdl); } diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 23d04db5..ae1bba27 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -4,244 +4,21 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "FeatureTestor.h" #include "FrozenTrie.hpp" -#include "LmState.hpp" #include "StrUtils.h" #include "SortUtils.hpp" #include "LimitedVector.hpp" #include "PathEvaluator.h" +#include "BestPathContainer.hpp" using namespace std; namespace kiwi { - template - struct WordLL; - - using Wid = uint32_t; - - enum class PathEvaluatingMode - { - topN, - top1, - top1Small, - }; - - template - struct PathEvaluator - { - template - static void eval(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - - template - static void evalSingleMorpheme( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ); - }; - - // evaluator using transposed order - template - struct PathEvaluator::type> - { - template - static void eval(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - - template - static void evalMorphemes( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Vector& morphs, - const Vector& morphScores, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const size_t totalPrevPathes, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ); - }; - - template - struct WordLL - { - const Morpheme* morpheme = nullptr; - float accScore = 0, accTypoCost = 0; - const WordLL* parent = nullptr; - LmState lmState; - Wid wid = 0; - uint16_t ownFormId = 0; - uint8_t combineSocket = 0; - uint8_t rootId = 0; - SpecialState spState; - - WordLL() = default; - - WordLL(const Morpheme* _morph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) - : morpheme{ _morph }, - accScore{ _accScore }, - accTypoCost{ _accTypoCost }, - parent{ _parent }, - lmState{ _lmState }, - spState{ _spState }, - rootId{ parent ? parent->rootId : (uint8_t)0 } - { - } - - const WordLL* root() const - { - if (parent) return parent->root(); - else return this; - } - }; - - static constexpr uint8_t commonRootId = -1; - - template - struct PathHash - { - LmState lmState; - uint8_t rootId, spState; - - PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) - : lmState{ _lmState }, rootId{ _rootId }, spState { _spState } - { - } - - PathHash(const WordLL& wordLl, const Morpheme* morphBase) - : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } - { - } - - bool operator==(const PathHash& o) const - { - return lmState == o.lmState && rootId == o.rootId && spState == o.spState; - } - }; - - template - struct PathHash> - { - using LmState = SbgState; - - KnLMState<_arch, VocabTy> lmState; - array lastMorphemes; - uint8_t rootId, spState; - - PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) - : lmState{ _lmState }, rootId{ _rootId }, spState{ _spState } - { - _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); - } - - - PathHash(const WordLL& wordLl, const Morpheme* morphBase) - : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } - { - } - - bool operator==(const PathHash& o) const - { - return lmState == o.lmState && lastMorphemes == o.lastMorphemes && spState == o.spState; - } - }; - - template - struct Hash> - { - size_t operator()(const PathHash& p) const - { - size_t ret = 0; - if (sizeof(PathHash) % sizeof(size_t)) - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(uint32_t); ++i) - { - ret ^= ptr[i]; - } - } - else - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(size_t); ++i) - { - ret ^= ptr[i]; - } - } - return ret; - } - }; - - struct WordLLGreater - { - template - bool operator()(const WordLL& a, const WordLL& b) const - { - return a.accScore > b.accScore; - } - }; - - template - inline std::ostream& printDebugPath(std::ostream& os, const WordLL& src) - { - if (src.parent) - { - printDebugPath(os, *src.parent); - } - - if (src.morpheme) src.morpheme->print(os); - else os << "NULL"; - os << " , "; - return os; - } - inline bool hasLeftBoundary(const KGraphNode* node) { // 시작 지점은 항상 왼쪽 경계로 처리 @@ -403,988 +180,846 @@ namespace kiwi || m == Kiwi::SpecialMorph::doubleQuoteOpen || m == Kiwi::SpecialMorph::doubleQuoteClose; } - template - class BestPathConatiner; + template + struct LmEvalData + { + LmState state; + float score = 0; + uint32_t length = 0; + }; + + template + struct PathEvaluator; template - class BestPathConatiner + struct PathEvaluator::type> { - // pair: [index, size] - UnorderedMap, pair> bestPathIndex; - Vector> bestPathValues; - public: - inline void clear() + const Kiwi* kw; + const KGraphNode* startNode; + const size_t topN; + Vector>>& cache; + const Vector& ownFormList; + const Vector& prevSpStates; + + PathEvaluator(const Kiwi* _kw, + const KGraphNode* _startNode, + size_t _topN, + Vector>>& _cache, + const Vector& _ownFormList, + const Vector& _prevSpStates + ) + : kw{ _kw }, startNode{ _startNode }, topN{ _topN }, cache{ _cache }, ownFormList{ _ownFormList }, prevSpStates{ _prevSpStates } { - bestPathIndex.clear(); - bestPathValues.clear(); } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + template + void operator()( + const size_t nodeIdx, + const size_t ownFormId, + CandTy&& cands, + bool unknownForm, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ) const { - auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); - if (inserted.second) + const size_t langVocabSize = kw->langMdl->vocabSize(); + auto* const node = startNode + nodeIdx; + auto& nCache = cache[nodeIdx]; + Vector> refCache; + + float whitespaceDiscount = 0; + if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) { - bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); - if (rootId != commonRootId) bestPathValues.back().rootId = rootId; - bestPathValues.resize(bestPathValues.size() + topN - 1); + whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; } - else + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) { - auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; - auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; - if (distance(bestPathFirst, bestPathLast) < topN) - { - *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) bestPathLast->rootId = rootId; - push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); - ++inserted.first->second.second; - } - else - { - if (accScore > bestPathFirst->accScore) - { - pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; - push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - } - } + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); } - } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) - { - for (auto& p : bestPathIndex) + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; + + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + totalPrevPathes += cache[prev - startNode].size(); + } + const bool useContainerForSmall = totalPrevPathes <= 48; + + for (bool ignoreCond : {false, true}) { - const auto index = p.second.first; - const auto size = p.second.second; - for (size_t i = 0; i < size; ++i) + for (auto& curMorph : cands) { - resultOut.emplace_back(move(bestPathValues[index + i])); - auto& newPath = resultOut.back(); + if (splitComplex && curMorph->getCombined()->complex) continue; + if (blocklist && blocklist->count(curMorph->getCombined())) continue; - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + // 덧붙은 받침(zCoda)을 위한 지름길 + if (curMorph->tag == POSTag::z_coda) { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isJClass(lastTag) && !isEClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += curMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= curMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; + newPath.wid = curMorph->lmMorphemeId; + } + } + continue; } - } - } - } - }; + // 사이시옷(zSiot)을 위한 지름길 + if (curMorph->tag == POSTag::z_siot) + { + if (!(splitSaisiot || mergeSaisiot)) + { + continue; + } - template - class BestPathConatiner - { - UnorderedMap, WordLL> bestPathes; - public: - inline void clear() - { - bestPathes.clear(); - } + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isNNClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += curMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= curMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; + newPath.wid = curMorph->lmMorphemeId; + } + } + continue; + } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) - { - WordLL newPath{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) newPath.rootId = rootId; - auto inserted = bestPathes.emplace(ph, newPath); - if (!inserted.second) - { - auto& target = inserted.first->second; - if (accScore > target.accScore) - { - target = newPath; - } - } - } + // if the morpheme has chunk set + if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + { + // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 + if (node->prev && node[-(int)node->prev].endPos < node->startPos + && curMorph->kform + && curMorph->kform->size() == 1 + && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') + && curMorph->chunks[0]->kform + && curMorph->chunks[0]->kform->size() == 1 + && (*curMorph->chunks[0]->kform)[0] == u'하') + { + continue; + } + } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) - { - for (auto& p : bestPathes) - { - resultOut.emplace_back(move(p.second)); - auto& newPath = resultOut.back(); + if (topN > 1) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + else if (useContainerForSmall) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + else + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; } + if (!nCache.empty()) break; } - } - }; - template - class BestPathConatiner - { - Vector> bestPathIndicesSmall; - Vector> bestPathValuesSmall; - public: + thread_local Vector maxScores; + maxScores.clear(); + maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); - inline void clear() - { - bestPathIndicesSmall.clear(); - bestPathValuesSmall.clear(); - } - - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) - { - const auto it = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph); - if (it == bestPathIndicesSmall.end()) + if (topN == 1) { - bestPathIndicesSmall.push_back(ph); - bestPathValuesSmall.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); - if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + maxScores[rootId] = max(maxScores[rootId], c.accScore); + } } else { - auto& target = bestPathValuesSmall[it - bestPathIndicesSmall.begin()]; - if (accScore > target.accScore) + for (auto& c : nCache) { - target = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) target.rootId = rootId; + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + if (c.accScore > maxScores[rootId * topN]) + { + pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + maxScores[rootId * topN + topN - 1] = c.accScore; + push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + } } } - } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) - { - for (auto& p : bestPathValuesSmall) + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) { - resultOut.emplace_back(move(p)); - auto& newPath = resultOut.back(); - - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } + const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; + if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; } + nCache.resize(validCount); } - }; - template - template - void PathEvaluator::evalSingleMorpheme( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ) - { - thread_local BestPathConatiner bestPathCont; - thread_local Vector rootIds; + template + void evalSingleMorpheme( + Vector>& resultOut, + const KGraphNode* node, + const size_t ownFormId, + const Morpheme* curMorph, + const float ignoreCondScore, + const float nodeLevelDiscount + ) const + { + thread_local BestPathConatiner bestPathCont; + thread_local Vector rootIds; - const LangModel& langMdl = kw->langMdl; - const Morpheme* morphBase = kw->morphemes.data(); - const auto spacePenalty = kw->spacePenalty; - const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const auto* langMdl = kw->getLangModel(); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; - const size_t langVocabSize = langMdl.vocabSize(); + const size_t langVocabSize = langMdl->vocabSize(); - const Morpheme* lastMorph; - Wid firstWid; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; - firstWid = curMorph->lmMorphemeId; - } - // if the morpheme has chunk set - else - { - lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; - firstWid = curMorph->chunks[0]->lmMorphemeId; - } + const Morpheme* lastMorph; + Wid firstWid; + if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + firstWid = curMorph->lmMorphemeId; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + firstWid = curMorph->chunks[0]->lmMorphemeId; + } - Wid lastSeqId; - if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) - { - lastSeqId = lastMorph - kw->morphemes.data(); - } - else - { - lastSeqId = lastMorph->lmMorphemeId; - } + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } - bestPathCont.clear(); - const float additionalScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + bestPathCont.clear(); + const float additionalScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); - RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - for (auto& prevPath : cache[prev - startNode]) + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { - // 사이시옷 뒤에 명사가 아닌 태그가 오거나 공백이 있는 경우 제외 - if (prevPath.morpheme->tag == POSTag::z_siot && ( - !isNNClass(curMorph->tag) || prev->endPos < node->startPos - )) - { - continue; - } - - float candScore = prevPath.accScore + additionalScore; - if (prevPath.combineSocket) + for (auto& prevPath : cache[prev - startNode]) { - // merge with only the same socket - if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + // 사이시옷 뒤에 명사가 아닌 태그가 오거나 공백이 있는 경우 제외 + if (prevPath.morpheme->tag == POSTag::z_siot && ( + !isNNClass(curMorph->tag) || prev->endPos < node->startPos + )) { continue; } - if (prev->endPos < node->startPos) + + float candScore = prevPath.accScore + additionalScore; + if (prevPath.combineSocket) { - if (allowedSpaceBetweenChunk) candScore -= spacePenalty; - else continue; + // merge with only the same socket + if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + { + continue; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) candScore -= spacePenalty; + else continue; + } + firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; } - firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; - } - const kchar_t* leftFormFirst, * leftFormLast; - if (prevPath.ownFormId) - { - leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); - leftFormLast = leftFormFirst + ownForms[0].size(); - } - else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) - { - leftFormFirst = morphBase[prevPath.wid].kform->data(); - leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); - } - else - { - leftFormFirst = prevPath.morpheme->getForm().data(); - leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); - } + const kchar_t* leftFormFirst, * leftFormLast; + if (prevPath.ownFormId) + { + leftFormFirst = ownFormList[prevPath.ownFormId - 1].data(); + leftFormLast = leftFormFirst + ownFormList[0].size(); + } + else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) + { + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); + } + else + { + leftFormFirst = prevPath.morpheme->getForm().data(); + leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); + } - const CondVowel cvowel = curMorph->vowel; - const CondPolarity cpolar = curMorph->polar; - const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; - if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) - { - // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 - } - else if (ignoreCondScore) - { - candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; - } - else - { - if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) continue; - } + const CondVowel cvowel = curMorph->vowel; + const CondPolarity cpolar = curMorph->polar; + const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; + if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) + { + // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 + } + else if (ignoreCondScore) + { + candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + } + else + { + if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) continue; + } - auto cLmState = prevPath.lmState; - if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) - { - // no-op - } - else - { - if (morphBase[firstWid].tag == POSTag::p) + auto cLmState = prevPath.lmState; + if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) { - // prohibit without - goto continueFor; + // no-op } - float ll = cLmState.next(langMdl, firstWid); - candScore += ll; - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + else { - for (size_t i = 1; i < curMorph->chunks.size(); ++i) + if (morphBase[firstWid].tag == POSTag::p) + { + // prohibit without + goto continueFor; + } + float ll = cLmState.next(langMdl, firstWid); + candScore += ll; + if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) { - const auto wid = curMorph->chunks[i]->lmMorphemeId; - if (morphBase[wid].tag == POSTag::p) + for (size_t i = 1; i < curMorph->chunks.size(); ++i) { - // prohibit without - goto continueFor; + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + // prohibit without + goto continueFor; + } + ll = cLmState.next(langMdl, wid); + candScore += ll; } - ll = cLmState.next(langMdl, wid); - candScore += ll; } } - } - if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) - { - rootIds.resize(prevSpStates.size()); - iota(rootIds.begin(), rootIds.end(), 0); - } - else - { - rootIds.resize(1); - rootIds[0] = commonRootId; - } - - for (auto rootId : rootIds) - { - const auto* prevMorpheme = &morphBase[prevPath.wid]; - auto spState = prevPath.spState; - if (rootId != commonRootId) + if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) { - spState = prevSpStates[rootId]; + rootIds.resize(prevSpStates.size()); + iota(rootIds.begin(), rootIds.end(), 0); } - const float candScoreWithRule = candScore + ruleBasedScorer(prevMorpheme, spState); - - // update special state - if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; - if (ruleBasedScorer.curMorphSbType) + else { - spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + rootIds.resize(1); + rootIds[0] = commonRootId; } - PathHash ph{ cLmState, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), spState); - } - - continueFor:; - } - } - - bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); - return; - } - - template - template - void PathEvaluator::eval(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex, - bool splitSaisiot, - bool mergeSaisiot, - const std::unordered_set* blocklist - ) - { - const size_t langVocabSize = kw->langMdl.vocabSize(); - auto& nCache = cache[i]; - Vector> refCache; - - float whitespaceDiscount = 0; - if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) - { - whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; - } - const float typoDiscount = -node->typoCost * kw->typoCostWeight; - float unknownFormDiscount = 0; - if (unknownForm) - { - size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); - unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); - } - - const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; - - size_t totalPrevPathes = 0; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - totalPrevPathes += cache[prev - startNode].size(); - } - const bool useContainerForSmall = totalPrevPathes <= 48; - - for (bool ignoreCond : {false, true}) - { - for (auto& curMorph : cands) - { - if (splitComplex && curMorph->getCombined()->complex) continue; - if (blocklist && blocklist->count(curMorph->getCombined())) continue; - - // 덧붙은 받침(zCoda)을 위한 지름길 - if (curMorph->tag == POSTag::z_coda) - { - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + for (auto rootId : rootIds) { - for (auto& p : cache[prev - startNode]) + const auto* prevMorpheme = &morphBase[prevPath.wid]; + auto spState = prevPath.spState; + if (rootId != commonRootId) { - auto lastTag = kw->morphemes[p.wid].tag; - if (!isJClass(lastTag) && !isEClass(lastTag)) continue; - nCache.emplace_back(p); - auto& newPath = nCache.back(); - newPath.accScore += curMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= curMorph->userScore; - newPath.parent = &p; - newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; - newPath.wid = curMorph->lmMorphemeId; + spState = prevSpStates[rootId]; } - } - continue; - } - // 사이시옷(zSiot)을 위한 지름길 - if (curMorph->tag == POSTag::z_siot) - { - if (!(splitSaisiot || mergeSaisiot)) - { - continue; - } + const float candScoreWithRule = candScore + ruleBasedScorer(prevMorpheme, spState); - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - for (auto& p : cache[prev - startNode]) + // update special state + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) { - auto lastTag = kw->morphemes[p.wid].tag; - if (!isNNClass(lastTag)) continue; - nCache.emplace_back(p); - auto& newPath = nCache.back(); - newPath.accScore += curMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= curMorph->userScore; - newPath.parent = &p; - newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; - newPath.wid = curMorph->lmMorphemeId; + spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); } - } - continue; - } - // if the morpheme has chunk set - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) - { - // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 - if (node->prev && node[-(int)node->prev].endPos < node->startPos - && curMorph->kform - && curMorph->kform->size() == 1 - && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') - && curMorph->chunks[0]->kform - && curMorph->chunks[0]->kform->size() == 1 - && (*curMorph->chunks[0]->kform)[0] == u'하') - { - continue; + PathHash ph{ cLmState, prevPath.rootId, spState }; + bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), spState); } - } - if (topN > 1) - { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); - } - else if (useContainerForSmall) - { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); - } - else - { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + continueFor:; } - } - if (!nCache.empty()) break; - } - - thread_local Vector maxScores; - maxScores.clear(); - maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); - if (topN == 1) - { - for (auto& c : nCache) - { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - maxScores[rootId] = max(maxScores[rootId], c.accScore); - } - } - else - { - for (auto& c : nCache) - { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - if (c.accScore > maxScores[rootId * topN]) - { - pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); - maxScores[rootId * topN + topN - 1] = c.accScore; - push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); - } - } - } - - size_t validCount = 0; - for (size_t i = 0; i < nCache.size(); ++i) - { - const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; - if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; - if (validCount != i) nCache[validCount] = move(nCache[i]); - validCount++; + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); } - nCache.resize(validCount); - } - - template - struct LmEvalData - { - LmState state; - float score = 0; - uint32_t length = 0; }; template - template - void PathEvaluator::type>::evalMorphemes( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Vector& morphs, - const Vector& morphScores, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const size_t totalPrevPathes, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ) + struct MorphemeEvaluator { - thread_local BestPathConatiner bestPathCont; - thread_local Vector rootIds; - thread_local Vector> evalMatrix; - thread_local Vector nextWids; - - const LangModel& langMdl = kw->langMdl; - const Morpheme* morphBase = kw->morphemes.data(); - const auto spacePenalty = kw->spacePenalty; - const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; - const size_t langVocabSize = langMdl.vocabSize(); - - evalMatrix.resize(totalPrevPathes * morphs.size()); - nextWids.clear(); - - size_t prevId = -1; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + template + void eval( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const Vector& morphScores, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const size_t totalPrevPathes, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ) const { - for (auto& prevPath : cache[prev - startNode]) - { - ++prevId; + thread_local BestPathConatiner bestPathCont; + thread_local Vector rootIds; + thread_local Vector> evalMatrix; + thread_local Vector nextWids; - const kchar_t* leftFormFirst, * leftFormLast; - if (prevPath.ownFormId) - { - leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); - leftFormLast = leftFormFirst + ownForms[0].size(); - } - else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) - { - leftFormFirst = morphBase[prevPath.wid].kform->data(); - leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); - } - else + const auto* langMdl = kw->getLangModel(); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const size_t langVocabSize = langMdl->vocabSize(); + + evalMatrix.resize(totalPrevPathes * morphs.size()); + nextWids.clear(); + + size_t prevId = -1; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) { - leftFormFirst = prevPath.morpheme->getForm().data(); - leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); - } - const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; + ++prevId; - for (size_t curId = 0; curId < morphs.size(); ++curId) - { - const auto curMorph = morphs[curId]; - float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount + morphScores[curId]; - Wid firstWid; - if (curMorph->chunks.empty() || curMorph->complex) + const kchar_t* leftFormFirst, * leftFormLast; + if (prevPath.ownFormId) { - firstWid = curMorph->lmMorphemeId; + leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); + leftFormLast = leftFormFirst + ownForms[0].size(); + } + else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) + { + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); } else { - firstWid = curMorph->chunks[0]->lmMorphemeId; + leftFormFirst = prevPath.morpheme->getForm().data(); + leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); } + const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; - if (prevPath.combineSocket) + for (size_t curId = 0; curId < morphs.size(); ++curId) { - // merge with only the same socket - if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) + const auto curMorph = morphs[curId]; + float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount + morphScores[curId]; + Wid firstWid; + if (curMorph->chunks.empty() || curMorph->complex) { - goto invalidCandidate; + firstWid = curMorph->lmMorphemeId; } - if (prev->endPos < node->startPos) + else { - if (allowedSpaceBetweenChunk) candScore -= spacePenalty; - else goto invalidCandidate; + firstWid = curMorph->chunks[0]->lmMorphemeId; } - firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; - } - const CondVowel cvowel = curMorph->vowel; - const CondPolarity cpolar = curMorph->polar; + if (prevPath.combineSocket) + { + // merge with only the same socket + if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) + { + goto invalidCandidate; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) candScore -= spacePenalty; + else goto invalidCandidate; + } + firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; + } - if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) - { - // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 - } - else if (ignoreCondScore) - { - candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; - } - else - { - if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) goto invalidCandidate; - } + const CondVowel cvowel = curMorph->vowel; + const CondPolarity cpolar = curMorph->polar; - size_t length = 0; - if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex)) - { - // no op - } - else - { - if (morphBase[firstWid].tag == POSTag::p) + if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) + { + // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 + } + else if (ignoreCondScore) { - goto invalidCandidate; + candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + } + else + { + if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) goto invalidCandidate; } - if (curMorph->chunks.empty() || curMorph->complex) + size_t length = 0; + if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex)) { - length = 1; + // no op } else { - length = curMorph->chunks.size(); - for (size_t i = 1; i < length; ++i) + if (morphBase[firstWid].tag == POSTag::p) { - const Wid wid = curMorph->chunks[i]->lmMorphemeId; - if (morphBase[wid].tag == POSTag::p) + goto invalidCandidate; + } + + if (curMorph->chunks.empty() || curMorph->complex) + { + length = 1; + } + else + { + length = curMorph->chunks.size(); + for (size_t i = 1; i < length; ++i) { - goto invalidCandidate; + const Wid wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto invalidCandidate; + } } } } - } - evalMatrix[prevId * morphs.size() + curId].state = prevPath.lmState; - evalMatrix[prevId * morphs.size() + curId].score = candScore; - evalMatrix[prevId * morphs.size() + curId].length = length; - if (length > 0) nextWids.emplace_back(firstWid); - if (length > 1) - { - for (size_t i = 1; i < length; ++i) + evalMatrix[prevId * morphs.size() + curId].state = prevPath.lmState; + evalMatrix[prevId * morphs.size() + curId].score = candScore; + evalMatrix[prevId * morphs.size() + curId].length = length; + if (length > 0) nextWids.emplace_back(firstWid); + if (length > 1) { - nextWids.emplace_back(curMorph->chunks[i]->lmMorphemeId); + for (size_t i = 1; i < length; ++i) + { + nextWids.emplace_back(curMorph->chunks[i]->lmMorphemeId); + } } + continue; + invalidCandidate: + evalMatrix[prevId * morphs.size() + curId].score = -INFINITY; + evalMatrix[prevId * morphs.size() + curId].length = 0; } - continue; - invalidCandidate: - evalMatrix[prevId * morphs.size() + curId].score = -INFINITY; - evalMatrix[prevId * morphs.size() + curId].length = 0; } } - } - { - size_t widOffset = 0; - for (auto& e : evalMatrix) { - //if (e.length == 0) continue; - float score = 0; - for (size_t i = 0; i < e.length; ++i) + size_t widOffset = 0; + for (auto& e : evalMatrix) { - score += e.state.next(langMdl, nextWids[widOffset + i]); + //if (e.length == 0) continue; + float score = 0; + for (size_t i = 0; i < e.length; ++i) + { + score += e.state.next(langMdl, nextWids[widOffset + i]); + } + e.score += score; + widOffset += e.length; } - e.score += score; - widOffset += e.length; - } - } - - for (size_t curId = 0; curId < morphs.size(); ++curId) - { - const auto curMorph = morphs[curId]; - bestPathCont.clear(); - - const Morpheme* lastMorph; - if (curMorph->chunks.empty() || curMorph->complex) - { - lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; - } - // if the morpheme has chunk set - else - { - lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; } - Wid lastSeqId; - if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) - { - lastSeqId = lastMorph - kw->morphemes.data(); - } - else + for (size_t curId = 0; curId < morphs.size(); ++curId) { - lastSeqId = lastMorph->lmMorphemeId; - } + const auto curMorph = morphs[curId]; + bestPathCont.clear(); - RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const Morpheme* lastMorph; + if (curMorph->chunks.empty() || curMorph->complex) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + } - size_t prevId = -1; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - for (auto& prevPath : cache[prev - startNode]) + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) { - ++prevId; - auto& em = evalMatrix[prevId * morphs.size() + curId]; - if (em.score < -99999) - { - continue; - } + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } - if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) - { - rootIds.resize(prevSpStates.size()); - iota(rootIds.begin(), rootIds.end(), 0); - } - else - { - rootIds.resize(1); - rootIds[0] = commonRootId; - } + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; - for (auto rootId : rootIds) + size_t prevId = -1; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) { - const auto* prevMorpheme = &morphBase[prevPath.wid]; - auto spState = prevPath.spState; - if (rootId != commonRootId) + ++prevId; + auto& em = evalMatrix[prevId * morphs.size() + curId]; + if (em.score < -99999) { - spState = prevSpStates[rootId]; + continue; } - const float candScoreWithRule = em.score + ruleBasedScorer(prevMorpheme, spState); - // update special state - if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; - if (ruleBasedScorer.curMorphSbType) + if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) { - spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + rootIds.resize(prevSpStates.size()); + iota(rootIds.begin(), rootIds.end(), 0); + } + else + { + rootIds.resize(1); + rootIds[0] = commonRootId; } - PathHash ph{ em.state, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); + for (auto rootId : rootIds) + { + const auto* prevMorpheme = &morphBase[prevPath.wid]; + auto spState = prevPath.spState; + if (rootId != commonRootId) + { + spState = prevSpStates[rootId]; + } + const float candScoreWithRule = em.score + ruleBasedScorer(prevMorpheme, spState); + + // update special state + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) + { + spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + } + + PathHash ph{ em.state, prevPath.rootId, spState }; + bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); + } } } - } - bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } } - } + }; template - template - void PathEvaluator::type>::eval(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex, - bool splitSaisiot, - bool mergeSaisiot, - const std::unordered_set* blocklist - ) + struct PathEvaluator::type> { - thread_local Vector maxScores; - thread_local Vector validMorphCands; - thread_local Vector lbScores; - const size_t langVocabSize = kw->langMdl.vocabSize(); - auto& nCache = cache[i]; - - float whitespaceDiscount = 0; - if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) - { - whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; - } - const float typoDiscount = -node->typoCost * kw->typoCostWeight; - float unknownFormDiscount = 0; - if (unknownForm) + const Kiwi* kw; + const KGraphNode* startNode; + const size_t topN; + Vector>>& cache; + const Vector& ownFormList; + const Vector& prevSpStates; + + PathEvaluator(const Kiwi* _kw, + const KGraphNode* _startNode, + size_t _topN, + Vector>>& _cache, + const Vector& _ownFormList, + const Vector& _prevSpStates + ) + : kw{ _kw }, startNode{ _startNode }, topN{ _topN }, cache{ _cache }, ownFormList{ _ownFormList }, prevSpStates{ _prevSpStates } { - size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); - unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); } - const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; - const Morpheme* zCodaMorph = nullptr; - const Morpheme* zSiotMorph = nullptr; - validMorphCands.clear(); - lbScores.clear(); - for (auto& curMorph : cands) + template + void operator()( + const size_t nodeIdx, + const size_t ownFormId, + CandTy&& cands, + bool unknownForm, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ) const { - if (splitComplex && curMorph->getCombined()->complex) continue; - if (blocklist && blocklist->count(curMorph->getCombined())) continue; - - // 덧붙은 받침(zCoda)을 위한 지름길 - if (curMorph->tag == POSTag::z_coda) + thread_local Vector maxScores; + thread_local Vector validMorphCands; + thread_local Vector lbScores; + const size_t langVocabSize = kw->langMdl->vocabSize(); + auto* const node = startNode + nodeIdx; + auto& nCache = cache[nodeIdx]; + + float whitespaceDiscount = 0; + if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) { - zCodaMorph = curMorph; - continue; + whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; } - else if (curMorph->tag == POSTag::z_siot) + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) { - zSiotMorph = curMorph; - continue; + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); } - if (!curMorph->chunks.empty() && !curMorph->complex) + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; + const Morpheme* zCodaMorph = nullptr; + const Morpheme* zSiotMorph = nullptr; + validMorphCands.clear(); + lbScores.clear(); + for (auto& curMorph : cands) { - // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 - if (node->prev && node[-(int)node->prev].endPos < node->startPos - && curMorph->kform - && curMorph->kform->size() == 1 - && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') - && curMorph->chunks[0]->kform - && curMorph->chunks[0]->kform->size() == 1 - && (*curMorph->chunks[0]->kform)[0] == u'하') + if (splitComplex && curMorph->getCombined()->complex) continue; + if (blocklist && blocklist->count(curMorph->getCombined())) continue; + + // 덧붙은 받침(zCoda)을 위한 지름길 + if (curMorph->tag == POSTag::z_coda) { + zCodaMorph = curMorph; + continue; + } + else if (curMorph->tag == POSTag::z_siot) + { + zSiotMorph = curMorph; continue; } - } - validMorphCands.emplace_back(curMorph); - lbScores.emplace_back(kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag)); - } - for (bool ignoreCond : {false, true}) - { - // 덧붙은 받침(zCoda)을 위한 지름길 - if (zCodaMorph) - { - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + if (!curMorph->chunks.empty() && !curMorph->complex) { - for (auto& p : cache[prev - startNode]) + // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 + if (node->prev && node[-(int)node->prev].endPos < node->startPos + && curMorph->kform + && curMorph->kform->size() == 1 + && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') + && curMorph->chunks[0]->kform + && curMorph->chunks[0]->kform->size() == 1 + && (*curMorph->chunks[0]->kform)[0] == u'하') { - auto lastTag = kw->morphemes[p.wid].tag; - if (!isJClass(lastTag) && !isEClass(lastTag)) continue; - nCache.emplace_back(p); - auto& newPath = nCache.back(); - newPath.accScore += zCodaMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= zCodaMorph->userScore; - newPath.parent = &p; - newPath.morpheme = &kw->morphemes[zCodaMorph->lmMorphemeId]; - newPath.wid = zCodaMorph->lmMorphemeId; + continue; } } - continue; + validMorphCands.emplace_back(curMorph); + lbScores.emplace_back(kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag)); } - // 사이시옷(zSiot)을 위한 지름길 - if (zSiotMorph) + + for (bool ignoreCond : {false, true}) { - if (!(splitSaisiot || mergeSaisiot)) + // 덧붙은 받침(zCoda)을 위한 지름길 + if (zCodaMorph) { + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isJClass(lastTag) && !isEClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += zCodaMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zCodaMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[zCodaMorph->lmMorphemeId]; + newPath.wid = zCodaMorph->lmMorphemeId; + } + } continue; } - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + // 사이시옷(zSiot)을 위한 지름길 + if (zSiotMorph) { - for (auto& p : cache[prev - startNode]) + if (!(splitSaisiot || mergeSaisiot)) { - auto lastTag = kw->morphemes[p.wid].tag; - if (!isNNClass(lastTag)) continue; - nCache.emplace_back(p); - auto& newPath = nCache.back(); - newPath.accScore += zSiotMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= zSiotMorph->userScore; - newPath.parent = &p; - newPath.morpheme = &kw->morphemes[zSiotMorph->lmMorphemeId]; - newPath.wid = zSiotMorph->lmMorphemeId; + continue; + } + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isNNClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += zSiotMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zSiotMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[zSiotMorph->lmMorphemeId]; + newPath.wid = zSiotMorph->lmMorphemeId; + } } + continue; } - continue; - } - size_t totalPrevPathes = 0; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) - { - totalPrevPathes += cache[prev - startNode].size(); - } - const bool useContainerForSmall = totalPrevPathes <= 48; + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + totalPrevPathes += cache[prev - startNode].size(); + } + const bool useContainerForSmall = totalPrevPathes <= 48; - if (topN > 1) - { - evalMorphemes(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, - node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); - } - else if (useContainerForSmall) - { - evalMorphemes(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, - node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); - } - else - { - evalMorphemes(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, - node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + MorphemeEvaluator me; + if (topN > 1) + { + me.eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + else if (useContainerForSmall) + { + me.eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + else + { + me.eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, lbScores, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + if (!nCache.empty()) break; } - if (!nCache.empty()) break; - } - maxScores.clear(); - maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); + maxScores.clear(); + maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); - if (topN == 1) - { - for (auto& c : nCache) + if (topN == 1) { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - maxScores[rootId] = max(maxScores[rootId], c.accScore); + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + maxScores[rootId] = max(maxScores[rootId], c.accScore); + } } - } - else - { - for (auto& c : nCache) + else { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - if (c.accScore > maxScores[rootId * topN]) + for (auto& c : nCache) { - pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); - maxScores[rootId * topN + topN - 1] = c.accScore; - push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + if (c.accScore > maxScores[rootId * topN]) + { + pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + maxScores[rootId * topN + topN - 1] = c.accScore; + push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + } } } - } - size_t validCount = 0; - for (size_t i = 0; i < nCache.size(); ++i) - { - const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; - if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; - if (validCount != i) nCache[validCount] = move(nCache[i]); - validCount++; + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) + { + const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; + if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; + } + nCache.resize(validCount); } - nCache.resize(validCount); - } - + }; template inline Path generateTokenList(const WordLL* result, @@ -1500,7 +1135,7 @@ namespace kiwi return ret; } - template + template Vector BestPathFinder::findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, @@ -1514,12 +1149,14 @@ namespace kiwi ) { static constexpr size_t eosId = 1; + using LmState = typename LangModel::LmStateType; + const auto* langMdl = kw->getLangModel(); Vector>> cache(graphSize); Vector ownFormList; Vector unknownNodeCands, unknownNodeLCands; - const size_t langVocabSize = kw->langMdl.vocabSize(); + const size_t langVocabSize = langMdl->vocabSize(); const KGraphNode* startNode = graph; const KGraphNode* endNode = graph + graphSize - 1; @@ -1537,7 +1174,7 @@ namespace kiwi } // start node - cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, SpecialState{}); + cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ langMdl }, SpecialState{}); cache[0].back().rootId = commonRootId; #ifdef DEBUG_PRINT @@ -1550,7 +1187,9 @@ namespace kiwi } #endif - using Evaluator = PathEvaluator; + PathEvaluator evaluator{ + kw, startNode, topN, cache, ownFormList, uniqStates, + }; // middle nodes for (size_t i = 1; i < graphSize - 1; ++i) { @@ -1564,9 +1203,8 @@ namespace kiwi if (node->form) { - Evaluator::eval(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, node->form->candidate, - false, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, node->form->candidate, + false, splitComplex, splitSaisiot, mergeSaisiot, blocklist); if (all_of(node->form->candidate.begin(), node->form->candidate.end(), [](const Morpheme* m) { return m->combineSocket || !(m->chunks.empty() || m->complex || m->saisiot); @@ -1574,16 +1212,14 @@ namespace kiwi { ownFormList.emplace_back(node->form->form); ownFormId = ownFormList.size(); - Evaluator::eval(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, unknownNodeLCands, - true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, unknownNodeLCands, + true, splitComplex, splitSaisiot, mergeSaisiot, blocklist); }; } else { - Evaluator::eval(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, unknownNodeCands, - true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, unknownNodeCands, + true, splitComplex, splitSaisiot, mergeSaisiot, blocklist); } #ifdef DEBUG_PRINT @@ -1613,7 +1249,7 @@ namespace kiwi } if (p.morpheme->tag == POSTag::z_siot) continue; - float c = p.accScore + (openEnd ? 0 : p.lmState.next(kw->langMdl, eosId)); + float c = p.accScore + (openEnd ? 0 : p.lmState.next(langMdl, eosId)); if (p.spState.singleQuote) c -= 2; if (p.spState.doubleQuote) c -= 2; if (p.rootId == commonRootId) diff --git a/src/SkipBigramModel.cpp b/src/SkipBigramModel.cpp index f290779a..380ea6d7 100644 --- a/src/SkipBigramModel.cpp +++ b/src/SkipBigramModel.cpp @@ -1,23 +1,65 @@ +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" #include "SkipBigramModel.hpp" namespace kiwi { - namespace sb + template + struct PathHash> { + using LmState = lm::SbgState; + + lm::KnLMState<_arch, VocabTy> lmState; + std::array lastMorphemes; + uint8_t rootId, spState; + + PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) + : lmState{ _lmState.knlm }, rootId{ _rootId }, spState{ _spState } + { + _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); + } + + + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } + { + } + + bool operator==(const PathHash& o) const + { + return lmState == o.lmState && lastMorphemes == o.lastMorphemes && spState == o.spState; + } + }; + + namespace lm + { + template + void* SkipBigramModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder::findBestPath>; + } + + template + void* SkipBigramModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + std::unique_ptr createOptimizedModel(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem) { - auto& header = *reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(sbgMem.get()); switch (header.keySize) { case 1: - return make_unique>(std::move(mem)); + return make_unique>(std::move(knlmMem), std::move(sbgMem)); case 2: - return make_unique>(std::move(mem)); + return make_unique>(std::move(knlmMem), std::move(sbgMem)); case 4: - return make_unique>(std::move(mem)); + return make_unique>(std::move(knlmMem), std::move(sbgMem)); case 8: - return make_unique>(std::move(mem)); + return make_unique>(std::move(knlmMem), std::move(sbgMem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; } @@ -34,12 +76,12 @@ namespace kiwi }; }; - std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& mem, ArchType archType) + std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem, ArchType archType) { static tp::Table table{ CreateOptimizedModelGetter{} }; auto fn = table[static_cast(archType)]; if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; - return (*fn)(std::move(mem)); + return (*fn)(std::move(knlmMem), std::move(sbgMem)); } } } diff --git a/src/SkipBigramModel.hpp b/src/SkipBigramModel.hpp index 9e9eee2b..8ee2c31a 100644 --- a/src/SkipBigramModel.hpp +++ b/src/SkipBigramModel.hpp @@ -5,15 +5,22 @@ #include #include #include "ArchAvailable.h" +#include "Knlm.hpp" #include "search.h" namespace kiwi { - namespace sb + namespace lm { + template + class SbgState; + template class SkipBigramModel : public SkipBigramModelBase { + friend class SbgState; + + KnLangModel knlm; std::unique_ptr ptrs; std::unique_ptr restoredFloats; std::unique_ptr keyData; @@ -22,12 +29,19 @@ namespace kiwi const float* compensations = nullptr; float logWindowSize; public: - SkipBigramModel(utils::MemoryObject&& mem) : SkipBigramModelBase{ std::move(mem) } + using VocabType = KeyType; + using LmStateType = SbgState; + + size_t getMemorySize() const override { return base.size() + knlm.getMemorySize(); } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; + + SkipBigramModel(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem) : SkipBigramModelBase{ std::move(sbgMem) }, knlm{ std::move(knlmMem) } { auto* ptr = reinterpret_cast(base.get()); auto& header = getHeader(); - const KeyType* kSizes = reinterpret_cast(ptr += sizeof(Header)); + const KeyType* kSizes = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); ptrs = make_unique(header.vocabSize + 1); ptrs[0] = 0; for (size_t i = 0; i < header.vocabSize; ++i) @@ -97,5 +111,65 @@ namespace kiwi float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const; }; + + template + struct SbgState : public LmStateBase> + { + KnLMState<_arch, VocabTy> knlm; + size_t historyPos = 0; + std::array history = { {0,} }; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = false; + + SbgState() = default; + SbgState(const ILangModel* lm) : knlm{ &static_cast*>(lm)->knlm } {} + + bool operator==(const SbgState& other) const + { + return knlm == other.knlm && historyPos == other.historyPos && history == other.history; + } + + void getLastHistory(VocabTy* out, size_t n) const + { + for (size_t i = 0; i < n; ++i) + { + out[i] = history[(historyPos + windowSize + i - n) % windowSize]; + } + } + + float nextImpl(const SkipBigramModel<_arch, VocabTy, windowSize>* lm, VocabTy next) + { + float ll = lm->knlm.progress(knlm.node, next); + if (lm->isValidVocab(next)) + { + if (ll > -13) + { + ll = lm->evaluate(history.data(), windowSize, next, ll); + } + history[historyPos] = next; + historyPos = (historyPos + 1) % windowSize; + } + return ll; + } + }; } + + + template + struct Hash> + { + size_t operator()(const lm::SbgState& state) const + { + Hash> hasher; + std::hash vocabHasher; + size_t ret = hasher(state.knlm); + for (size_t i = 0; i < windowSize; ++i) + { + ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + return ret; + } + }; + } diff --git a/src/SkipBigramModelImpl.hpp b/src/SkipBigramModelImpl.hpp index 887ab5f9..c89a4d78 100644 --- a/src/SkipBigramModelImpl.hpp +++ b/src/SkipBigramModelImpl.hpp @@ -6,7 +6,7 @@ namespace kiwi { - namespace sb + namespace lm { template struct LogExpSum diff --git a/src/SkipBigramTrainer.hpp b/src/SkipBigramTrainer.hpp index 9195a607..b9bcf7c3 100644 --- a/src/SkipBigramTrainer.hpp +++ b/src/SkipBigramTrainer.hpp @@ -12,7 +12,7 @@ namespace kiwi { - namespace sb + namespace lm { struct TrainContext { @@ -768,7 +768,7 @@ namespace kiwi utils::MemoryOwner convertToModel(float trimThreshold = -15, bool quantize = true) const { - Header header = { 0, }; + SkipBigramModelHeader header = { 0, }; header.vocabSize = ptrs.size() - 1; header.keySize = sizeof(VocabTy); header.windowSize = windowSize; @@ -821,7 +821,7 @@ namespace kiwi mse = nuq::nuquant(compensationTable.data(), allCompensations, 256); std::transform(compensationTable.begin(), compensationTable.end(), compensationTable.begin(), [](float f) { return -std::pow(f, 16.f); }); - size_t totalModelSize = sizeof(Header); + size_t totalModelSize = sizeof(SkipBigramModelHeader); totalModelSize += header.vocabSize * sizeof(VocabTy); totalModelSize += finalVocabSize * sizeof(VocabTy); totalModelSize += header.vocabSize * sizeof(uint8_t); @@ -832,8 +832,8 @@ namespace kiwi utils::MemoryOwner ret{ totalModelSize }; auto* ptr = reinterpret_cast(ret.get()); - *reinterpret_cast(ptr) = header; - auto* ks = reinterpret_cast(ptr += sizeof(Header)); + *reinterpret_cast(ptr) = header; + auto* ks = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); for (auto& v : compensations) { *ks++ = v.first.size(); @@ -870,7 +870,7 @@ namespace kiwi } else { - size_t totalModelSize = sizeof(Header); + size_t totalModelSize = sizeof(SkipBigramModelHeader); totalModelSize += header.vocabSize * sizeof(VocabTy); totalModelSize += finalVocabSize * sizeof(VocabTy); totalModelSize += header.vocabSize * sizeof(float); @@ -879,8 +879,8 @@ namespace kiwi utils::MemoryOwner ret{ totalModelSize }; auto* ptr = reinterpret_cast(ret.get()); - *reinterpret_cast(ptr) = header; - auto* ks = reinterpret_cast(ptr += sizeof(Header)); + *reinterpret_cast(ptr) = header; + auto* ks = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); for (auto& v : compensations) { *ks++ = v.first.size(); diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 76f27d75..3d8bd19a 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template<> struct LogExpSum diff --git a/src/archImpl/avx512bw.cpp b/src/archImpl/avx512bw.cpp index ad290331..03f6d8b2 100644 --- a/src/archImpl/avx512bw.cpp +++ b/src/archImpl/avx512bw.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template<> struct LogExpSum diff --git a/src/archImpl/neon.cpp b/src/archImpl/neon.cpp index 57d80ec8..87ffe397 100644 --- a/src/archImpl/neon.cpp +++ b/src/archImpl/neon.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template<> struct LogExpSum diff --git a/src/archImpl/none.cpp b/src/archImpl/none.cpp index 42ab13b8..223f1d4a 100644 --- a/src/archImpl/none.cpp +++ b/src/archImpl/none.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template class SkipBigramModel; template class SkipBigramModel; diff --git a/src/archImpl/sse2.cpp b/src/archImpl/sse2.cpp index ecab2c80..25eab888 100644 --- a/src/archImpl/sse2.cpp +++ b/src/archImpl/sse2.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template<> struct LogExpSum diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index 8c7efa4d..d5759b47 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -2,7 +2,7 @@ namespace kiwi { - namespace sb + namespace lm { template<> struct LogExpSum diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index e77740a1..6269237c 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -36,6 +36,7 @@ + @@ -55,6 +56,7 @@ + @@ -63,10 +65,10 @@ + - @@ -123,9 +125,6 @@ - - EIGEN_VECTORIZE_AVX512;KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) - From fd173dc311629433b2484c4a5e92b8ed032e406f Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 26 Jan 2025 16:23:44 +0900 Subject: [PATCH 08/53] Refactor PathEvaluator more --- include/kiwi/Form.h | 3 + src/BestPathContainer.hpp | 6 +- src/PathEvaluator.hpp | 286 +++++++++++++++++--------------------- 3 files changed, 134 insertions(+), 161 deletions(-) diff --git a/include/kiwi/Form.h b/include/kiwi/Form.h index 335a3138..547e2b89 100644 --- a/include/kiwi/Form.h +++ b/include/kiwi/Form.h @@ -166,6 +166,9 @@ namespace kiwi /** 분할된 형태소의 경우 원형 형태소를 반환한다. 그 외에는 자기 자신을 반환한다. */ const Morpheme* getCombined() const { return this + combined; } + + /** 현재 인스턴스가 단일 형태소인지 확인한다 */ + bool isSingle() const { return chunks.empty() || complex || saisiot; } }; /** diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp index 6cee09cf..6d26a937 100644 --- a/src/BestPathContainer.hpp +++ b/src/BestPathContainer.hpp @@ -185,7 +185,7 @@ namespace kiwi // fill the rest information of resultOut newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + if (curMorph->isSingle()) { newPath.combineSocket = curMorph->combineSocket; newPath.ownFormId = ownFormId; @@ -230,7 +230,7 @@ namespace kiwi // fill the rest information of resultOut newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + if (curMorph->isSingle()) { newPath.combineSocket = curMorph->combineSocket; newPath.ownFormId = ownFormId; @@ -282,7 +282,7 @@ namespace kiwi // fill the rest information of resultOut newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + if (curMorph->isSingle()) { newPath.combineSocket = curMorph->combineSocket; newPath.ownFormId = ownFormId; diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index ae1bba27..e4f84446 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -180,6 +180,111 @@ namespace kiwi || m == Kiwi::SpecialMorph::doubleQuoteOpen || m == Kiwi::SpecialMorph::doubleQuoteClose; } + + template + inline void insertToPathContainer( + BestPathConatiner& bestPathCont, + const size_t topN, + const Vector& prevSpStates, + const Morpheme* curMorph, + const Morpheme* morphBase, + LmState&& state, + const float score, + const KGraphNode* node, + const WordLL& prevPath, + const RuleBasedScorer& ruleBasedScorer + ) + { + const auto insert = [&](uint8_t rootId) + { + const auto* prevMorpheme = &morphBase[prevPath.wid]; + auto spState = prevPath.spState; + if (rootId != commonRootId) + { + spState = prevSpStates[rootId]; + } + const float candScoreWithRule = score + ruleBasedScorer(prevMorpheme, spState); + + // update special state + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) + { + spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + } + + PathHash ph{ state, prevPath.rootId, spState }; + bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(state), spState); + }; + + if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) + { + for (uint8_t rootId = 0; rootId < prevSpStates.size(); ++rootId) + { + insert(rootId); + } + } + else + { + insert(commonRootId); + } + } + + class FormEvaluator + { + const kchar_t* leftFormFirst; + const kchar_t* leftFormLast; + bool leftFormEndswithSSC; + POSTag prevTag; + + public: + template + FormEvaluator(const WordLL& prevPath, + const Vector& ownFormList, + const Morpheme* morphBase + ) + { + if (prevPath.ownFormId) + { + leftFormFirst = ownFormList[prevPath.ownFormId - 1].data(); + leftFormLast = leftFormFirst + ownFormList[0].size(); + } + else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) + { + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); + } + else + { + leftFormFirst = prevPath.morpheme->getForm().data(); + leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); + } + leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; + prevTag = prevPath.morpheme->tag; + } + + bool operator()(const Morpheme* curMorph, const float ignoreCondScore, float& candScore) const + { + const CondVowel cvowel = curMorph->vowel; + const CondPolarity cpolar = curMorph->polar; + if (prevTag == POSTag::ssc || leftFormEndswithSSC) + { + // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 + } + else if (ignoreCondScore) + { + candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + } + else + { + if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) return false; + } + return true; + } + }; + template struct LmEvalData { @@ -305,7 +410,7 @@ namespace kiwi } // if the morpheme has chunk set - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + if (!curMorph->isSingle()) { // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 if (node->prev && node[-(int)node->prev].endPos < node->startPos @@ -390,8 +495,7 @@ namespace kiwi ) const { thread_local BestPathConatiner bestPathCont; - thread_local Vector rootIds; - + const auto* langMdl = kw->getLangModel(); const Morpheme* morphBase = kw->morphemes.data(); const auto spacePenalty = kw->spacePenalty; @@ -401,7 +505,7 @@ namespace kiwi const Morpheme* lastMorph; Wid firstWid; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + if (curMorph->isSingle()) { lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; firstWid = curMorph->lmMorphemeId; @@ -445,7 +549,7 @@ namespace kiwi if (prevPath.combineSocket) { // merge with only the same socket - if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + if (prevPath.combineSocket != curMorph->combineSocket || curMorph->isSingle()) { continue; } @@ -457,41 +561,11 @@ namespace kiwi firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; } - const kchar_t* leftFormFirst, * leftFormLast; - if (prevPath.ownFormId) - { - leftFormFirst = ownFormList[prevPath.ownFormId - 1].data(); - leftFormLast = leftFormFirst + ownFormList[0].size(); - } - else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) - { - leftFormFirst = morphBase[prevPath.wid].kform->data(); - leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); - } - else - { - leftFormFirst = prevPath.morpheme->getForm().data(); - leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); - } - - const CondVowel cvowel = curMorph->vowel; - const CondPolarity cpolar = curMorph->polar; - const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; - if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) - { - // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 - } - else if (ignoreCondScore) - { - candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; - } - else - { - if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) continue; - } + FormEvaluator formEvaluator{ prevPath, ownFormList, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, candScore)) continue; auto cLmState = prevPath.lmState; - if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + if (curMorph->combineSocket && curMorph->isSingle()) { // no-op } @@ -504,7 +578,7 @@ namespace kiwi } float ll = cLmState.next(langMdl, firstWid); candScore += ll; - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + if (!curMorph->isSingle()) { for (size_t i = 1; i < curMorph->chunks.size(); ++i) { @@ -520,41 +594,7 @@ namespace kiwi } } - if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) - { - rootIds.resize(prevSpStates.size()); - iota(rootIds.begin(), rootIds.end(), 0); - } - else - { - rootIds.resize(1); - rootIds[0] = commonRootId; - } - - for (auto rootId : rootIds) - { - const auto* prevMorpheme = &morphBase[prevPath.wid]; - auto spState = prevPath.spState; - if (rootId != commonRootId) - { - spState = prevSpStates[rootId]; - } - const float candScoreWithRule = candScore + ruleBasedScorer(prevMorpheme, spState); - - // update special state - if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; - if (ruleBasedScorer.curMorphSbType) - { - spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); - } - - PathHash ph{ cLmState, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), spState); - } - + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(cLmState), candScore, node, prevPath, ruleBasedScorer); continueFor:; } } @@ -574,7 +614,6 @@ namespace kiwi const Vector>>& cache, size_t ownFormId, const Vector& morphs, - const Vector& morphScores, const KGraphNode* node, const KGraphNode* startNode, const size_t topN, @@ -585,7 +624,6 @@ namespace kiwi ) const { thread_local BestPathConatiner bestPathCont; - thread_local Vector rootIds; thread_local Vector> evalMatrix; thread_local Vector nextWids; @@ -604,31 +642,13 @@ namespace kiwi for (auto& prevPath : cache[prev - startNode]) { ++prevId; - - const kchar_t* leftFormFirst, * leftFormLast; - if (prevPath.ownFormId) - { - leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); - leftFormLast = leftFormFirst + ownForms[0].size(); - } - else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) - { - leftFormFirst = morphBase[prevPath.wid].kform->data(); - leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); - } - else - { - leftFormFirst = prevPath.morpheme->getForm().data(); - leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); - } - const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; - + FormEvaluator formEvaluator{ prevPath, ownForms, morphBase }; for (size_t curId = 0; curId < morphs.size(); ++curId) { const auto curMorph = morphs[curId]; - float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount + morphScores[curId]; + float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount; Wid firstWid; - if (curMorph->chunks.empty() || curMorph->complex) + if (curMorph->isSingle()) { firstWid = curMorph->lmMorphemeId; } @@ -640,7 +660,7 @@ namespace kiwi if (prevPath.combineSocket) { // merge with only the same socket - if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex)) + if (prevPath.combineSocket != curMorph->combineSocket || curMorph->isSingle()) { goto invalidCandidate; } @@ -652,24 +672,10 @@ namespace kiwi firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; } - const CondVowel cvowel = curMorph->vowel; - const CondPolarity cpolar = curMorph->polar; - - if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) - { - // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 - } - else if (ignoreCondScore) - { - candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; - } - else - { - if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) goto invalidCandidate; - } + if (!formEvaluator(curMorph, ignoreCondScore, candScore)) continue; size_t length = 0; - if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex)) + if (curMorph->combineSocket && curMorph->isSingle()) { // no op } @@ -680,7 +686,7 @@ namespace kiwi goto invalidCandidate; } - if (curMorph->chunks.empty() || curMorph->complex) + if (curMorph->isSingle()) { length = 1; } @@ -737,7 +743,7 @@ namespace kiwi bestPathCont.clear(); const Morpheme* lastMorph; - if (curMorph->chunks.empty() || curMorph->complex) + if (curMorph->isSingle()) { lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; } @@ -758,7 +764,7 @@ namespace kiwi } RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; - + const float morphScore = kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); size_t prevId = -1; for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { @@ -771,40 +777,7 @@ namespace kiwi continue; } - if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) - { - rootIds.resize(prevSpStates.size()); - iota(rootIds.begin(), rootIds.end(), 0); - } - else - { - rootIds.resize(1); - rootIds[0] = commonRootId; - } - - for (auto rootId : rootIds) - { - const auto* prevMorpheme = &morphBase[prevPath.wid]; - auto spState = prevPath.spState; - if (rootId != commonRootId) - { - spState = prevSpStates[rootId]; - } - const float candScoreWithRule = em.score + ruleBasedScorer(prevMorpheme, spState); - - // update special state - if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; - if (ruleBasedScorer.curMorphSbType) - { - spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); - } - - PathHash ph{ em.state, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(em.state), spState); - } + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(em.state), em.score, node, prevPath, ruleBasedScorer); } } @@ -848,7 +821,6 @@ namespace kiwi { thread_local Vector maxScores; thread_local Vector validMorphCands; - thread_local Vector lbScores; const size_t langVocabSize = kw->langMdl->vocabSize(); auto* const node = startNode + nodeIdx; auto& nCache = cache[nodeIdx]; @@ -870,7 +842,6 @@ namespace kiwi const Morpheme* zCodaMorph = nullptr; const Morpheme* zSiotMorph = nullptr; validMorphCands.clear(); - lbScores.clear(); for (auto& curMorph : cands) { if (splitComplex && curMorph->getCombined()->complex) continue; @@ -888,7 +859,7 @@ namespace kiwi continue; } - if (!curMorph->chunks.empty() && !curMorph->complex) + if (!curMorph->isSingle()) { // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 if (node->prev && node[-(int)node->prev].endPos < node->startPos @@ -903,7 +874,6 @@ namespace kiwi } } validMorphCands.emplace_back(curMorph); - lbScores.emplace_back(kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag)); } for (bool ignoreCond : {false, true}) @@ -964,19 +934,19 @@ namespace kiwi if (topN > 1) { me.eval(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, + ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else if (useContainerForSmall) { me.eval(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, + ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else { me.eval(nCache, kw, ownFormList, cache, - ownFormId, validMorphCands, lbScores, + ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } if (!nCache.empty()) break; From 6d21f085258f063e8ef0e097fc31ee5215bdf026 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 26 Jan 2025 16:26:35 +0900 Subject: [PATCH 09/53] Refactor evaluator-cli --- src/KiwiBuilder.cpp | 4 - tools/Evaluator.cpp | 277 +++++++++++++++++++++++++++++++++++---- tools/Evaluator.h | 54 ++++++-- tools/evaluator_main.cpp | 184 ++++++++------------------ tools/runner.cpp | 2 +- 5 files changed, 354 insertions(+), 167 deletions(-) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 360f15e9..ee74b8cc 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -795,10 +795,6 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio { langMdl = lm::PcLangModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, modelType == ModelType::pclm); } - else if (modelType == ModelType::pclmLocal) - { - langMdl.pclm = pclm::PCLanguageModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, false); - } if (!!(options & BuildOption::loadDefaultDict)) { diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 1bef793c..996e8b59 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -4,12 +4,41 @@ #include #include "../src/StrUtils.h" #include "Evaluator.h" +#include "toolUtils.h" #include "LCS.hpp" using namespace std; using namespace kiwi; -TokenInfo parseWordPOS(const u16string& str) +unique_ptr Evaluator::create(const std::string& evalType) +{ + if (evalType == "morph") return std::make_unique(); + if (evalType == "disamb") return std::make_unique(); + throw runtime_error{ "Unknown Evaluator Type" }; +} + +const char* modelTypeToStr(ModelType type) +{ + switch (type) + { + case ModelType::knlm: return "knlm"; + case ModelType::knlmTransposed: return "knlm-transposed"; + case ModelType::sbg: return "sbg"; + case ModelType::pclm: return "pclm"; + case ModelType::pclmLocal: return "pclm-local"; + } + return "unknown"; +} + +inline ostream& operator<<(ostream& o, const kiwi::TokenInfo& t) +{ + o << utf16To8(t.str); + if (t.senseId) o << "__" << (int)t.senseId; + o << "/" << kiwi::tagToString(t.tag); + return o; +} + +inline TokenInfo parseWordPOS(const u16string& str) { auto p = str.rfind('/'); if (p == str.npos) return {}; @@ -36,13 +65,98 @@ TokenInfo parseWordPOS(const u16string& str) tagStr.erase(tagStr.begin() + tagStr.find('-'), tagStr.end()); } POSTag tag = toPOSTag(tagStr); - if (tag >= POSTag::max) throw runtime_error{ "Wrong Input '" + utf16To8(str.substr(p + 1)) + "'" }; + if (clearIrregular(tag) >= POSTag::max) throw runtime_error{ "Wrong Input '" + utf16To8(str.substr(p + 1)) + "'" }; return { form, tag, 0, 0 }; } -Evaluator::Evaluator(const std::string& testSetFile, const Kiwi* _kw, Match _matchOption, size_t _topN) - : kw{ _kw }, matchOption{ _matchOption }, topN{ _topN } +int Evaluator::operator()(const string& modelPath, + const string& output, + const vector& input, + bool normCoda, bool zCoda, bool multiDict, ModelType modelType, + float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, + int repeat) +{ + try + { + if (typoCostWeight > 0 && !bTypo && !cTypo && !lTypo) + { + bTypo = true; + } + else if (typoCostWeight == 0) + { + bTypo = false; + cTypo = false; + lTypo = false; + } + + tutils::Timer timer; + auto option = (BuildOption::default_ & ~BuildOption::loadMultiDict) | (multiDict ? BuildOption::loadMultiDict : BuildOption::none); + auto typo = getDefaultTypoSet(DefaultTypoSet::withoutTypo); + + if (bTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::basicTypoSet); + } + + if (cTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::continualTypoSet); + } + + if (lTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::lengtheningTypoSet); + } + + Kiwi kw = KiwiBuilder{ modelPath, 1, option, modelType }.build( + typo + ); + if (typoCostWeight > 0) kw.setTypoCostWeight(typoCostWeight); + + cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; + cout << "ArchType : " << archToStr(kw.archType()) << endl; + cout << "Model Type : " << modelTypeToStr(kw.modelType()) << endl; + if (kw.getLangModel()) + { + cout << "LM Size : " << (kw.getLangModel()->getMemorySize() / 1024. / 1024.) << " MB" << endl; + } + cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB\n" << endl; + + double avgMicro = 0, avgMacro = 0; + double cnt = 0; + for (auto& tf : input) + { + cout << "Test file: " << tf << endl; + try + { + auto result = eval(output, tf, kw, normCoda, zCoda, repeat); + avgMicro += result.first; + avgMacro += result.second; + ++cnt; + cout << "================" << endl; + } + catch (const std::exception& e) + { + cerr << e.what() << endl; + } + } + + cout << endl << "================" << endl; + cout << "Avg Score" << endl; + cout << avgMicro / cnt << ", " << avgMacro / cnt << endl; + cout << "================" << endl; + return 0; + } + catch (const exception& e) + { + cerr << e.what() << endl; + return -1; + } +} + +auto MorphEvaluator::loadTestset(const string& testSetFile) const -> vector { + vector ret; ifstream f{ testSetFile }; if (!f) throw std::ios_base::failure{ "Cannot open '" + testSetFile + "'" }; string line; @@ -60,27 +174,19 @@ Evaluator::Evaluator(const std::string& testSetFile, const Kiwi* _kw, Match _mat TestResult tr; tr.q = fd[0].to_string(); for (auto& t : tokens) tr.a.emplace_back(parseWordPOS(t)); - testsets.emplace_back(std::move(tr)); - } -} - -void Evaluator::run() -{ - for (auto& tr : testsets) - { - auto cands = kw->analyze(tr.q, topN, matchOption); - tr.r = cands[0].first; + ret.emplace_back(std::move(tr)); } + return ret; } -Evaluator::Score Evaluator::evaluate() +auto MorphEvaluator::computeScore(vector& preds, vector& errors) const -> Score { errors.clear(); size_t totalCount = 0, microCorrect = 0, microCount = 0; double totalScore = 0; - for (auto& tr : testsets) + for (auto& tr : preds) { if (tr.a != tr.r) { @@ -128,15 +234,31 @@ Evaluator::Score Evaluator::evaluate() return ret; } -ostream& operator<<(ostream& o, const kiwi::TokenInfo& t) +auto DisambEvaluator::computeScore(vector& preds, vector& errors) const -> Score { - o << utf16To8(t.str); - if (t.senseId) o << "__" << (int)t.senseId; - o << "/" << kiwi::tagToString(t.tag); - return o; + errors.clear(); + Score score; + for (auto& tr : preds) + { + bool correct = false; + for (auto& token : tr.result.first) + { + if (token.str == tr.target.str && + clearIrregular(token.tag) == clearIrregular(tr.target.tag)) + { + correct = true; + break; + } + } + if (correct) score.acc += 1; + else errors.emplace_back(tr); + score.totalCount++; + } + score.acc /= score.totalCount; + return score; } -void Evaluator::TestResult::writeResult(ostream& out) const +void MorphEvaluator::TestResult::writeResult(ostream& out) const { out << utf16To8(q) << '\t' << score << endl; for (auto& _r : da) @@ -151,3 +273,114 @@ void Evaluator::TestResult::writeResult(ostream& out) const out << endl; out << endl; } + +pair MorphEvaluator::eval(const string& output, const string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) +{ + const size_t topN = 1; + const Match matchOption = (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda); + vector testsets = loadTestset(file), errors; + tutils::Timer total; + for (int i = 0; i < repeat; ++i) + { + for (auto& tr : testsets) + { + auto cands = kiwi.analyze(tr.q, topN, matchOption); + tr.r = cands[0].first; + } + } + double tm = total.getElapsed() / repeat; + auto score = computeScore(testsets, errors); + + cout << score.micro << ", " << score.macro << endl; + cout << "Total (" << score.totalCount << " lines) Time : " << tm << " ms" << endl; + cout << "Time per Line : " << tm / score.totalCount << " ms" << endl; + + if (!output.empty()) + { + const size_t last_slash_idx = file.find_last_of("\\/"); + string name; + if (last_slash_idx != file.npos) name = file.substr(last_slash_idx + 1); + else name = file; + + ofstream out{ output + "/" + name }; + out << score.micro << ", " << score.macro << endl; + out << "Total (" << score.totalCount << ") Time : " << tm << " ms" << endl; + out << "Time per Unit : " << tm / score.totalCount << " ms" << endl; + for (auto t : errors) + { + t.writeResult(out); + } + } + return make_pair(score.micro, score.macro); +} + +auto DisambEvaluator::loadTestset(const string& testSetFile) const -> vector +{ + vector ret; + ifstream f{ testSetFile }; + if (!f) throw std::ios_base::failure{ "Cannot open '" + testSetFile + "'" }; + string line; + while (getline(f, line)) + { + while (line.back() == '\n' || line.back() == '\r') line.pop_back(); + auto wstr = utf8To16(line); + auto fd = split(wstr, u'\t'); + if (fd.size() < 2) continue; + TestResult tr; + tr.target = parseWordPOS(fd[0].to_string()); + tr.text = fd[1].to_string(); + ret.emplace_back(move(tr)); + } + return ret; +} + +void DisambEvaluator::TestResult::writeResult(ostream& out) const +{ + out << target << '\t' << utf16To8(text) << '\t' << score << endl; + for (auto& _r : result.first) + { + out << _r << '\t'; + } + out << endl; + out << endl; +} + +pair DisambEvaluator::eval(const string& output, const string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) +{ + const size_t topN = 1; + const Match matchOption = (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda); + vector testsets = loadTestset(file), errors; + tutils::Timer total; + for (int i = 0; i < repeat; ++i) + { + for (auto& tr : testsets) + { + auto cands = kiwi.analyze(tr.text, topN, matchOption); + tr.result = cands[0]; + } + } + double tm = total.getElapsed() / repeat; + auto score = computeScore(testsets, errors); + + cout << score.acc << endl; + cout << "Total (" << score.totalCount << " lines) Time : " << tm << " ms" << endl; + cout << "Time per Line : " << tm / score.totalCount << " ms" << endl; + + if (!output.empty()) + { + const size_t last_slash_idx = file.find_last_of("\\/"); + string name; + if (last_slash_idx != file.npos) name = file.substr(last_slash_idx + 1); + else name = file; + + ofstream out{ output + "/" + name }; + out << score.acc << endl; + out << "Total (" << score.totalCount << ") Time : " << tm << " ms" << endl; + out << "Time per Unit : " << tm / score.totalCount << " ms" << endl; + for (auto t : errors) + { + t.writeResult(out); + } + } + return make_pair(score.acc, score.acc); +} diff --git a/tools/Evaluator.h b/tools/Evaluator.h index a91161d3..77221306 100644 --- a/tools/Evaluator.h +++ b/tools/Evaluator.h @@ -3,7 +3,23 @@ class Evaluator { + virtual std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) = 0; public: + + virtual ~Evaluator() = default; + + static std::unique_ptr create(const std::string& evalType); + + int operator()(const std::string& modelPath, + const std::string& output, + const std::vector& input, + bool normCoda, bool zCoda, bool multiDict, kiwi::ModelType modelType, + float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, + int repeat); +}; + +class MorphEvaluator : public Evaluator +{ using AnswerType = std::vector; struct TestResult { @@ -15,7 +31,7 @@ class Evaluator float score; void writeResult(std::ostream& out) const; }; - + struct Score { double micro = 0; @@ -23,15 +39,31 @@ class Evaluator size_t totalCount = 0; }; -private: - std::vector testsets, errors; - const kiwi::Kiwi* kw = nullptr; - kiwi::Match matchOption; - size_t topN = 1; -public: - Evaluator(const std::string& testSetFile, const kiwi::Kiwi* _kw, kiwi::Match _matchOption = kiwi::Match::all, size_t topN = 1); - void run(); - Score evaluate(); - const std::vector& getErrors() const { return errors; } + std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) override; + + std::vector loadTestset(const std::string& file) const; + Score computeScore(std::vector& preds, std::vector& errors) const; }; +class DisambEvaluator : public Evaluator +{ + struct TestResult + { + std::u16string text; + kiwi::TokenInfo target; + kiwi::TokenResult result; + float score = 0; + void writeResult(std::ostream& out) const; + }; + + struct Score + { + double acc = 0; + size_t totalCount = 0; + }; + + std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) override; + + std::vector loadTestset(const std::string& file) const; + Score computeScore(std::vector& preds, std::vector& errors) const; +}; diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index 9dc876a4..d50af519 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -10,131 +10,6 @@ using namespace std; using namespace kiwi; - -const char* modelTypeToStr(ModelType type) -{ - switch (type) - { - case ModelType::knlm: return "knlm"; - case ModelType::knlmTransposed: return "knlm-transposed"; - case ModelType::sbg: return "sbg"; - case ModelType::pclm: return "pclm"; - } - return "unknown"; -} - -int doEvaluate(const string& modelPath, const string& output, const vector& input, - bool normCoda, bool zCoda, bool multiDict, ModelType modelType, - float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, - int repeat) -{ - try - { - if (typoCostWeight > 0 && !bTypo && !cTypo && !lTypo) - { - bTypo = true; - } - else if (typoCostWeight == 0) - { - bTypo = false; - cTypo = false; - lTypo = false; - } - - tutils::Timer timer; - auto option = (BuildOption::default_ & ~BuildOption::loadMultiDict) | (multiDict ? BuildOption::loadMultiDict : BuildOption::none); - auto typo = getDefaultTypoSet(DefaultTypoSet::withoutTypo); - - if (bTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::basicTypoSet); - } - - if (cTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::continualTypoSet); - } - - if (lTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::lengtheningTypoSet); - } - - Kiwi kw = KiwiBuilder{ modelPath, 1, option, modelType }.build( - typo - ); - if (typoCostWeight > 0) kw.setTypoCostWeight(typoCostWeight); - - cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; - cout << "ArchType : " << archToStr(kw.archType()) << endl; - cout << "Model Type : " << modelTypeToStr(kw.modelType()) << endl; - if (kw.getKnLM()) - { - cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; - } - cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB\n" << endl; - - double avgMicro = 0, avgMacro = 0; - double cnt = 0; - for (auto& tf : input) - { - cout << "Test file: " << tf << endl; - try - { - Evaluator test{ tf, &kw, (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda) }; - tutils::Timer total; - for (int i = 0; i < repeat; ++i) - { - test.run(); - } - double tm = total.getElapsed() / repeat; - auto result = test.evaluate(); - - cout << result.micro << ", " << result.macro << endl; - cout << "Total (" << result.totalCount << " lines) Time : " << tm << " ms" << endl; - cout << "Time per Line : " << tm / result.totalCount << " ms" << endl; - - avgMicro += result.micro; - avgMacro += result.macro; - cnt++; - - if (!output.empty()) - { - const size_t last_slash_idx = tf.find_last_of("\\/"); - string name; - if (last_slash_idx != tf.npos) name = tf.substr(last_slash_idx + 1); - else name = tf; - - ofstream out{ output + "/" + name }; - out << result.micro << ", " << result.macro << endl; - out << "Total (" << result.totalCount << ") Time : " << tm << " ms" << endl; - out << "Time per Unit : " << tm / result.totalCount << " ms" << endl; - for (auto t : test.getErrors()) - { - t.writeResult(out); - } - } - cout << "================" << endl; - } - catch (const std::exception& e) - { - cerr << e.what() << endl; - } - } - - cout << endl << "================" << endl; - cout << "Avg Score" << endl; - cout << avgMicro / cnt << ", " << avgMacro / cnt << endl; - cout << "================" << endl; - return 0; - } - catch (const exception& e) - { - cerr << e.what() << endl; - return -1; - } -} - using namespace TCLAP; int main(int argc, const char* argv[]) @@ -152,11 +27,10 @@ int main(int argc, const char* argv[]) SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; SwitchArg lTypo{ "", "ltypo", "make lengthening-typo-tolerant model", false }; ValueArg repeat{ "", "repeat", "repeat evaluation for benchmark", false, 1, "int" }; - UnlabeledMultiArg files{ "files", "evaluation set files", true, "string" }; + UnlabeledMultiArg inputs{ "inputs", "evaluation set (--morph, --disamb)", false, "string" }; cmd.add(model); cmd.add(output); - cmd.add(files); cmd.add(noNormCoda); cmd.add(noZCoda); cmd.add(noMulti); @@ -166,6 +40,7 @@ int main(int argc, const char* argv[]) cmd.add(cTypo); cmd.add(lTypo); cmd.add(repeat); + cmd.add(inputs); try { @@ -195,13 +70,64 @@ int main(int argc, const char* argv[]) { kiwiModelType = ModelType::pclm; } + else if (v == "pclm-local") + { + kiwiModelType = ModelType::pclmLocal; + } else { cerr << "Invalid model type" << endl; return -1; } } - return doEvaluate(model, output, files.getValue(), - !noNormCoda, !noZCoda, !noMulti, kiwiModelType, typoWeight, bTypo, cTypo, lTypo, repeat); + + vector morphInputs, disambInputs; + + string currentType = ""; + for (auto& input : inputs.getValue()) + { + if (input.size() > 2 && input[0] == '-' && input[1] == '-') + { + currentType = input; + } + else + { + if (currentType == "--morph") + { + morphInputs.emplace_back(input); + } + else if (currentType == "--disamb") + { + disambInputs.emplace_back(input); + } + else + { + cerr << "Unknown argument: " << input << endl; + return -1; + } + } + } + + if (morphInputs.size()) + { + auto evaluator = Evaluator::create("morph"); + (*evaluator)(model, output, morphInputs, + !noNormCoda, !noZCoda, !noMulti, + kiwiModelType, + typoWeight, bTypo, cTypo, lTypo, + repeat); + cout << endl; + } + + if (disambInputs.size()) + { + auto evaluator = Evaluator::create("disamb"); + (*evaluator)(model, output, disambInputs, + !noNormCoda, !noZCoda, !noMulti, + kiwiModelType, + typoWeight, bTypo, cTypo, lTypo, + repeat); + cout << endl; + } } diff --git a/tools/runner.cpp b/tools/runner.cpp index 78954de2..3b8f8b7e 100644 --- a/tools/runner.cpp +++ b/tools/runner.cpp @@ -28,7 +28,7 @@ int run(const string& modelPath, bool benchmark, const string& output, const str { tutils::Timer timer; size_t lines = 0, bytes = 0; - Kiwi kw = KiwiBuilder{ modelPath, 1, BuildOption::default_, sbg }.build(typos > 0 ? DefaultTypoSet::basicTypoSet : DefaultTypoSet::withoutTypo); + Kiwi kw = KiwiBuilder{ modelPath, 1, BuildOption::default_, sbg ? ModelType::sbg : ModelType::knlm }.build(typos > 0 ? DefaultTypoSet::basicTypoSet : DefaultTypoSet::withoutTypo); cout << "Kiwi v" << KIWI_VERSION_STRING << endl; if (tolerance) From a0058b3c2fe45315d9894a5e976e8db8220e0873 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 9 Feb 2025 02:57:25 +0900 Subject: [PATCH 10/53] add `ArchType::avx_vnni` & `ArchType::avx512vnni` --- include/kiwi/ArchUtils.h | 14 + include/kiwi/TemplateUtils.hpp | 38 +++ src/ArchAvailable.h | 30 ++ src/ArchUtils.cpp | 8 +- src/SIMD.hpp | 492 ++++++++++++++++++++++++++++----- src/archImpl/avx2.cpp | 15 +- src/archImpl/avx512bw.cpp | 15 +- src/archImpl/neon.cpp | 10 - src/archImpl/sse2.cpp | 10 - src/archImpl/sse4_1.cpp | 10 - src/search.cpp | 326 ++++++++++++++++++++-- src/search.h | 49 ++++ 12 files changed, 871 insertions(+), 146 deletions(-) diff --git a/include/kiwi/ArchUtils.h b/include/kiwi/ArchUtils.h index ba2d0e85..c28e5f72 100644 --- a/include/kiwi/ArchUtils.h +++ b/include/kiwi/ArchUtils.h @@ -10,7 +10,9 @@ namespace kiwi sse2, sse4_1, avx2, + avx_vnni, avx512bw, + avx512vnni, neon, last = neon, }; @@ -57,12 +59,24 @@ namespace kiwi static constexpr size_t alignment = 32; }; + template<> + struct ArchInfo + { + static constexpr size_t alignment = 32; + }; + template<> struct ArchInfo { static constexpr size_t alignment = 64; }; + template<> + struct ArchInfo + { + static constexpr size_t alignment = 64; + }; + template<> struct ArchInfo { diff --git a/include/kiwi/TemplateUtils.hpp b/include/kiwi/TemplateUtils.hpp index 3d2c7038..8c69710b 100644 --- a/include/kiwi/TemplateUtils.hpp +++ b/include/kiwi/TemplateUtils.hpp @@ -157,5 +157,43 @@ namespace kiwi } }; } + + + template + struct SignedType { using type = IntTy; }; + + template<> + struct SignedType { using type = int8_t; }; + + template<> + struct SignedType { using type = int16_t; }; + + template<> + struct SignedType { using type = int32_t; }; + + template<> + struct SignedType { using type = int64_t; }; + + template<> + struct SignedType { using type = int16_t; }; + + + template + struct UnsignedType { using type = IntTy; }; + + template<> + struct UnsignedType { using type = uint8_t; }; + + template<> + struct UnsignedType { using type = uint16_t; }; + + template<> + struct UnsignedType { using type = uint32_t; }; + + template<> + struct UnsignedType { using type = uint64_t; }; + + template<> + struct UnsignedType { using type = uint16_t; }; } diff --git a/src/ArchAvailable.h b/src/ArchAvailable.h index 9a2a59f0..c13f4ff6 100644 --- a/src/ArchAvailable.h +++ b/src/ArchAvailable.h @@ -12,7 +12,9 @@ namespace kiwi using AvailableArch = tp::seq< #ifdef KIWI_USE_CPUINFO #if CPUINFO_ARCH_X86_64 + static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), + static_cast(ArchType::avx_vnni), static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -24,7 +26,9 @@ namespace kiwi #endif #else #ifdef KIWI_ARCH_X86_64 + static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), + static_cast(ArchType::avx_vnni), static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -38,4 +42,30 @@ namespace kiwi static_cast(ArchType::none), static_cast(ArchType::balanced) >; + + using QuantAvailableArch = tp::seq < +#ifdef KIWI_USE_CPUINFO +#if CPUINFO_ARCH_X86_64 + static_cast(ArchType::avx512vnni), + static_cast(ArchType::avx512bw), + static_cast(ArchType::avx_vnni), + static_cast(ArchType::avx2), + static_cast(ArchType::sse4_1) +#endif +#if CPUINFO_ARCH_ARM64 + static_cast(ArchType::neon) +#endif +#else +#ifdef KIWI_ARCH_X86_64 + static_cast(ArchType::avx512vnni), + static_cast(ArchType::avx512bw), + static_cast(ArchType::avx_vnni), + static_cast(ArchType::avx2), + static_cast(ArchType::sse4_1) +#endif +#ifdef KIWI_ARCH_ARM64 + static_cast(ArchType::neon) +#endif +#endif + >; } diff --git a/src/ArchUtils.cpp b/src/ArchUtils.cpp index 9342f365..4f553222 100644 --- a/src/ArchUtils.cpp +++ b/src/ArchUtils.cpp @@ -12,7 +12,9 @@ ArchType kiwi::getBestArch() #ifdef KIWI_USE_CPUINFO cpuinfo_initialize(); #if CPUINFO_ARCH_X86_64 + if (cpuinfo_has_x86_avx512vnni()) return ArchType::avx512vnni; if (cpuinfo_has_x86_avx512bw()) return ArchType::avx512bw; + if (cpuinfo_has_x86_avx_vnni_int8()) return ArchType::avx_vnni; if (cpuinfo_has_x86_avx2()) return ArchType::avx2; if (cpuinfo_has_x86_sse4_1()) return ArchType::sse4_1; #endif @@ -24,7 +26,7 @@ ArchType kiwi::getBestArch() #endif #else #ifdef KIWI_ARCH_X86_64 - return ArchType::avx512bw; + return ArchType::avx512vnni; #elif defined(__x86_64__) || defined(KIWI_ARCH_X86) return ArchType::sse2; #elif defined(KIWI_ARCH_ARM64) @@ -43,7 +45,9 @@ namespace kiwi "sse2", "sse4_1", "avx2", + "avx_vnni", "avx512bw", + "avx512vnni", "neon", }; @@ -51,7 +55,7 @@ namespace kiwi { if (arch <= ArchType::balanced) return arch; #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 || KIWI_ARCH_X86_64 || KIWI_ARCH_X86 - if (ArchType::sse2 <= arch && arch <= ArchType::avx512bw && arch <= best) + if (ArchType::sse2 <= arch && arch <= ArchType::avx512vnni && arch <= best) { return arch; } diff --git a/src/SIMD.hpp b/src/SIMD.hpp index bf9834bc..e3bada00 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -27,9 +27,8 @@ namespace kiwi using IntPacket = typename PacketTrait::IntPacket; template - class OperatorBase + struct OperatorBase { - public: enum { packetSize = PacketTrait::size }; using FPacket = typename PacketTrait::FloatPacket; @@ -40,7 +39,9 @@ namespace kiwi static STRONG_INLINE FPacket mulf(FPacket a, FPacket b) { return O::mulf(a, b); } static STRONG_INLINE FPacket divf(FPacket a, FPacket b) { return O::divf(a, b); } static STRONG_INLINE FPacket maddf(FPacket a, FPacket b, FPacket c) { return O::maddf(a, b, c); } + static STRONG_INLINE IPacket set1i(int32_t a) { return O::set1i(a); } static STRONG_INLINE FPacket set1f(float a) { return O::set1f(a); } + static STRONG_INLINE FPacket set1frombits(uint32_t a) { return O::reinterpret_as_float(O::set1i(a)); } static STRONG_INLINE FPacket loadf(const float* a) { return O::loadf(a); } static STRONG_INLINE void storef(float* a, FPacket b) { return O::storef(a, b); } static STRONG_INLINE FPacket maxf(FPacket a, FPacket b) { return O::maxf(a, b); } @@ -48,16 +49,35 @@ namespace kiwi static STRONG_INLINE FPacket floorf(FPacket a) { return O::floorf(a); } static STRONG_INLINE FPacket negatef(FPacket a) { return O::negatef(a); } static STRONG_INLINE FPacket zerof() { return O::zerof(); } + static STRONG_INLINE FPacket cast_to_float(IPacket a) { return O::cast_to_float(a); } static STRONG_INLINE IPacket cast_to_int(FPacket a) { return O::cast_to_int(a); } static STRONG_INLINE FPacket reinterpret_as_float(IPacket a) { return O::reinterpret_as_float(a); } + static STRONG_INLINE IPacket reinterpret_as_int(FPacket a) { return O::reinterpret_as_int(a); } static STRONG_INLINE float firstf(FPacket a) { return O::firstf(a); } static STRONG_INLINE float redsumf(FPacket a) { return O::redsumf(a); } static STRONG_INLINE float redmaxf(FPacket a) { return O::redmaxf(a); } static STRONG_INLINE FPacket redmaxbf(FPacket a) { return O::redmaxbf(a); } + static STRONG_INLINE IPacket band(IPacket a, IPacket b) { return O::band(a, b); } + static STRONG_INLINE FPacket band(FPacket a, FPacket b) { return O::band(a, b); } + + static STRONG_INLINE IPacket bor(IPacket a, IPacket b) { return O::bor(a, b); } + static STRONG_INLINE FPacket bor(FPacket a, FPacket b) { return O::bor(a, b); } + + static STRONG_INLINE IPacket select(IPacket mask, IPacket a, IPacket b) { return O::select(mask, a, b); } + static STRONG_INLINE FPacket select(FPacket mask, FPacket a, FPacket b) { return O::select(mask, a, b); } + + static STRONG_INLINE FPacket cmp_eq(FPacket a, FPacket b) { return O::cmp_eq(a, b); } + static STRONG_INLINE FPacket cmp_le(FPacket a, FPacket b) { return O::cmp_le(a, b); } + static STRONG_INLINE FPacket cmp_lt(FPacket a, FPacket b) { return O::cmp_lt(a, b); } + static STRONG_INLINE FPacket cmp_lt_or_nan(FPacket a, FPacket b) { return O::cmp_lt_or_nan(a, b); } + template static STRONG_INLINE IPacket sll(IPacket a) { return O::template sll(a); } + template + static STRONG_INLINE IPacket srl(IPacket a) { return O::template srl(a); } + static STRONG_INLINE FPacket ldexpf_fast(FPacket a, FPacket exponent) { static constexpr int exponentBits = 8, mantissaBits = 23; @@ -70,6 +90,19 @@ namespace kiwi return mulf(a, reinterpret_as_float(sll(e))); } + static STRONG_INLINE FPacket frexpf_fast(FPacket x, FPacket& exp) + { + // ignore nan, inf, 0, denormalized numbers. + const IPacket exp_mask = set1i(0x7F800000), + inv_exp_mask = set1i(~0x7F800000), + norm_exp = set1i(126 << 23); + const FPacket exp_bias = set1f(126); + IPacket ix = reinterpret_as_int(x); + exp = subf(cast_to_float(srl<23>(band(ix, exp_mask))), exp_bias); + ix = bor(band(ix, inv_exp_mask), norm_exp); + return reinterpret_as_float(ix); + } + static STRONG_INLINE FPacket expf(FPacket _x) { const FPacket cst_1 = set1f(1.0f); @@ -117,10 +150,94 @@ namespace kiwi // TODO: replace pldexp with faster implementation since y in [-1, 1). return maxf(ldexpf_fast(y, m), _x); } + + static STRONG_INLINE FPacket logf(FPacket _x) + { + FPacket x = _x; + + const FPacket cst_1 = set1f(1.0f); + const FPacket cst_neg_half = set1f(-0.5f); + // The smallest non denormalized float number. + const FPacket cst_min_norm_pos = set1frombits(0x00800000u); + const FPacket cst_minus_inf = set1frombits(0xff800000u); + const FPacket cst_pos_inf = set1frombits(0x7f800000u); + + // Polynomial coefficients. + const FPacket cst_cephes_SQRTHF = set1f(0.707106781186547524f); + const FPacket cst_cephes_log_p0 = set1f(7.0376836292E-2f); + const FPacket cst_cephes_log_p1 = set1f(-1.1514610310E-1f); + const FPacket cst_cephes_log_p2 = set1f(1.1676998740E-1f); + const FPacket cst_cephes_log_p3 = set1f(-1.2420140846E-1f); + const FPacket cst_cephes_log_p4 = set1f(+1.4249322787E-1f); + const FPacket cst_cephes_log_p5 = set1f(-1.6668057665E-1f); + const FPacket cst_cephes_log_p6 = set1f(+2.0000714765E-1f); + const FPacket cst_cephes_log_p7 = set1f(-2.4999993993E-1f); + const FPacket cst_cephes_log_p8 = set1f(+3.3333331174E-1f); + + // Truncate input values to the minimum positive normal. + x = maxf(x, cst_min_norm_pos); + + FPacket e; + // extract significant in the range [0.5,1) and exponent + x = frexpf_fast(x, e); + + // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) + // and shift by -1. The values are then centered around 0, which improves + // the stability of the polynomial evaluation. + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + FPacket mask = cmp_lt(x, cst_cephes_SQRTHF); + FPacket tmp = band(x, mask); + x = subf(x, cst_1); + e = subf(e, band(cst_1, mask)); + x = addf(x, tmp); + + FPacket x2 = mulf(x, x); + FPacket x3 = mulf(x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts, probably + // to improve instruction-level parallelism. + FPacket y, y1, y2; + y = maddf(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = maddf(cst_cephes_log_p3, x, cst_cephes_log_p4); + y2 = maddf(cst_cephes_log_p6, x, cst_cephes_log_p7); + y = maddf(y, x, cst_cephes_log_p2); + y1 = maddf(y1, x, cst_cephes_log_p5); + y2 = maddf(y2, x, cst_cephes_log_p8); + y = maddf(y, x3, y1); + y = maddf(y, x3, y2); + y = mulf(y, x3); + + y = maddf(cst_neg_half, x2, y); + x = addf(x, y); + + const FPacket cst_ln2 = set1f(0.69314718f); + x = maddf(e, cst_ln2, x); + + FPacket invalid_mask = cmp_lt_or_nan(_x, zerof()); + FPacket iszero_mask = cmp_eq(_x, zerof()); + FPacket pos_inf_mask = cmp_eq(_x, cst_pos_inf); + // Filter out invalid inputs, i.e.: + // - negative arg will be NAN + // - 0 will be -INF + // - +INF will be +INF + return select(iszero_mask, cst_minus_inf, + bor(select(pos_inf_mask, cst_pos_inf, x), invalid_mask)); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + return 0; + } }; + template + struct OperatorImpl; + template - class Operator; + struct Operator; } } @@ -142,17 +259,16 @@ namespace kiwi struct PacketTrait : public PacketTrait {}; -#if defined(_MSC_VER) || defined(__SSE2__) || defined(__AVX2__) - template<> - class Operator : public OperatorBase> +#if defined(_MSC_VER) || defined(__SSE2__) || defined(__SSE4_1__) || defined(__AVX2__) + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m128 addf(__m128 a, __m128 b) { return _mm_add_ps(a, b); } static STRONG_INLINE __m128 subf(__m128 a, __m128 b) { return _mm_sub_ps(a, b); } static STRONG_INLINE __m128 mulf(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } static STRONG_INLINE __m128 divf(__m128 a, __m128 b) { return _mm_div_ps(a, b); } - static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return addf(mulf(a, b), c); } + static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return O::addf(O::mulf(a, b), c); } + static STRONG_INLINE __m128i set1i(int32_t a) { return _mm_set1_epi32(a); } static STRONG_INLINE __m128 set1f(float a) { return _mm_set1_ps(a); } static STRONG_INLINE __m128 loadf(const float* a) { return _mm_load_ps(a); } static STRONG_INLINE void storef(float* a, __m128 b) { return _mm_store_ps(a, b); } @@ -165,29 +281,44 @@ namespace kiwi return _mm_and_ps(a, mask); } - static STRONG_INLINE __m128 selectf(__m128 mask, __m128 a, __m128 b) + static STRONG_INLINE __m128 band(__m128 a, __m128 b) { return _mm_and_ps(a, b); } + static STRONG_INLINE __m128i band(__m128i a, __m128i b) { return _mm_and_si128(a, b); } + + static STRONG_INLINE __m128 bor(__m128 a, __m128 b) { return _mm_or_ps(a, b); } + static STRONG_INLINE __m128i bor(__m128i a, __m128i b) { return _mm_or_si128(a, b); } + + static STRONG_INLINE __m128 select(__m128 mask, __m128 a, __m128 b) { return _mm_or_ps(_mm_and_ps(mask, a), _mm_andnot_ps(mask, b)); } + static STRONG_INLINE __m128i select(__m128i mask, __m128i a, __m128i b) + { + return _mm_or_si128(_mm_and_si128(mask, a), _mm_andnot_si128(mask, b)); + } + + static STRONG_INLINE __m128 cmp_eq(__m128 a, __m128 b) { return _mm_cmpeq_ps(a, b); } + static STRONG_INLINE __m128 cmp_le(__m128 a, __m128 b) { return _mm_cmple_ps(a, b); } + static STRONG_INLINE __m128 cmp_lt(__m128 a, __m128 b) { return _mm_cmplt_ps(a, b); } + static STRONG_INLINE __m128 cmp_lt_or_nan(__m128 a, __m128 b) { return _mm_cmpnge_ps(a, b); } static STRONG_INLINE __m128 rint(__m128 a) { - const __m128 limit = set1f(static_cast(1 << 23)); - const __m128 abs_a = absf(a); - __m128 r = addf(abs_a, limit); + const __m128 limit = O::set1f(static_cast(1 << 23)); + const __m128 abs_a = O::absf(a); + __m128 r = O::addf(abs_a, limit); #ifdef __GNUC__ __asm__("" : "+g,x" (r)); #endif - r = subf(r, limit); + r = O::subf(r, limit); - r = selectf(_mm_cmplt_ps(abs_a, limit), - selectf(_mm_cmplt_ps(a, zerof()), negatef(r), r), a); + r = O::select(_mm_cmplt_ps(abs_a, limit), + O::select(_mm_cmplt_ps(a, O::zerof()), O::negatef(r), r), a); return r; } static STRONG_INLINE __m128 floorf(__m128 a) { - const __m128 cst_1 = set1f(1.0f); + const __m128 cst_1 = O::set1f(1.0f); __m128 tmp = rint(a); __m128 mask = _mm_cmpgt_ps(tmp, a); mask = _mm_and_ps(mask, cst_1); @@ -197,8 +328,13 @@ namespace kiwi static STRONG_INLINE __m128 zerof() { return _mm_setzero_ps(); } static STRONG_INLINE __m128 negatef(__m128 a) { return subf(zerof(), a); } static STRONG_INLINE __m128i cast_to_int(__m128 a) { return _mm_cvtps_epi32(a); } + static STRONG_INLINE __m128 cast_to_float(__m128i a) { return _mm_cvtepi32_ps(a); } + static STRONG_INLINE __m128i reinterpret_as_int(__m128 a) { return _mm_castps_si128(a); } static STRONG_INLINE __m128 reinterpret_as_float(__m128i a) { return _mm_castsi128_ps(a); } + template static STRONG_INLINE __m128i sll(__m128i a) { return _mm_slli_epi32(a, bit); } + template static STRONG_INLINE __m128i srl(__m128i a) { return _mm_srli_epi32(a, bit); } + static STRONG_INLINE float firstf(__m128 a) { return _mm_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m128 a) @@ -219,49 +355,46 @@ namespace kiwi return _mm_max_ps(tmp, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); } }; -#endif -#if defined(_MSC_VER) || defined(__SSE4_1__) || defined(__AVX2__) template<> - class Operator : public OperatorBase> + struct Operator : public OperatorImpl> { - public: - - static STRONG_INLINE __m128 addf(__m128 a, __m128 b) { return _mm_add_ps(a, b); } - static STRONG_INLINE __m128 subf(__m128 a, __m128 b) { return _mm_sub_ps(a, b); } - static STRONG_INLINE __m128 mulf(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } - static STRONG_INLINE __m128 divf(__m128 a, __m128 b) { return _mm_div_ps(a, b); } - static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return addf(mulf(a, b), c); } - static STRONG_INLINE __m128 set1f(float a) { return _mm_set1_ps(a); } - static STRONG_INLINE __m128 loadf(const float* a) { return _mm_load_ps(a); } - static STRONG_INLINE void storef(float* a, __m128 b) { return _mm_store_ps(a, b); } - static STRONG_INLINE __m128 maxf(__m128 a, __m128 b) { return _mm_max_ps(a, b); } - static STRONG_INLINE __m128 minf(__m128 a, __m128 b) { return _mm_min_ps(a, b); } - static STRONG_INLINE __m128 floorf(__m128 a) { return _mm_floor_ps(a); } - static STRONG_INLINE __m128 zerof() { return _mm_setzero_ps(); } - static STRONG_INLINE __m128 negatef(__m128 a) { return subf(zerof(), a); } - static STRONG_INLINE __m128i cast_to_int(__m128 a) { return _mm_cvtps_epi32(a); } - static STRONG_INLINE __m128 reinterpret_as_float(__m128i a) { return _mm_castsi128_ps(a); } - template static STRONG_INLINE __m128i sll(__m128i a) { return _mm_slli_epi32(a, bit); } - static STRONG_INLINE float firstf(__m128 a) { return _mm_cvtss_f32(a); } + }; +#endif - static STRONG_INLINE float redsumf(__m128 a) +#if defined(_MSC_VER) || defined(__SSE4_1__) || defined(__AVX2__) + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE __m128 select(__m128 mask, __m128 a, __m128 b) + { + return _mm_blendv_ps(b, a, mask); + } + static STRONG_INLINE __m128i select(__m128i mask, __m128i a, __m128i b) { - __m128 tmp = _mm_add_ps(a, _mm_movehl_ps(a, a)); - return firstf(_mm_add_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1))); + return _mm_blendv_epi32(b, a, mask); } - static STRONG_INLINE float redmaxf(__m128 a) - { - __m128 tmp = _mm_max_ps(a, _mm_movehl_ps(a, a)); - return firstf(_mm_max_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1))); - } + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m128i pa, pb, sum = _mm_setzero_si128(); + __m128i one16 = _mm_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 16) + { + pa = _mm_loadu_si128(reinterpret_cast(a + i)); + pb = _mm_loadu_si128(reinterpret_cast(b + i)); + pt = _mm_maddubs_epi16(pa, pb); + sum = _mm_add_epi32(sum, _mm_madd_epi16(pt, one16)); + } + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); + } + }; - static STRONG_INLINE __m128 redmaxbf(__m128 a) - { - __m128 tmp = _mm_max_ps(a, _mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 0, 3, 2))); - return _mm_max_ps(tmp, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); - } + template<> + struct Operator : public OperatorImpl> + { }; #endif @@ -274,16 +407,15 @@ namespace kiwi using FloatPacket = __m256; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m256 addf(__m256 a, __m256 b) { return _mm256_add_ps(a, b); } static STRONG_INLINE __m256 subf(__m256 a, __m256 b) { return _mm256_sub_ps(a, b); } static STRONG_INLINE __m256 mulf(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); } static STRONG_INLINE __m256 divf(__m256 a, __m256 b) { return _mm256_div_ps(a, b); } static STRONG_INLINE __m256 maddf(__m256 a, __m256 b, __m256 c) { return _mm256_fmadd_ps(a, b, c); } + static STRONG_INLINE __m256i set1i(int32_t a) { return _mm256_set1_epi32(a); } static STRONG_INLINE __m256 set1f(float a) { return _mm256_set1_ps(a); } static STRONG_INLINE __m256 loadf(const float* a) { return _mm256_load_ps(a); } static STRONG_INLINE void storef(float* a, __m256 b) { return _mm256_store_ps(a, b); } @@ -293,8 +425,11 @@ namespace kiwi static STRONG_INLINE __m256 zerof() { return _mm256_setzero_ps(); } static STRONG_INLINE __m256 negatef(__m256 a) { return subf(zerof(), a); } static STRONG_INLINE __m256i cast_to_int(__m256 a) { return _mm256_cvtps_epi32(a); } + static STRONG_INLINE __m256 cast_to_float(__m256i a) { return _mm256_cvtepi32_ps(a); } + static STRONG_INLINE __m256i reinterpret_as_int(__m256 a) { return _mm256_castps_si256(a); } static STRONG_INLINE __m256 reinterpret_as_float(__m256i a) { return _mm256_castsi256_ps(a); } template static STRONG_INLINE __m256i sll(__m256i a) { return _mm256_slli_epi32(a, bit); } + template static STRONG_INLINE __m256i srl(__m256i a) { return _mm256_srli_epi32(a, bit); } static STRONG_INLINE float firstf(__m256 a) { return _mm256_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m256 a) @@ -314,8 +449,149 @@ namespace kiwi __m256 tmp = _mm256_max_ps(a, _mm256_permute2f128_ps(a, a, 1)); tmp = _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2))); return _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); + } + + static STRONG_INLINE __m256 band(__m256 a, __m256 b) { return _mm256_and_ps(a, b); } + static STRONG_INLINE __m256i band(__m256i a, __m256i b) { return _mm256_and_si256(a, b); } + + static STRONG_INLINE __m256 bor(__m256 a, __m256 b) { return _mm256_or_ps(a, b); } + static STRONG_INLINE __m256i bor(__m256i a, __m256i b) { return _mm256_or_si256(a, b); } + + static STRONG_INLINE __m256 select(__m256 mask, __m256 a, __m256 b) { return _mm256_blendv_ps(b, a, mask); } + static STRONG_INLINE __m256i select(__m256i mask, __m256i a, __m256i b) { return _mm256_blendv_epi32(b, a, mask); } + + static STRONG_INLINE __m256 cmp_eq(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_EQ_OQ); } + static STRONG_INLINE __m256 cmp_le(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_LE_OQ); } + static STRONG_INLINE __m256 cmp_lt(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_LT_OQ); } + static STRONG_INLINE __m256 cmp_lt_or_nan(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } + + static STRONG_INLINE void load_transposed(const float* a, size_t stride, + __m256& r0, __m256& r1, __m256& r2, __m256& r3, + __m256& r4, __m256& r5, __m256& r6, __m256& r7 + ) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + + r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[0 * stride + 0])), _mm_load_ps(&a[4 * stride + 0]), 1); + r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[1 * stride + 0])), _mm_load_ps(&a[5 * stride + 0]), 1); + r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[2 * stride + 0])), _mm_load_ps(&a[6 * stride + 0]), 1); + r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[3 * stride + 0])), _mm_load_ps(&a[7 * stride + 0]), 1); + r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[0 * stride + 4])), _mm_load_ps(&a[4 * stride + 4]), 1); + r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[1 * stride + 4])), _mm_load_ps(&a[5 * stride + 4]), 1); + r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[2 * stride + 4])), _mm_load_ps(&a[6 * stride + 4]), 1); + r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[3 * stride + 4])), _mm_load_ps(&a[7 * stride + 4]), 1); + + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0, t2, 0x44); + r1 = _mm256_shuffle_ps(t0, t2, 0xEE); + r2 = _mm256_shuffle_ps(t1, t3, 0x44); + r3 = _mm256_shuffle_ps(t1, t3, 0xEE); + r4 = _mm256_shuffle_ps(t4, t6, 0x44); + r5 = _mm256_shuffle_ps(t4, t6, 0xEE); + r6 = _mm256_shuffle_ps(t5, t7, 0x44); + r7 = _mm256_shuffle_ps(t5, t7, 0xEE); + } + + static STRONG_INLINE void store_transposed(float* a, size_t stride, + __m256 r0, __m256 r1, __m256 r2, __m256 r3, + __m256 r4, __m256 r5, __m256 r6, __m256 r7 + ) + { + __m256 t0 = _mm256_unpacklo_ps(r0, r1); + __m256 t1 = _mm256_unpackhi_ps(r0, r1); + __m256 t2 = _mm256_unpacklo_ps(r2, r3); + __m256 t3 = _mm256_unpackhi_ps(r2, r3); + __m256 t4 = _mm256_unpacklo_ps(r4, r5); + __m256 t5 = _mm256_unpackhi_ps(r4, r5); + __m256 t6 = _mm256_unpacklo_ps(r6, r7); + __m256 t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0)); + r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2)); + r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0)); + r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2)); + r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0)); + r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2)); + r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0)); + r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2)); + + t0 = _mm256_permute2f128_ps(r0, r4, 0x20); + t1 = _mm256_permute2f128_ps(r1, r5, 0x20); + t2 = _mm256_permute2f128_ps(r2, r6, 0x20); + t3 = _mm256_permute2f128_ps(r3, r7, 0x20); + t4 = _mm256_permute2f128_ps(r0, r4, 0x31); + t5 = _mm256_permute2f128_ps(r1, r5, 0x31); + t6 = _mm256_permute2f128_ps(r2, r6, 0x31); + t7 = _mm256_permute2f128_ps(r3, r7, 0x31); + + _mm256_store_ps(&a[0 * stride], t0); + _mm256_store_ps(&a[1 * stride], t1); + _mm256_store_ps(&a[2 * stride], t2); + _mm256_store_ps(&a[3 * stride], t3); + _mm256_store_ps(&a[4 * stride], t4); + _mm256_store_ps(&a[5 * stride], t5); + _mm256_store_ps(&a[6 * stride], t6); + _mm256_store_ps(&a[7 * stride], t7); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m256i pa, pb, acc = _mm256_setzero_si256(); + __m256i one16 = _mm256_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 32) + { + pa = _mm256_loadu_si256(reinterpret_cast(&a[i])); + pb = _mm256_loadu_si256(reinterpret_cast(&b[i])); + pt = _mm256_maddubs_epi16(pa, pb); + acc = _mm256_add_epi32(acc, _mm256_madd_epi16(pt, one16)); + } + // reduce sum of eight int32_t to one int32_t + __m256i sum = _mm256_hadd_epi32(acc, acc); + sum = _mm256_hadd_epi32(sum, sum); + return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); } }; + + template<> + struct Operator : public OperatorImpl> + { + }; + + template<> + struct PacketTrait : public PacketTrait + { + }; + + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m256i pa, pb, acc = _mm256_setzero_si256(); + for (size_t i = 0; i < size; i += 32) + { + pa = _mm256_loadu_si256(reinterpret_cast(&a[i])); + pb = _mm256_loadu_si256(reinterpret_cast(&b[i])); + acc = _mm256_dpbusd_epi32(acc, pa, pb); + } + // reduce sum of eight int32_t to one int32_t + __m256i sum = _mm256_hadd_epi32(acc, acc); + sum = _mm256_hadd_epi32(sum, sum); + return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); + } + }; + + template<> + struct Operator : public OperatorImpl> + { + }; #endif #if defined(_MSC_VER) || defined(__AVX512F__) || defined(__AVX512BW__) @@ -327,16 +603,15 @@ namespace kiwi using FloatPacket = __m512; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m512 addf(__m512 a, __m512 b) { return _mm512_add_ps(a, b); } static STRONG_INLINE __m512 subf(__m512 a, __m512 b) { return _mm512_sub_ps(a, b); } static STRONG_INLINE __m512 mulf(__m512 a, __m512 b) { return _mm512_mul_ps(a, b); } static STRONG_INLINE __m512 divf(__m512 a, __m512 b) { return _mm512_div_ps(a, b); } static STRONG_INLINE __m512 maddf(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } + static STRONG_INLINE __m512i set1i(int32_t a) { return _mm512_set1_epi32(a); } static STRONG_INLINE __m512 set1f(float a) { return _mm512_set1_ps(a); } static STRONG_INLINE __m512 loadf(const float* a) { return _mm512_load_ps(a); } static STRONG_INLINE void storef(float* a, __m512 b) { return _mm512_store_ps(a, b); } @@ -346,8 +621,11 @@ namespace kiwi static STRONG_INLINE __m512 zerof() { return _mm512_setzero_ps(); } static STRONG_INLINE __m512 negatef(__m512 a) { return subf(zerof(), a); } static STRONG_INLINE __m512i cast_to_int(__m512 a) { return _mm512_cvtps_epi32(a); } + static STRONG_INLINE __m512 cast_to_float(__m512i a) { return _mm512_cvtepi32_ps(a); } + static STRONG_INLINE __m512i reinterpret_as_int(__m512 a) { return _mm512_castps_si512(a); } static STRONG_INLINE __m512 reinterpret_as_float(__m512i a) { return _mm512_castsi512_ps(a); } template static STRONG_INLINE __m512i sll(__m512i a) { return _mm512_slli_epi32(a, bit); } + template static STRONG_INLINE __m512i srl(__m512i a) { return _mm512_srli_epi32(a, bit); } static STRONG_INLINE float firstf(__m512 a) { return _mm512_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m512 a) @@ -376,8 +654,91 @@ namespace kiwi static STRONG_INLINE __m512 redmaxbf(__m512 a) { return set1f(redmaxf(a)); + } + + static STRONG_INLINE __m512 band(__m512 a, __m512 b) { return _mm512_and_ps(a, b); } + static STRONG_INLINE __m512i band(__m512i a, __m512i b) { return _mm512_and_si512(a, b); } + + static STRONG_INLINE __m512 bor(__m512 a, __m512 b) { return _mm512_or_ps(a, b); } + static STRONG_INLINE __m512i bor(__m512i a, __m512i b) { return _mm512_or_si512(a, b); } + + static STRONG_INLINE __m512 select(__m512 mask, __m512 a, __m512 b) + { + __mmask16 mask16 = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mask16, a, b); + } + static STRONG_INLINE __m512i select(__m512i mask, __m512i a, __m512i b) + { + __mmask16 mask16 = _mm512_cmp_epi32_mask(mask, _mm512_setzero_si512(), _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask16, a, b); + } + + static STRONG_INLINE __m512 cmp_eq(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_le(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_lt(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_lt_or_nan(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m512i pa, pb, acc = _mm512_setzero_si512(); + __m512i one16 = _mm512_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 64) + { + pa = _mm512_loadu_si512(reinterpret_cast(&a[i])); + pb = _mm512_loadu_si512(reinterpret_cast(&b[i])); + pt = _mm512_maddubs_epi16(pa, pb); + acc = _mm512_add_epi32(acc, _mm512_madd_epi16(pt, one16)); + } + return _mm512_reduce_add_epi32(acc); } }; + + template<> + struct Operator : public OperatorImpl> + { + }; + + template<> + struct PacketTrait : public PacketTrait + { + }; + + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m512i pa, pb, acc = _mm512_setzero_si512(); + for (size_t i = 0; i < size; i += 64) + { + pa = _mm512_loadu_si512(reinterpret_cast(&a[i])); + pb = _mm512_loadu_si512(reinterpret_cast(&b[i])); + acc = _mm512_dpbusd_epi32(acc, pa, pb); + } + return _mm512_reduce_add_epi32(acc); + } + }; + + template<> + struct Operator : public OperatorImpl> + { + }; #endif } } @@ -396,11 +757,9 @@ namespace kiwi using FloatPacket = float32x4_t; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE float32x4_t addf(float32x4_t a, float32x4_t b) { return vaddq_f32(a, b); } static STRONG_INLINE float32x4_t subf(float32x4_t a, float32x4_t b) { return vsubq_f32(a, b); } static STRONG_INLINE float32x4_t mulf(float32x4_t a, float32x4_t b) { return vmulq_f32(a, b); } @@ -435,6 +794,11 @@ namespace kiwi { return set1f(redmaxf(a)); } + }; + + template<> + struct Operator : public OperatorImpl> + { }; } } diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 3d8bd19a..6ff2c16e 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -4,19 +4,14 @@ namespace kiwi { namespace lm { - template<> - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; - template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; } } diff --git a/src/archImpl/avx512bw.cpp b/src/archImpl/avx512bw.cpp index 03f6d8b2..867418cd 100644 --- a/src/archImpl/avx512bw.cpp +++ b/src/archImpl/avx512bw.cpp @@ -4,19 +4,14 @@ namespace kiwi { namespace lm { - template<> - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; - template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; } } diff --git a/src/archImpl/neon.cpp b/src/archImpl/neon.cpp index 87ffe397..06c3a5ad 100644 --- a/src/archImpl/neon.cpp +++ b/src/archImpl/neon.cpp @@ -4,16 +4,6 @@ namespace kiwi { namespace lm { - template<> - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; - template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; diff --git a/src/archImpl/sse2.cpp b/src/archImpl/sse2.cpp index 25eab888..8df7a094 100644 --- a/src/archImpl/sse2.cpp +++ b/src/archImpl/sse2.cpp @@ -4,16 +4,6 @@ namespace kiwi { namespace lm { - template<> - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; - template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index d5759b47..bff67239 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -4,16 +4,6 @@ namespace kiwi { namespace lm { - template<> - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; - template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; diff --git a/src/search.cpp b/src/search.cpp index 6f14f0ef..821a93bd 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -1,9 +1,11 @@ #include #include +#include #include #include +#include #include "ArchAvailable.h" #include "search.h" @@ -14,11 +16,22 @@ template bool detail::searchImpl(const uint32_t*, size_t, uint32_t, size_t&);\ template bool detail::searchImpl(const uint64_t*, size_t, uint64_t, size_t&);\ template bool detail::searchImpl(const char16_t*, size_t, char16_t, size_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint8_t, uint32_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint16_t, uint32_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint32_t, uint32_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint64_t, uint32_t&);\ + template bool detail::searchKVImpl(const void*, size_t, char16_t, uint32_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint8_t, uint64_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint16_t, uint64_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint32_t, uint64_t&);\ + template bool detail::searchKVImpl(const void*, size_t, uint64_t, uint64_t&);\ + template bool detail::searchKVImpl(const void*, size_t, char16_t, uint64_t&);\ template Vector detail::reorderImpl(const uint8_t*, size_t);\ template Vector detail::reorderImpl(const uint16_t*, size_t);\ template Vector detail::reorderImpl(const uint32_t*, size_t);\ template Vector detail::reorderImpl(const uint64_t*, size_t);\ - template Vector detail::reorderImpl(const char16_t*, size_t); + template Vector detail::reorderImpl(const char16_t*, size_t);\ + template size_t detail::getPacketSizeImpl(); #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 || KIWI_ARCH_X86 || KIWI_ARCH_X86_64 #include @@ -149,12 +162,27 @@ namespace kiwi template bool detail::searchImpl(const IntTy* keys, size_t size, IntTy target, size_t& ret) { - return OptimizedImpl::template search(keys, size, target, ret); + return OptimizedImpl::search(keys, size, target, ret); + + } + + template + bool detail::searchKVImpl(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + return OptimizedImpl::searchKV(kv, size, target, ret); + } + + template + size_t detail::getPacketSizeImpl() + { + return OptimizedImpl::packetSize; } template<> struct OptimizedImpl { + static constexpr size_t packetSize = 0; + template static Vector reorder(const IntTy* keys, size_t size) { @@ -166,12 +194,28 @@ namespace kiwi { return bstSearch(keys, size, target, ret); } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + size_t idx; + const IntTy* keys = reinterpret_cast(kv); + const ValueTy* values = reinterpret_cast(keys + size); + if (search(keys, size, target, idx)) + { + ret = values[idx]; + return true; + } + else return false; + } }; INSTANTIATE_IMPL(ArchType::none); template<> struct OptimizedImpl { + static constexpr size_t packetSize = 0; + template static Vector reorder(const IntTy* keys, size_t size) { @@ -197,26 +241,22 @@ namespace kiwi ret = left1; return true; } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + size_t idx; + const IntTy* keys = reinterpret_cast(kv); + const ValueTy* values = reinterpret_cast(keys + size); + if (search(keys, size, target, idx)) + { + ret = values[idx]; + return true; + } + else return false; + } }; INSTANTIATE_IMPL(ArchType::balanced); - - template - struct SignedType { using type = IntTy; }; - - template<> - struct SignedType { using type = int8_t; }; - - template<> - struct SignedType { using type = int16_t; }; - - template<> - struct SignedType { using type = int32_t; }; - - template<> - struct SignedType { using type = int64_t; }; - - template<> - struct SignedType { using type = int16_t; }; } } @@ -401,21 +441,83 @@ namespace kiwi return false; } + template + ARCH_TARGET("sse2") + bool nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + { + size_t i = 0, r; + + __m128i ptarget, pkey, peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm_set1_epi8(target); + break; + case 2: + ptarget = _mm_set1_epi16(target); + break; + case 4: + ptarget = _mm_set1_epi32(target); + break; + } + + while (i < size) + { + pkey = _mm_loadu_si128(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm_cmpeq_epi8(ptarget, pkey); + pgt = _mm_cmpgt_epi8(ptarget, pkey); + break; + case 2: + peq = _mm_cmpeq_epi16(ptarget, pkey); + pgt = _mm_cmpgt_epi16(ptarget, pkey); + break; + case 4: + peq = _mm_cmpeq_epi32(ptarget, pkey); + pgt = _mm_cmpgt_epi32(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + ret = values[r]; + return true; + } + + r = utils::popcount((uint32_t)_mm_movemask_epi8(pgt)) / sizeof(IntTy); + i = i * n + (n - 1) * (r + 1); + } + return false; + } + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchSSE2<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchSSE2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVSSE2((const uint8_t*)kv, size, (SignedIntTy)target, ret); } }; INSTANTIATE_IMPL(ArchType::sse2); @@ -642,6 +744,66 @@ namespace kiwi return false; } + template + ARCH_TARGET("avx2") + bool nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + { + size_t i = 0, r; + + __m256i ptarget, pkey, peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm256_set1_epi8(target); + break; + case 2: + ptarget = _mm256_set1_epi16(target); + break; + case 4: + ptarget = _mm256_set1_epi32(target); + break; + case 8: + ptarget = _mm256_set1_epi64x(target); + break; + } + + while (i < size) + { + pkey = _mm256_loadu_si256(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm256_cmpeq_epi8(ptarget, pkey); + pgt = _mm256_cmpgt_epi8(ptarget, pkey); + break; + case 2: + peq = _mm256_cmpeq_epi16(ptarget, pkey); + pgt = _mm256_cmpgt_epi16(ptarget, pkey); + break; + case 4: + peq = _mm256_cmpeq_epi32(ptarget, pkey); + pgt = _mm256_cmpgt_epi32(ptarget, pkey); + break; + case 8: + peq = _mm256_cmpeq_epi64(ptarget, pkey); + pgt = _mm256_cmpgt_epi64(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + ret = values[r]; + return true; + } + + r = utils::popcount((uint32_t)_mm256_movemask_epi8(pgt)) / sizeof(IntTy); + i = i * n + (n - 1) * (r + 1); + } + return false; + } + template ARCH_TARGET("avx512bw") bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) @@ -832,21 +994,93 @@ namespace kiwi return false; } + template + ARCH_TARGET("avx512bw") + bool nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + { + size_t i = 0, r; + const IntTy* keys; + + __m512i ptarget, pkey; + uint64_t peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm512_set1_epi8(target); + break; + case 2: + ptarget = _mm512_set1_epi16(target); + break; + case 4: + ptarget = _mm512_set1_epi32(target); + break; + case 8: + ptarget = _mm512_set1_epi64(target); + break; + } + + while (i < size) + { + keys = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))]); + pkey = _mm512_loadu_si512(reinterpret_cast(keys)); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm512_cmpeq_epi8_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi8_mask(ptarget, pkey); + break; + case 2: + peq = _mm512_cmpeq_epi16_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi16_mask(ptarget, pkey); + break; + case 4: + peq = _mm512_cmpeq_epi32_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi32_mask(ptarget, pkey); + break; + case 8: + peq = _mm512_cmpeq_epi64_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi64_mask(ptarget, pkey); + break; + } + + if (testEqMask(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&keys[groupSize]); + ret = values[r]; + return true; + } + + r = utils::popcount(pgt); + i = i * n + (n - 1) * (r + 1); + } + return false; + } + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchSSE2<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchSSE2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVSSE2((const uint8_t*)kv, size, (SignedIntTy)target, ret); } }; INSTANTIATE_IMPL(ArchType::sse4_1); @@ -854,40 +1088,70 @@ namespace kiwi template<> struct OptimizedImpl { + static constexpr size_t packetSize = 32; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<32 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchAVX2<32 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchAVX2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVAVX2((const uint8_t*)kv, size, (SignedIntTy)target, ret); } }; INSTANTIATE_IMPL(ArchType::avx2); + template<> + struct OptimizedImpl : public OptimizedImpl + { + }; + INSTANTIATE_IMPL(ArchType::avx_vnni); + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 64; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<64 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchAVX512<64 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchAVX512((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVAVX512((const uint8_t*)kv, size, (SignedIntTy)target, ret); } }; INSTANTIATE_IMPL(ArchType::avx512bw); + + template<> + struct OptimizedImpl : public OptimizedImpl + { + }; + INSTANTIATE_IMPL(ArchType::avx512vnni); } } #endif @@ -1023,18 +1287,20 @@ namespace kiwi template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchNeon<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchNeon((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); } }; INSTANTIATE_IMPL(ArchType::neon); diff --git a/src/search.h b/src/search.h index e3a118de..f3c4fa9d 100644 --- a/src/search.h +++ b/src/search.h @@ -20,8 +20,14 @@ namespace kiwi template bool searchImpl(const IntTy* keys, size_t size, IntTy target, size_t& ret); + template + bool searchKVImpl(const void* keys, size_t size, IntTy target, ValueTy& ret); + template Vector reorderImpl(const IntTy* keys, size_t size); + + template + size_t getPacketSizeImpl(); } template @@ -50,6 +56,37 @@ namespace kiwi } } + template + void prepareKV(void* dest, size_t size, Vector& tempBuf) + { + const size_t packetSize = detail::getPacketSizeImpl() / sizeof(IntTy); + if (size <= 1 || packetSize <= 1) return; + auto order = detail::reorderImpl(reinterpret_cast(dest), size); + if (order.empty()) return; + + if (tempBuf.size() < (sizeof(IntTy) + sizeof(Value)) * size) + { + tempBuf.resize((sizeof(IntTy) + sizeof(Value)) * size); + } + std::memcpy(tempBuf.data(), dest, (sizeof(IntTy) + sizeof(Value)) * size); + auto tempKeys = (IntTy*)tempBuf.data(); + auto tempValues = (Value*)(tempKeys + size); + for (size_t i = 0; i < size; i += packetSize) + { + const size_t groupSize = std::min(packetSize, size - i); + for (size_t j = 0; j < groupSize; ++j) + { + *reinterpret_cast(dest) = tempKeys[order[i + j]]; + dest = reinterpret_cast(dest) + sizeof(IntTy); + } + for (size_t j = 0; j < groupSize; ++j) + { + *reinterpret_cast(dest) = tempValues[order[i + j]]; + dest = reinterpret_cast(dest) + sizeof(Value); + } + } + } + template bool search(const IntTy* keys, const Value* values, size_t size, IntTy target, Out& ret) { @@ -73,5 +110,17 @@ namespace kiwi } else return false; } + + template + bool searchKV(const void* kv, size_t size, IntTy target, Out& ret) + { + typename UnsignedType::type out; + if (detail::searchKVImpl(kv, size, target, out)) + { + ret = out; + return true; + } + else return false; + } } } From 3fbbbfafa1d0ff54d57dd9e2ad4685292762b829 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 9 Feb 2025 03:00:07 +0900 Subject: [PATCH 11/53] Implement quantized PcLM --- include/kiwi/PCLanguageModel.h | 14 +- include/kiwi/Types.h | 2 + src/KiwiBuilder.cpp | 7 +- src/PCLanguageModel.cpp | 1079 ++++++++++++++++++++++++---- src/PCLanguageModel.hpp | 304 ++++++-- src/SkipBigramModelImpl.hpp | 253 ++++++- src/qgemm.cpp | 743 +++++++++++++++++++ src/qgemm.h | 38 + vsproj/kiwi_shared_library.vcxproj | 2 + 9 files changed, 2218 insertions(+), 224 deletions(-) create mode 100644 src/qgemm.cpp create mode 100644 src/qgemm.h diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h index c58835c0..536abd36 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/PCLanguageModel.h @@ -37,21 +37,21 @@ namespace kiwi class PcLangModelBase : public ILangModel { protected: - utils::MemoryObject base; + const size_t memorySize = 0; + PcLangModelHeader header; - PcLangModelBase(utils::MemoryObject&& mem) : base{ std::move(mem) } + PcLangModelBase(const utils::MemoryObject& mem) : memorySize{ mem.size() }, header{ *reinterpret_cast(mem.get()) } { } public: virtual ~PcLangModelBase() {} - size_t vocabSize() const override { return getHeader().vocabSize; } - ModelType getType() const override { return ModelType::pclm; } - size_t getMemorySize() const override { return base.size(); } + size_t vocabSize() const override { return header.vocabSize; } + size_t getMemorySize() const override { return memorySize; } - const PcLangModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } + const PcLangModelHeader& getHeader() const { return header; } static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, bool reorderContextIdx = true); - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); }; } } diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index 78144de0..1a7a260c 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -310,6 +310,8 @@ namespace kiwi sbg = 2, /**< Skip-Bigram Model */ pclm = 3, /**< Pre-computed Context Language Model */ pclmLocal = 4, /**< Pre-computed Context Language Model (Only local context) */ + pclmQuantized = 5, /**< Pre-computed Context Language Model (quantized) */ + pclmLocalQuantized = 6, /**< Pre-computed Context Language Model (Only local context, quantized) */ knlmTransposed, }; diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index ee74b8cc..24a9459c 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -791,9 +791,12 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio { langMdl = lm::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); } - else if (modelType == ModelType::pclm || modelType == ModelType::pclmLocal) + else if (modelType == ModelType::pclm || modelType == ModelType::pclmLocal + || modelType == ModelType::pclmQuantized || modelType == ModelType::pclmLocalQuantized) { - langMdl = lm::PcLangModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, modelType == ModelType::pclm); + langMdl = lm::PcLangModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, + (modelType == ModelType::pclm || modelType == ModelType::pclmQuantized), + (modelType == ModelType::pclmQuantized || modelType == ModelType::pclmLocalQuantized)); } if (!!(options & BuildOption::loadDefaultDict)) diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index c5e3c220..161b0628 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -6,11 +6,278 @@ #include "PCLanguageModel.hpp" #include "StrUtils.h" #include "FrozenTrie.hpp" +#include "qgemm.h" using namespace std; namespace kiwi { + template + struct MorphemeEvaluator> + { + using LmState = lm::PcLMState; + + template + void eval( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const size_t totalPrevPathes, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ) const + { + thread_local BestPathConatiner bestPathCont; + thread_local Vector*> regularPrevPathes; + thread_local Vector*>> combiningPrevPathes; + thread_local Vector regularMorphs, regularDistantMorphs, combiningLMorphs, combiningRMorphs; + thread_local Vector prevLmStates, nextLmStates; + thread_local Vector nextWids, nextDistantWids; + thread_local Vector scores; + + const auto* langMdl = static_cast*>(kw->getLangModel()); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const size_t langVocabSize = langMdl->vocabSize(); + + regularPrevPathes.clear(); + combiningPrevPathes.clear(); + regularMorphs.clear(); + regularDistantMorphs.clear(); + combiningLMorphs.clear(); + combiningRMorphs.clear(); + prevLmStates.clear(); + nextWids.clear(); + nextDistantWids.clear(); + + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) + { + if (prevPath.combineSocket) + { + combiningPrevPathes.emplace_back(prev, &prevPath); + continue; + } + regularPrevPathes.emplace_back(&prevPath); + prevLmStates.emplace_back(prevPath.lmState); + } + } + + for (auto& curMorph : morphs) + { + if (curMorph->combineSocket) + { + (curMorph->isSingle() ? combiningLMorphs : combiningRMorphs).emplace_back(curMorph); + continue; + } + Wid firstWid; + if (curMorph->isSingle()) + { + firstWid = curMorph->lmMorphemeId; + } + else + { + firstWid = curMorph->chunks[0]->lmMorphemeId; + } + + if (morphBase[firstWid].tag == POSTag::p) + { + continue; + } + if (windowSize > 0 && langMdl->distantTokenMask(firstWid)) + { + regularDistantMorphs.emplace_back(curMorph); + nextDistantWids.emplace_back(firstWid); + } + else + { + regularMorphs.emplace_back(curMorph); + nextWids.emplace_back(firstWid); + } + } + + if (windowSize > 0) + { + regularMorphs.insert(regularMorphs.end(), regularDistantMorphs.begin(), regularDistantMorphs.end()); + nextWids.insert(nextWids.end(), nextDistantWids.begin(), nextDistantWids.end()); + } + + if (nextWids.size() > 0) + { + nextLmStates.resize(prevLmStates.size() * nextWids.size()); + scores.resize(prevLmStates.size() * nextWids.size()); + langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + } + + for (size_t curId = 0; curId < regularMorphs.size(); ++curId) + { + const auto* curMorph = regularMorphs[curId]; + bestPathCont.clear(); + + size_t length = 1; + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + length = curMorph->chunks.size(); + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + size_t prevId = -1; + for (auto* prevPath : regularPrevPathes) + { + ++prevId; + auto& state = nextLmStates[prevId * regularMorphs.size() + curId]; + auto score = prevPath->accScore + morphScore + scores[prevId * regularMorphs.size() + curId]; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + + for (size_t i = 1; i < length; ++i) + { + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto continueFor; + } + score += state.next(langMdl, wid); + } + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + continueFor:; + } + + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + + for (auto* curMorph : combiningLMorphs) + { + bestPathCont.clear(); + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + for (auto* prevPath : regularPrevPathes) + { + auto state = prevPath->lmState; + float score = prevPath->accScore + morphScore; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + } + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + + for (auto* curMorph : combiningRMorphs) + { + bestPathCont.clear(); + size_t length = 1; + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + length = curMorph->chunks.size(); + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + for (auto& p : combiningPrevPathes) + { + auto* prev = p.first; + auto* prevPath = p.second; + float score = prevPath->accScore + morphScore; + // merge with only the same socket + if (prevPath->combineSocket != curMorph->combineSocket || curMorph->isSingle()) + { + continue; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) score -= spacePenalty; + else continue; + } + Wid firstWid = morphBase[prevPath->wid].getCombined()->lmMorphemeId; + auto state = prevPath->lmState; + score += state.next(langMdl, firstWid); + + for (size_t i = 1; i < length; ++i) + { + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto continueFor2; + } + score += state.next(langMdl, wid); + } + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + continueFor2:; + } + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + } + }; + namespace lm { inline float half2float(uint16_t h) @@ -33,6 +300,14 @@ namespace kiwi } } + inline void addBias(uint8_t* out, const int8_t* ints, size_t n) + { + for (size_t i = 0; i < n; ++i) + { + out[i] = ints[i] + 128; + } + } + template void logsoftmaxInplace(Arr& arr) { @@ -40,15 +315,15 @@ namespace kiwi arr -= std::log(arr.exp().sum()); } - template - PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ std::move(mem) } + template + PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ mem } { - auto* ptr = reinterpret_cast(base.get()); - auto& header = getHeader(); + auto* ptr = reinterpret_cast(mem.get()); Vector nodeSizes(header.numNodes); streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); - keyData = make_unique(header.numNodes - 1); + keyValueData = make_unique((header.numNodes - 1) * (sizeof(KeyType) + sizeof(int32_t))); + auto keyData = make_unique(header.numNodes - 1); if (std::is_same::value) { streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); @@ -70,7 +345,7 @@ namespace kiwi } nodeData = make_unique(numNonLeafNodes); - valueData = make_unique(header.numNodes - 1); + auto valueData = make_unique(header.numNodes - 1); size_t nonLeafIdx = 0, leafIdx = 0, nextOffset = 0; Vector> keyRanges; @@ -106,11 +381,31 @@ namespace kiwi } } + uint8_t* kvDataPtr = keyValueData.get(); + nonLeafIdx = 0; + nextOffset = 0; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (!nodeSizes[i]) continue; + auto& node = nodeData[nonLeafIdx]; + memcpy(kvDataPtr, &keyData[nextOffset], node.numNexts * sizeof(KeyType)); + kvDataPtr += node.numNexts * sizeof(KeyType); + memcpy(kvDataPtr, &valueData[nextOffset], node.numNexts * sizeof(int32_t)); + kvDataPtr += node.numNexts * sizeof(int32_t); + nextOffset += node.numNexts; + nonLeafIdx++; + } + + allRootValueData = make_unique(header.vocabSize); + for (size_t i = 0; i < nodeData[0].numNexts; ++i) + { + allRootValueData[keyData[i]] = valueData[i]; + } Vector tempBuf; for (size_t i = 0; i < nonLeafIdx; ++i) { auto& node = nodeData[i]; - nst::prepare(&keyData[node.nextOffset], &valueData[node.nextOffset], node.numNexts, tempBuf); + nst::prepareKV(&keyValueData[node.nextOffset * (sizeof(KeyType) + sizeof(int32_t))], node.numNexts, tempBuf); } Deque dq; @@ -132,139 +427,637 @@ namespace kiwi } } + { + const size_t contextEmbSize = header.contextSize * contextEmbStride(); + const size_t distantEmbSize = windowSize > 0 ? header.vocabSize * distantEmbStride() : 0; + const size_t outputEmbSize = header.vocabSize * outputEmbStride(); + const size_t positionConfSize = windowSize > 0 ? (header.windowSize + 1) * sizeof(float) : 0; + const size_t distantMaskSize = windowSize > 0 ? (header.vocabSize + 7) / 8 : 0; + + allEmbs = make_unique(contextEmbSize + outputEmbSize + distantEmbSize + positionConfSize + distantMaskSize); + auto p = allEmbs.get(); + contextEmbPtr = reinterpret_cast(p); + distantEmbPtr = windowSize > 0 ? reinterpret_cast(p += contextEmbSize) : nullptr; + outputEmbPtr = reinterpret_cast(p += distantEmbSize); + positionConfidPtr = windowSize > 0 ? reinterpret_cast(p += outputEmbSize) : nullptr; + distantMaskPtr = windowSize > 0 ? reinterpret_cast(p += positionConfSize) : nullptr; + } + auto* eptr = ptr + header.embOffset; - contextEmb = make_unique(header.contextSize * header.dim); - contextBias = make_unique(header.contextSize); - contextValidTokenSum = make_unique(header.contextSize); - contextConf = make_unique(header.contextSize); - if (useDistantTokens) + auto* optr = const_cast(contextEmbPtr); + for (size_t i = 0; i < header.contextSize; ++i) { - distantEmb = make_unique(header.vocabSize * header.dim); - distantBias = make_unique(header.vocabSize); - distantConf = make_unique(header.vocabSize); - positionConf = make_unique(header.windowSize); + if (quantized) + { + addBias(optr, reinterpret_cast(eptr), header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // scale + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), reinterpret_cast(eptr), header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } + + *reinterpret_cast(optr) = -half2float(*reinterpret_cast(eptr)); // bias + optr += sizeof(float); + eptr += sizeof(uint16_t); + if (windowSize > 0) + { + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // confidence + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // valid token sum + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + eptr += sizeof(uint16_t) * 2; + } } - outputEmb = make_unique(header.vocabSize * header.dim); - const uint16_t* contextEmbScale = reinterpret_cast(eptr + header.contextSize * header.dim); - for (size_t i = 0; i < header.contextSize; ++i) + optr = const_cast(outputEmbPtr); + for (size_t i = 0; i < header.vocabSize; ++i) { - dequantize(&contextEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(contextEmbScale[i])); - eptr += header.dim; + auto* qvals = reinterpret_cast(eptr); + if (quantized) + { + memcpy(optr, qvals, header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = accumulate(qvals, qvals + header.dim, 0) * 128; + optr += sizeof(int32_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), qvals, header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } } - eptr += header.contextSize * sizeof(uint16_t); - for (size_t i = 0; i < header.contextSize; ++i) + + if (windowSize > 0) { - contextBias[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + optr = const_cast(distantEmbPtr); + for (size_t i = 0; i < header.vocabSize; ++i) + { + if (quantized) + { + addBias(optr, reinterpret_cast(eptr), header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // scale + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), reinterpret_cast(eptr), header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } + + *reinterpret_cast(optr) = -half2float(*reinterpret_cast(eptr)); // bias + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // confidence + optr += sizeof(float); + eptr += sizeof(uint16_t); + if (quantized) + { + optr += sizeof(float); + } + } + + const_cast(positionConfidPtr)[0] = 0; + for (size_t i = 0; i < header.windowSize; ++i) + { + const_cast(positionConfidPtr)[i + 1] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + optr = const_cast(distantMaskPtr); + const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; + std::copy(eptr, eptr + compressedDistantMaskSize, optr); } - for (size_t i = 0; i < header.contextSize; ++i) + } + + template + float PcLangModel::progress(int32_t& nodeIdx, + uint32_t& contextIdx, + size_t& historyPos, + std::array& history, + KeyType next) const + { + const bool validDistantToken = distantTokenMask(next); + float ll = 0; + + if (windowSize > 0 && validDistantToken) { - contextValidTokenSum[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + if constexpr (quantized) + { + int32_t contextIdcs[1 + windowSize]; + float lls[(1 + windowSize) * 2]; + int32_t nextIdx[1] = { next }; + + copy(positionConfidPtr, positionConfidPtr + windowSize + 1, lls); + lls[0] += getContextConfid(contextIdx); + contextIdcs[0] = contextIdx; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; + contextIdcs[i + 1] = (historyToken ? historyToken : 0) + header.contextSize; + } + LogSoftmax{}(lls, std::integral_constant()); + + qgemm::scatteredGEMMBaseline( + 1 + windowSize, 1, header.dim, + getContextQuantEmb(0), contextIdcs, contextEmbStride(), + getOutputQuantEmb(0), nextIdx, outputEmbStride(), + &lls[1 + windowSize], 1); + + for (size_t i = 0; i < 1 + windowSize; ++i) + { + lls[i] += lls[i + 1 + windowSize]; + } + lls[0] -= getContextValidTokenSum(contextIdx); + ll = LogSumExp{}(lls, std::integral_constant()); + ll += getContextValidTokenSum(contextIdx); + } + else + { + thread_local Eigen::MatrixXf mat; + mat.resize(header.dim, 1 + windowSize); + thread_local Eigen::VectorXf lls; + lls.resize(1 + windowSize); + + lls = Eigen::Map{ positionConfidPtr, windowSize + 1 }; + lls[0] += getContextConfid(contextIdx); + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; + } + logsoftmaxInplace(lls.array()); + + mat.col(0) = Eigen::Map{ getContextEmb(contextIdx), header.dim }; + lls[0] += getContextBias(contextIdx); + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[(historyPos + i) % windowSize]; + if (historyToken) mat.col(i + 1) = Eigen::Map{ getDistantEmb(historyToken), header.dim }; + else mat.col(i + 1).setZero(); + lls[i + 1] += getDistantBias(historyToken); + } + lls.tail(windowSize).array() += getContextValidTokenSum(contextIdx); + Eigen::Map outputVec{ getOutputEmb(next), header.dim }; + lls += mat.transpose() * outputVec; + ll = LogSumExp{}(lls.data(), std::integral_constant()); + } } - for (size_t i = 0; i < header.contextSize; ++i) + else { - contextConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + if constexpr (quantized) + { + const auto* contextPtr = getContextQuantEmb(contextIdx); + const auto* outputPtr = getOutputQuantEmb(next); + int32_t acc = qgemm::dotprod(contextPtr, outputPtr, header.dim); + const float contextScale = *reinterpret_cast(contextPtr + header.dim), + outputScale = *reinterpret_cast(outputPtr + header.dim), + contextBias = *reinterpret_cast(contextPtr + header.dim + sizeof(float)); + const int32_t hsum = *reinterpret_cast(outputPtr + header.dim + sizeof(float)); + acc -= hsum; + ll = acc * contextScale * outputScale + contextBias; + } + else + { + ll = getContextBias(contextIdx); + Eigen::Map contextVec{ getContextEmb(contextIdx), header.dim }; + Eigen::Map outputVec{ getOutputEmb(next), header.dim }; + ll += (contextVec.transpose() * outputVec)[0]; + } } + + contextIdx = progressContextNode(nodeIdx, next); + if (windowSize > 0) + { + if (history[windowSize]) + { + history[historyPos] = history[windowSize]; + historyPos = (historyPos + 1) % windowSize; + } + history[windowSize] = validDistantToken ? next : 0; + } + return ll; + } - const uint16_t* distantEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); - for (size_t i = 0; i < header.vocabSize; ++i) + // specialization for windowSize > 0 + template + template + auto PcLangModel::nextState( + const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next) const -> LmStateType + { + LmStateType ret = state; + ret.contextIdx = progressContextNode(ret.node, next); + if (ret.history[windowSize]) { - if (useDistantTokens) dequantize(&distantEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(distantEmbScale[i])); - eptr += header.dim; + ret.history[ret.historyPos] = ret.history[windowSize]; + ret.historyPos = (ret.historyPos + 1) % windowSize; } - eptr += header.vocabSize * sizeof(uint16_t); - for (size_t i = 0; i < header.vocabSize; ++i) + ret.history[windowSize] = distantTokenMask(next) ? next : 0; + return ret; + } + + // specialization for windowSize == 0 + template + template + auto PcLangModel::nextState( + const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next) const -> LmStateType + { + LmStateType ret = state; + ret.contextIdx = progressContextNode(ret.node, next); + return ret; + } + + inline uint64_t mergePair(uint32_t a, uint32_t b) + { + return ((uint64_t)a << 32) | b; + } + + inline pair splitPair(uint64_t a) + { + return make_pair(a >> 32, a & 0xFFFFFFFF); + } + + template + template + void PcLangModel::progressMatrix( + const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + static constexpr size_t scoreBatchSize = 32; + thread_local Vector contextIdcs, historyIdcs, nextIdcs; + thread_local Vector inverseContextIdcs, inverseHistoryIdcs, inverseNextIdcs; + thread_local Vector inputEmbBuf, outputEmbBuf, resultBuf, confidenceBuf; + thread_local Vector scoreBuf; + thread_local Vector contextIdcs2, nextIdcs2; + + contextIdcs.resize(prevStateSize); + historyIdcs.clear(); + nextIdcs.resize(nextIdSize); + inverseContextIdcs.resize(prevStateSize); + inverseHistoryIdcs.clear(); + inverseHistoryIdcs.resize(prevStateSize * windowSize, -1); + inverseNextIdcs.resize(nextIdSize); + if (quantized) { - if (useDistantTokens) distantBias[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + contextIdcs2.clear(); + nextIdcs2.clear(); } - for (size_t i = 0; i < header.vocabSize; ++i) + else { - if (useDistantTokens) distantConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + inputEmbBuf.resize(prevStateSize * header.dim); + outputEmbBuf.resize(nextIdSize * header.dim); + } + confidenceBuf.resize(prevStateSize * 2); + scoreBuf.resize(scoreBatchSize * (windowSize + 2)); + + const size_t numInvalidDistantTokens = nextIdSize - numValidDistantTokens; + for (size_t i = 0; i < nextIdSize; ++i) + { + nextIdcs[i] = mergePair(nextIds[i], i); } - for (size_t i = 0; i < header.windowSize; ++i) + sort(nextIdcs.begin(), nextIdcs.begin() + numInvalidDistantTokens); + sort(nextIdcs.begin() + numInvalidDistantTokens, nextIdcs.end()); + size_t uniqOutputSize = 0; + for (size_t i = 0; i < nextIdSize; ++i) { - if (useDistantTokens) positionConf[i] = half2float(*reinterpret_cast(eptr)); - eptr += sizeof(uint16_t); + const auto nextId = splitPair(nextIdcs[i]).first; + const auto idx = splitPair(nextIdcs[i]).second; + if (i == 0 || nextId != splitPair(nextIdcs[i - 1]).first) + { + if (quantized) + { + nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &outputEmbBuf[uniqOutputSize * header.dim]); + } + uniqOutputSize++; + } + inverseNextIdcs[idx] = uniqOutputSize - 1; } + resultBuf.resize(prevStateSize * uniqOutputSize); - const uint16_t* outputEmbScale = reinterpret_cast(eptr + header.vocabSize * header.dim); - for (size_t i = 0; i < header.vocabSize; ++i) + for (size_t i = 0; i < prevStateSize; ++i) { - dequantize(&outputEmb[i * header.dim], reinterpret_cast(eptr), header.dim, half2float(outputEmbScale[i])); - eptr += header.dim; + contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); + } + sort(contextIdcs.begin(), contextIdcs.end()); + size_t uniqInputSize = 0; + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = splitPair(contextIdcs[i]).first; + const auto idx = splitPair(contextIdcs[i]).second; + if (i == 0 || contextId != splitPair(contextIdcs[i - 1]).first) + { + if (quantized) + { + contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &inputEmbBuf[uniqInputSize * header.dim]); + fill(&resultBuf[uniqInputSize * uniqOutputSize], &resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); + } + confidenceBuf[uniqInputSize * 2] = getContextConfid(contextId); + confidenceBuf[uniqInputSize * 2 + 1] = getContextValidTokenSum(contextId); + uniqInputSize++; + } + inverseContextIdcs[idx] = uniqInputSize - 1; } - eptr += header.vocabSize * sizeof(uint16_t); - if (useDistantTokens) + size_t uniqHistorySize = 0; + if (prevStateSize <= 8) // use vector for small size { - const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; - distantMask = make_unique(compressedDistantMaskSize); - std::copy(eptr, eptr + compressedDistantMaskSize, distantMask.get()); + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[(j + prevStates[i].historyPos) % windowSize]; + if (historyToken) + { + historyIdcs.emplace_back(mergePair(historyToken, i * windowSize + j)); + } + } + } + sort(historyIdcs.begin(), historyIdcs.end()); + uniqHistorySize = 0; + for (size_t i = 0; i < historyIdcs.size(); ++i) + { + const auto historyToken = splitPair(historyIdcs[i]).first; + const auto idx = splitPair(historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(historyIdcs[i - 1]).first) + { + uniqHistorySize++; + } + inverseHistoryIdcs[idx] = uniqHistorySize - 1; + } + inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); + confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + resultBuf.resize((uniqInputSize + uniqHistorySize) * uniqOutputSize); + + uniqHistorySize = 0; + for (size_t i = 0; i < historyIdcs.size(); ++i) + { + const auto historyToken = splitPair(historyIdcs[i]).first; + const auto idx = splitPair(historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(historyIdcs[i - 1]).first) + { + if (quantized) + { + contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &inputEmbBuf[(uniqInputSize + uniqHistorySize) * header.dim]); + fill(&resultBuf[(uniqInputSize + uniqHistorySize) * uniqOutputSize], &resultBuf[(uniqInputSize + uniqHistorySize + 1) * uniqOutputSize], getDistantBias(historyToken)); + } + confidenceBuf[uniqInputSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); + uniqHistorySize++; + } + } } - } + else // use map for large size + { + thread_local UnorderedMap historyMap; + thread_local Vector uniqHistoryTokens; + historyMap.clear(); + uniqHistoryTokens.clear(); + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[(j + prevStates[i].historyPos) % windowSize]; + if (!historyToken) continue; + const auto idx = i * windowSize + j; + auto inserted = historyMap.emplace(historyToken, historyMap.size()); + inverseHistoryIdcs[idx] = inserted.first->second; + if (inserted.second) uniqHistoryTokens.emplace_back(historyToken); + } + } + uniqHistorySize = historyMap.size(); - template - float PcLangModel::progress(int32_t& nodeIdx, - uint32_t& contextIdx, - size_t& historyPos, - std::array& history, - KeyType next) const - { - const auto& header = getHeader(); - const bool validDistantToken = distantTokenMask(next); - float ll = 0; + inputEmbBuf.resize((uniqInputSize + uniqHistorySize)* header.dim); + confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + resultBuf.resize((uniqInputSize + uniqHistorySize)* uniqOutputSize); - thread_local Eigen::MatrixXf mat; - mat.resize(header.dim, 1 + windowSize); - thread_local Eigen::VectorXf lls; - lls.resize(1 + windowSize); - if (useDistantTokens && validDistantToken) + for (size_t i = 0; i < uniqHistoryTokens.size(); ++i) + { + const auto historyToken = uniqHistoryTokens[i]; + if (quantized) + { + contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &inputEmbBuf[(uniqInputSize + i) * header.dim]); + fill(&resultBuf[(uniqInputSize + i) * uniqOutputSize], &resultBuf[(uniqInputSize + i + 1) * uniqOutputSize], getDistantBias(historyToken)); + } + confidenceBuf[uniqInputSize * 2 + i] = getDistantConfid(historyToken); + } + } + + Eigen::Map resultMap{ resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + + if constexpr (quantized) { - lls[0] = contextConf[contextIdx]; - lls.tail(windowSize) = Eigen::Map{ &positionConf[0], windowSize }; - for (size_t i = 0; i < windowSize; ++i) + qgemm::scatteredGEMMOpt( + uniqInputSize + uniqHistorySize, uniqOutputSize, header.dim, + getContextQuantEmb(0), contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), nextIdcs2.data(), outputEmbStride(), + resultBuf.data(), uniqOutputSize); + } + else + { + Eigen::Map inputMap{ inputEmbBuf.data(), header.dim, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; + resultMap += outputMap.transpose() * inputMap; + } + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto state = prevStates[i]; + for (size_t j = 0; j < numInvalidDistantTokens; ++j) { - const auto historyToken = history[(historyPos + i) % windowSize]; - lls[i + 1] += historyToken ? distantConf[historyToken] : -99999; + outScores[i * nextIdSize + j] = resultMap(inverseNextIdcs[j], inverseContextIdcs[i]); + outStates[i * nextIdSize + j] = nextState<_windowSize>(state, nextIds[j]); } - logsoftmaxInplace(lls.array()); + } - mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; - lls[0] -= contextBias[contextIdx]; - for (size_t i = 0; i < windowSize; ++i) + auto* validTokenSumBuf = scoreBuf.data() + scoreBatchSize * (windowSize + 1); + + for (size_t i = 0; i < prevStateSize * numValidDistantTokens; i += scoreBatchSize) + { + const size_t batchSize = std::min(scoreBatchSize, prevStateSize * numValidDistantTokens - i); + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + scoreBuf[j] = confidenceBuf[inverseContextIdcs[pIdx] * 2]; + validTokenSumBuf[j] = confidenceBuf[inverseContextIdcs[pIdx] * 2 + 1]; + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = inverseHistoryIdcs[pIdx * windowSize + k]; + scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : confidenceBuf[uniqInputSize * 2 + idx]; + } + } + Eigen::Map> scoreMap{ scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; + scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; + LogSoftmaxTransposed{}(scoreBuf.data(), batchSize, scoreBatchSize); + scoreMap.rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + scoreBuf[j] += resultMap(inverseNextIdcs[nIdx], inverseContextIdcs[pIdx]); + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = inverseHistoryIdcs[pIdx * windowSize + k]; + if (idx != -1) + { + scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(inverseNextIdcs[nIdx], uniqInputSize + idx); + } + } + } + LogSumExpTransposed{}(scoreBuf.data(), batchSize, scoreBatchSize); + + for (size_t j = 0; j < batchSize; ++j) { - const auto historyToken = history[(historyPos + i) % windowSize]; - if (historyToken) mat.col(i + 1) = Eigen::Map{ &distantEmb[historyToken * header.dim], header.dim }; - else mat.col(i + 1).setZero(); - lls[i + 1] -= distantBias[historyToken]; + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + outScores[pIdx * nextIdSize + nIdx] = scoreBuf[j]; + outStates[pIdx * nextIdSize + nIdx] = nextState(prevStates[pIdx], nextIds[nIdx]); } - lls.tail(windowSize).array() += contextValidTokenSum[contextIdx]; - Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; - lls += mat.transpose() * outputVec; - ll = LogExpSum{}(lls.data(), std::integral_constant()); + } + } + + template + template + void PcLangModel::progressMatrix( + const typename std::enable_if<_windowSize == 0, LmStateType>::type* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + thread_local Vector contextIdcs, nextIdcs; + thread_local Vector inverseContextIdcs, inverseNextIdcs; + thread_local Vector inputEmbBuf, outputEmbBuf, resultBuf; + thread_local Vector contextIdcs2, nextIdcs2; + + contextIdcs.resize(prevStateSize); + nextIdcs.resize(nextIdSize); + inverseContextIdcs.resize(prevStateSize); + inverseNextIdcs.resize(nextIdSize); + if (quantized) + { + contextIdcs2.clear(); + nextIdcs2.clear(); } else { - lls[0] = -contextBias[contextIdx]; - mat.col(0) = Eigen::Map{ &contextEmb[contextIdx * header.dim], header.dim }; - Eigen::Map outputVec{ &outputEmb[next * header.dim], header.dim }; - lls.head(1) += mat.transpose() * outputVec; - ll = lls[0]; + inputEmbBuf.resize(prevStateSize * header.dim); + outputEmbBuf.resize(nextIdSize * header.dim); + } + + for (size_t i = 0; i < nextIdSize; ++i) + { + nextIdcs[i] = mergePair(nextIds[i], i); } + sort(nextIdcs.begin(), nextIdcs.end()); + size_t uniqOutputSize = 0; + for (size_t i = 0; i < nextIdSize; ++i) + { + const auto nextId = splitPair(nextIdcs[i]).first; + const auto idx = splitPair(nextIdcs[i]).second; + if (i == 0 || nextId != splitPair(nextIdcs[i - 1]).first) + { + if (quantized) + { + nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &outputEmbBuf[uniqOutputSize * header.dim]); + } + uniqOutputSize++; + } + inverseNextIdcs[idx] = uniqOutputSize - 1; + } + resultBuf.resize(max(prevStateSize * uniqOutputSize, (size_t)64)); - contextIdx = progressContextNode(nodeIdx, next); - if (history[windowSize]) + for (size_t i = 0; i < prevStateSize; ++i) { - history[historyPos] = history[windowSize]; - historyPos = (historyPos + 1) % windowSize; + contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); + } + sort(contextIdcs.begin(), contextIdcs.end()); + size_t uniqInputSize = 0; + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = splitPair(contextIdcs[i]).first; + const auto idx = splitPair(contextIdcs[i]).second; + if (i == 0 || contextId != splitPair(contextIdcs[i - 1]).first) + { + if (quantized) + { + contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &inputEmbBuf[uniqInputSize * header.dim]); + fill(&resultBuf[uniqInputSize * uniqOutputSize], &resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); + } + uniqInputSize++; + } + inverseContextIdcs[idx] = uniqInputSize - 1; + } + + Eigen::Map resultMap{ resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)uniqInputSize }; + if constexpr (quantized) + { + qgemm::scatteredGEMMOpt( + uniqInputSize, uniqOutputSize, header.dim, + getContextQuantEmb(0), contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), nextIdcs2.data(), outputEmbStride(), + resultBuf.data(), uniqOutputSize); + } + else + { + Eigen::Map inputMap{ inputEmbBuf.data(), header.dim, (Eigen::Index)uniqInputSize }; + Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; + resultMap += outputMap.transpose() * inputMap; + } + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + for (size_t j = 0; j < nextIdSize; ++j) + { + outStates[i * nextIdSize + j] = nextState<_windowSize>(state, nextIds[j]); + outScores[i * nextIdSize + j] = resultMap(inverseNextIdcs[j], inverseContextIdcs[i]); + } } - history[windowSize] = validDistantToken ? next : 0; - return ll; } utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, bool reorderContextId) @@ -498,85 +1291,107 @@ namespace kiwi writePadding(ostr); ostr.write((const char*)compressedValues.data(), compressedValues.size()); writePadding(ostr); - ostr.write((const char*)contextEmb.data(), contextEmb.size()); - ostr.write((const char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); - ostr.write((const char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); - ostr.write((const char*)contextValidTokenSum.data(), contextValidTokenSum.size() * sizeof(uint16_t)); - ostr.write((const char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); - ostr.write((const char*)distantEmb.data(), distantEmb.size()); - ostr.write((const char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); - ostr.write((const char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); - ostr.write((const char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + + for (size_t i = 0; i < contextSize; ++i) + { + ostr.write((const char*)&contextEmb[i * dim], dim); + ostr.write((const char*)&contextEmbScale[i], sizeof(uint16_t)); + ostr.write((const char*)&contextEmbBias[i], sizeof(uint16_t)); + ostr.write((const char*)&contextConfidence[i], sizeof(uint16_t)); + ostr.write((const char*)&contextValidTokenSum[i], sizeof(uint16_t)); + } + for (size_t i = 0; i < outputSize; ++i) + { + ostr.write((const char*)&outputEmb[i * dim], dim); + ostr.write((const char*)&outputEmbScale[i], sizeof(uint16_t)); + } + for (size_t i = 0; i < outputSize; ++i) + { + ostr.write((const char*)&distantEmb[i * dim], dim); + ostr.write((const char*)&distantEmbScale[i], sizeof(uint16_t)); + ostr.write((const char*)&distantEmbBias[i], sizeof(uint16_t)); + ostr.write((const char*)&distantConfidence[i], sizeof(uint16_t)); + } ostr.write((const char*)positionConfidence.data(), positionConfidence.size() * sizeof(uint16_t)); - ostr.write((const char*)outputEmb.data(), outputEmb.size()); - ostr.write((const char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); ostr.write((const char*)distantMask.data(), distantMask.size()); return mem; } - template - void* PcLangModel::getFindBestPathFn() const + template + void* PcLangModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder::findBestPath>; } - template - void* PcLangModel::getNewJoinerFn() const + template + void* PcLangModel::getNewJoinerFn() const { return (void*)&newJoinerWithKiwi; } - template + template inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) { auto& header = *reinterpret_cast(mem.get()); + if (!useDistantTokens) + { + return make_unique>(std::move(mem)); + } + switch (header.windowSize) { - case 4: - return make_unique>(std::move(mem)); case 7: - return make_unique>(std::move(mem)); - case 8: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; }; } - template + template std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { auto& header = *reinterpret_cast(mem.get()); switch (header.keySize) { - case 1: - return createOptimizedModelWithWindowSize(std::move(mem)); case 2: - return createOptimizedModelWithWindowSize(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); case 4: - return createOptimizedModelWithWindowSize(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; } } - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + using FnCreateOptimizedModel = decltype(&createOptimizedModel); - template + template struct CreateOptimizedModelGetter { template struct Wrapper { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), useDistantTokens>; + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), useDistantTokens, quantized>; }; }; - std::unique_ptr PcLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens) + std::unique_ptr PcLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens, bool quantized) { - static tp::Table tableWithoutDistantTokens{ CreateOptimizedModelGetter{} }, - tableWithDistantTokens{ CreateOptimizedModelGetter{} }; - auto fn = (useDistantTokens ? tableWithDistantTokens : tableWithoutDistantTokens)[static_cast(archType)]; + static tp::Table tables[] = { + CreateOptimizedModelGetter{}, + CreateOptimizedModelGetter{}, + }; + static tp::Table quantTables[] = { + CreateOptimizedModelGetter{}, + CreateOptimizedModelGetter{}, + }; + + if (quantized) + { + auto fn = quantTables[useDistantTokens ? 1 : 0][static_cast(archType)]; + if (fn) return (*fn)(std::move(mem)); + std::cerr << "Quantization is not supported for " << archToStr(archType) << ". Fall back to non-quantized model." << std::endl; + } + auto fn = tables[useDistantTokens ? 1 : 0][static_cast(archType)]; if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; return (*fn)(std::move(mem)); } diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index c6e1e5b7..8f94dba3 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -15,43 +15,128 @@ namespace kiwi { namespace lm { - template + template class PcLMState; - template + template class PcLangModel : public PcLangModelBase { using MyNode = Node; std::unique_ptr nodeData; - std::unique_ptr keyData; - std::unique_ptr valueData; - std::unique_ptr contextEmb; - std::unique_ptr contextBias; - std::unique_ptr contextValidTokenSum; - std::unique_ptr contextConf; - std::unique_ptr distantEmb; - std::unique_ptr distantBias; - std::unique_ptr distantConf; - std::unique_ptr positionConf; - std::unique_ptr outputEmb; - std::unique_ptr distantMask; + std::unique_ptr keyValueData; + std::unique_ptr allRootValueData; + std::unique_ptr allEmbs; + const uint8_t* contextEmbPtr = nullptr; // [numContexts, (dim + scale? + bias + confid + vts)] + const uint8_t* outputEmbPtr = nullptr; // [numOutputs, (dim + scale? + sum?)] + const uint8_t* distantEmbPtr = nullptr; // [numOutputs, (dim + scale? + bias + confid + pad?)] + const float* positionConfidPtr = nullptr; + const uint8_t* distantMaskPtr = nullptr; + + inline size_t contextEmbStride() const + { + if (quantized) return header.dim + (windowSize > 0 ? 4 : 2) * sizeof(float); + else return (header.dim + (windowSize > 0 ? 3 : 1)) * sizeof(float); + } + + inline size_t outputEmbStride() const + { + if (quantized) return header.dim + 2 * sizeof(float); + else return header.dim * sizeof(float); + } + + inline size_t distantEmbStride() const + { + if (quantized) return header.dim + 4 * sizeof(float); + else return (header.dim + 2) * sizeof(float); + } + + inline const float* getContextEmb(uint32_t idx) const + { + return reinterpret_cast(contextEmbPtr + idx * contextEmbStride()); + } + + inline const uint8_t* getContextQuantEmb(uint32_t idx) const + { + return contextEmbPtr + idx * contextEmbStride(); + } + + inline float getContextBias(uint32_t idx) const + { + const size_t offset = quantized ? + (header.dim + sizeof(float)) + : (header.dim * sizeof(float)); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline float getContextConfid(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 2 * sizeof(float)) + : (header.dim + 1) * sizeof(float); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline float getContextValidTokenSum(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 3 * sizeof(float)) + : (header.dim + 2) * sizeof(float); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline const float* getOutputEmb(uint32_t idx) const + { + return reinterpret_cast(outputEmbPtr + idx * outputEmbStride()); + } + + inline const int8_t* getOutputQuantEmb(uint32_t idx) const + { + return reinterpret_cast(outputEmbPtr + idx * outputEmbStride()); + } + + inline const float* getDistantEmb(uint32_t idx) const + { + return reinterpret_cast(distantEmbPtr + idx * distantEmbStride()); + } + + inline const uint8_t* getDistantQuantEmb(uint32_t idx) const + { + return distantEmbPtr + idx * distantEmbStride(); + } + + inline float getDistantBias(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + sizeof(float)) + : (header.dim * sizeof(float)); + return *reinterpret_cast(distantEmbPtr + idx * distantEmbStride() + offset); + } + + inline float getDistantConfid(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 2 * sizeof(float)) + : (header.dim + 1) * sizeof(float); + return *reinterpret_cast(distantEmbPtr + idx * distantEmbStride() + offset); + } MyNode* findLowerNode(MyNode* node, KeyType k) const { while (node->lower) { auto* lowerNode = node + node->lower; - auto* keys = &keyData[lowerNode->nextOffset]; - auto* values = &valueData[lowerNode->nextOffset]; + auto* kvs = &keyValueData[lowerNode->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; int32_t found; - if (nst::search( - keys, - values, + if (nst::searchKV( + kvs, lowerNode->numNexts, k, - found - ) && found >= 0) + found) && found >= 0) { return lowerNode + found; } @@ -65,16 +150,13 @@ namespace kiwi while (node->lower) { auto* lowerNode = node + node->lower; - auto* keys = &keyData[lowerNode->nextOffset]; - auto* values = &valueData[lowerNode->nextOffset]; + auto* kvs = &keyValueData[lowerNode->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; int32_t found; - if (nst::search( - keys, - values, + if (nst::searchKV( + kvs, lowerNode->numNexts, k, - found - )) + found)) { if (found >= 0) { @@ -92,10 +174,23 @@ namespace kiwi public: using VocabType = KeyType; - using LmStateType = PcLMState; + using LmStateType = PcLMState; PcLangModel(utils::MemoryObject&& mem); + ModelType getType() const override + { + if (quantized) + { + if (windowSize > 0) return ModelType::pclmQuantized; + else return ModelType::pclmLocalQuantized; + } + else + { + if (windowSize > 0) return ModelType::pclm; + else return ModelType::pclmLocal; + } + } void* getFindBestPathFn() const override; void* getNewJoinerFn() const override; @@ -105,19 +200,29 @@ namespace kiwi { int32_t v; auto* node = &nodeData[nodeIdx]; - auto* keys = &keyData[node->nextOffset]; - auto* values = &valueData[node->nextOffset]; - PREFETCH_T0(node + node->lower); - if (!nst::search( - keys, - values, - node->numNexts, next, v - )) + auto* kvs = &keyValueData[node->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; + if (node != nodeData.get()) { - if (!node->lower) return 0; - nodeIdx += node->lower; - PREFETCH_T0(&keyData[nodeData[nodeIdx].nextOffset]); - continue; + //ScopedTimer<> timer(node->numNexts <= 16 ? 0 : node->numNexts <= 272 ? 1 : 2); + PREFETCH_T0(node + node->lower); + if (!nst::searchKV( + kvs, + node->numNexts, next, v + )) + { + if (!node->lower) return 0; + nodeIdx += node->lower; + PREFETCH_T0(&keyValueData[nodeData[nodeIdx].nextOffset * (sizeof(KeyType) + sizeof(int32_t))]); + continue; + } + } + else + { + v = allRootValueData[next]; + if (v == 0) + { + return 0; + } } // non-leaf node @@ -132,17 +237,30 @@ namespace kiwi while (node->lower) { node += node->lower; + auto* lkvs = &keyValueData[node->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; int32_t lv; - if (nst::search( - &keyData[node->nextOffset], - &valueData[node->nextOffset], - node->numNexts, next, lv - )) + if (node != nodeData.get()) { + //ScopedTimer<> timer(node->numNexts <= 16 ? 0 : node->numNexts <= 272 ? 1 : 2); + if (nst::searchKV( + lkvs, + node->numNexts, next, lv + )) + { + if (lv > 0) + { + node += lv; + nodeIdx = node - &nodeData[0]; + return (uint32_t)-v; + } + } + } + else + { + lv = allRootValueData[next]; if (lv > 0) { - node += lv; - nodeIdx = node - &nodeData[0]; + nodeIdx = lv; return (uint32_t)-v; } } @@ -155,23 +273,45 @@ namespace kiwi inline bool distantTokenMask(uint32_t idx) const { - if (useDistantTokens) return (distantMask[idx / 8] & (1 << (idx % 8))) != 0; + if (windowSize > 0) return (distantMaskPtr[idx / 8] & (1 << (idx % 8))) != 0; else return false; } float progress(int32_t& nodeIdx, uint32_t& contextIdx, size_t& historyPos, - std::array& history, + std::array& history, KeyType next) const; + template + LmStateType nextState(const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next) const; + + template + LmStateType nextState(const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next) const; + + /* + * �� prevStateSize���� ���¿� nextIdSize���� ���� ��ū�� �޾Ƽ�, �� ���º��� ���� ��ū�� ������ Ȯ���� ����ϰ� �� ���¸� ��ȯ�Ѵ�. + * �� ���°��� outStates�� ����ǰ�, �� ���º� Ȯ������ outScores�� ����ȴ�. + * nextIdSize���� ���� ��ū �� ������ numValidDistantTokens���� ��ū�� ��ȿ�� distant ��ū���� ó���ȴ�. + */ + template + void progressMatrix(const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + + template + void progressMatrix(const typename std::enable_if<(_windowSize == 0), LmStateType>::type* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; }; - template - struct PcLMState : public LmStateBase> + template + struct PcLMState : public LmStateBase> { int32_t node = 0; uint32_t contextIdx = 0; + size_t historyPos = 0; + std::array history = { {0,} }; static constexpr ArchType arch = _arch; static constexpr bool transposed = true; @@ -181,68 +321,74 @@ namespace kiwi bool operator==(const PcLMState& other) const { - return node == other.node; + if (node != other.node) return false; + for (size_t i = windowSize / 2; i < windowSize; ++i) + { + if (history[(historyPos + i) % windowSize] != other.history[(other.historyPos + i) % windowSize]) + { + return false; + } + } + return true; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const PcLangModel* lm, VocabTy next) { - size_t historyPos = 0; - std::array history = { {0,} }; return lm->progress(node, contextIdx, historyPos, history, next); } }; - template - struct PcLMState : public LmStateBase> + template + struct PcLMState<0, _arch, VocabTy, quantized> : public LmStateBase> { - static constexpr bool useDistantTokens = true; - int32_t node = 0; uint32_t contextIdx = 0; - size_t historyPos = 0; - std::array history = { {0,} }; - + static constexpr ArchType arch = _arch; static constexpr bool transposed = true; + static constexpr size_t windowSize = 0; PcLMState() = default; PcLMState(const ILangModel* lm) {} bool operator==(const PcLMState& other) const { - return node == other.node && historyPos == other.historyPos && history == other.history; + return node == other.node; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const PcLangModel* lm, VocabTy next) { + size_t historyPos = 0; + std::array history = { {0,} }; return lm->progress(node, contextIdx, historyPos, history, next); } }; } - template - struct Hash> + template + struct Hash> { - size_t operator()(const lm::PcLMState& state) const + size_t operator()(const lm::PcLMState& state) const { Hash hasher; - return hasher(state.node); + std::hash vocabHasher; + size_t ret = hasher(state.node); + for (size_t i = windowSize / 2; i < windowSize; ++i) + { + const auto historyToken = state.history[(state.historyPos + i) % windowSize]; + ret = vocabHasher(historyToken) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + return ret; } }; - template - struct Hash> + template + struct Hash> { - size_t operator()(const lm::PcLMState& state) const + size_t operator()(const lm::PcLMState<0, arch, VocabTy, quantized>& state) const { Hash hasher; - std::hash vocabHasher; - size_t ret = hasher(state.node); - for (size_t i = 0; i < state.history.size(); ++i) - { - ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); - } - return ret; + return hasher(state.node); } }; } diff --git a/src/SkipBigramModelImpl.hpp b/src/SkipBigramModelImpl.hpp index c89a4d78..404ea6da 100644 --- a/src/SkipBigramModelImpl.hpp +++ b/src/SkipBigramModelImpl.hpp @@ -8,8 +8,18 @@ namespace kiwi { namespace lm { - template - struct LogExpSum + template + struct LogSumExp + { + template + float operator()(const float* arr, std::integral_constant) + { + return logSumExpImpl(arr); + } + }; + + template<> + struct LogSumExp { template float operator()(const float* arr, std::integral_constant) @@ -24,9 +34,18 @@ namespace kiwi } }; + template<> + struct LogSumExp : public LogSumExp + { + }; + template - float logExpSumImpl(const float* arr) + float logSumExpImpl(const float* arr) { + if ((archType == ArchType::avx512bw || archType == ArchType::avx512vnni) && size < 16) + { + return logSumExpImpl(arr); + } simd::Operator op; auto pmax = op.loadf(arr); @@ -44,6 +63,232 @@ namespace kiwi return std::log(op.redsumf(sum)) + op.firstf(pmax); } + template + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + return logSoftmaxImpl(arr); + } + }; + + template<> + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + float maxValue = *std::max_element(arr, arr + size); + float sum = 0; + for (size_t i = 0; i < size; ++i) + { + sum += std::exp(arr[i] - maxValue); + } + maxValue += std::log(sum); + for (size_t i = 0; i < size; ++i) + { + arr[i] -= maxValue; + } + } + }; + + template<> + struct LogSoftmax : public LogSoftmax + { + }; + + template + void logSoftmaxImpl(float* arr) + { + if ((archType == ArchType::avx512bw || archType == ArchType::avx512vnni) && size < 16) + { + return logSoftmaxImpl(arr); + } + simd::Operator op; + + auto pmax = op.loadf(arr); + for (size_t i = op.packetSize; i < size; i += op.packetSize) + { + pmax = op.maxf(pmax, op.loadf(&arr[i])); + } + pmax = op.redmaxbf(pmax); + + auto sum = op.zerof(); + for (size_t i = 0; i < size; i += op.packetSize) + { + sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); + } + pmax = op.addf(op.logf(op.set1f(op.redsumf(sum))), pmax); + for (size_t i = 0; i < size; i += op.packetSize) + { + op.storef(&arr[i], op.subf(op.loadf(&arr[i]), pmax)); + } + } + + template + struct LogSoftmaxTransposed; + + template + struct LogSoftmaxTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + m = op.expf(a0); + m = op.addf(m, op.expf(a1)); + m = op.addf(m, op.expf(a2)); + m = op.addf(m, op.expf(a3)); + m = op.addf(m, op.expf(a4)); + m = op.addf(m, op.expf(a5)); + m = op.addf(m, op.expf(a6)); + m = op.addf(m, op.expf(a7)); + m = op.logf(m); + + // subtract + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + op.storef(arr, a0); + op.storef(arr + stride, a1); + op.storef(arr + stride * 2, a2); + op.storef(arr + stride * 3, a3); + op.storef(arr + stride * 4, a4); + op.storef(arr + stride * 5, a5); + op.storef(arr + stride * 6, a6); + op.storef(arr + stride * 7, a7); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSoftmaxTransposed : public LogSoftmaxTransposed + { + }; + + template<> + struct LogSoftmaxTransposed : public LogSoftmaxTransposed + { + }; + + template + struct LogSumExpTransposed; + + template + struct LogSumExpTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + auto s = op.expf(a0); + s = op.addf(s, op.expf(a1)); + s = op.addf(s, op.expf(a2)); + s = op.addf(s, op.expf(a3)); + s = op.addf(s, op.expf(a4)); + s = op.addf(s, op.expf(a5)); + s = op.addf(s, op.expf(a6)); + s = op.addf(s, op.expf(a7)); + s = op.logf(s); + + op.storef(arr, op.addf(m, s)); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSumExpTransposed : public LogSumExpTransposed + { + }; + + template<> + struct LogSumExpTransposed : public LogSumExpTransposed + { + }; + template float SkipBigramModel::evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const { @@ -70,7 +315,7 @@ namespace kiwi arr[i + windowSize] = out; } } - return LogExpSum{}(arr, std::integral_constant{}) - logWindowSize; + return LogSumExp{}(arr, std::integral_constant{}) - logWindowSize; } } } diff --git a/src/qgemm.cpp b/src/qgemm.cpp new file mode 100644 index 00000000..e4fb888e --- /dev/null +++ b/src/qgemm.cpp @@ -0,0 +1,743 @@ +#include +#include +#include +#include +#include "qgemm.h" +#include "SIMD.hpp" + +namespace kiwi +{ + namespace qgemm + { + template + int32_t dotprod( + const uint8_t* a, const int8_t* b, size_t n + ) + { + simd::Operator op; + return op.dotprod(a, b, n); + } + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + inline void packScatteredAPanel(uint8_t* out, size_t ld, const uint8_t* base, const int32_t* idx, size_t scale, size_t m, size_t k) + { + for (size_t i = 0; i < m; ++i) + { + const uint8_t* src = base + idx[i] * scale; + memcpy(out + i * ld, src, k); + } + } + + template + inline void packScatteredBPanel(int8_t* out, size_t ld, int32_t* sum, + const int8_t* base, const int32_t* sumBase, const int32_t* idx, + size_t scale, size_t sumScale, size_t n, size_t k) + { + int32_t* pout = reinterpret_cast(out); + + for (size_t i = 0; i < n; i += blockSize) + { + const size_t innerN = std::min(blockSize, n - i); + for (size_t j = 0; j < k; j += 4) + { + for (size_t x = 0; x < innerN; ++x) + { + const int8_t* src = base + idx[i + x] * scale; + *pout++ = *reinterpret_cast(&src[j]); + } + pout += (blockSize - innerN); + } + + for (size_t x = 0; x < innerN; ++x) + { + sum[i + x] = sumBase[idx[i + x] * sumScale]; + } + } + } + + template + inline void qgemmKernel( + size_t m, size_t n, size_t k, + const uint8_t* a, const int8_t* b, + const float* aScale, const float* bScale, + const float* aBias, const int32_t* sumBuffer, + float* out, size_t ld) + { + // quantized sub-block gemm(m=4, n=64) + static constexpr size_t blockNStride = blockNSize * 4; + __m512i pa, pb[4], psum[16]; + __m512 paScale, paBias, pbScale[4], r; + + for (size_t i = 0; i < n; n += blockNSize * 4) + { + psum[0] = psum[4] = psum[8] = psum[12] = _mm512_loadu_si512(sumBuffer); + psum[1] = psum[5] = psum[9] = psum[13] = _mm512_loadu_si512(sumBuffer + blockNSize); + psum[2] = psum[6] = psum[10] = psum[14] = _mm512_loadu_si512(sumBuffer + blockNSize * 2); + psum[3] = psum[7] = psum[11] = psum[15] = _mm512_loadu_si512(sumBuffer + blockNSize * 3); + + for (size_t j = 0; j < k; j += 4) + { + pb[0] = _mm512_loadu_si512(b); + pb[1] = _mm512_loadu_si512(b + blockNStride * 1); + pb[2] = _mm512_loadu_si512(b + blockNStride * 2); + pb[3] = _mm512_loadu_si512(b + blockNStride * 3); + + pa = _mm512_set1_epi32(*reinterpret_cast(a)); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa, pb[0]); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa, pb[1]); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa, pb[2]); + psum[3] = _mm512_dpbusd_epi32(psum[3], pa, pb[3]); + + pa = _mm512_set1_epi32(*reinterpret_cast(a + k)); + psum[4] = _mm512_dpbusd_epi32(psum[4], pa, pb[0]); + psum[5] = _mm512_dpbusd_epi32(psum[5], pa, pb[1]); + psum[6] = _mm512_dpbusd_epi32(psum[6], pa, pb[2]); + psum[7] = _mm512_dpbusd_epi32(psum[7], pa, pb[3]); + + pa = _mm512_set1_epi32(*reinterpret_cast(a + k * 2)); + psum[8] = _mm512_dpbusd_epi32(psum[8], pa, pb[0]); + psum[9] = _mm512_dpbusd_epi32(psum[9], pa, pb[1]); + psum[10] = _mm512_dpbusd_epi32(psum[10], pa, pb[2]); + psum[11] = _mm512_dpbusd_epi32(psum[11], pa, pb[3]); + + pa = _mm512_set1_epi32(*reinterpret_cast(a + k * 3)); + psum[12] = _mm512_dpbusd_epi32(psum[12], pa, pb[0]); + psum[13] = _mm512_dpbusd_epi32(psum[13], pa, pb[1]); + psum[14] = _mm512_dpbusd_epi32(psum[14], pa, pb[2]); + psum[15] = _mm512_dpbusd_epi32(psum[15], pa, pb[3]); + + a += 4; + b += blockNStride * 4; + } + pbScale[0] = _mm512_loadu_ps(bScale); + pbScale[1] = _mm512_loadu_ps(bScale + blockNSize); + pbScale[2] = _mm512_loadu_ps(bScale + blockNSize * 2); + pbScale[3] = _mm512_loadu_ps(bScale + blockNSize * 3); + + paScale = _mm512_set1_ps(*aScale++); + paBias = _mm512_set1_ps(*aBias++); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale[0]), paScale, paBias); + _mm512_storeu_ps(out, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[1]), pbScale[1]), paScale, paBias); + _mm512_storeu_ps(out + blockNSize, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[2]), pbScale[2]), paScale, paBias); + _mm512_storeu_ps(out + blockNSize * 2, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[3]), pbScale[3]), paScale, paBias); + _mm512_storeu_ps(out + blockNSize * 3, r); + + paScale = _mm512_set1_ps(*aScale++); + paBias = _mm512_set1_ps(*aBias++); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[4]), pbScale[0]), paScale, paBias); + _mm512_storeu_ps(out + ld, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[5]), pbScale[1]), paScale, paBias); + _mm512_storeu_ps(out + ld + blockNSize, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[6]), pbScale[2]), paScale, paBias); + _mm512_storeu_ps(out + ld + blockNSize * 2, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[7]), pbScale[3]), paScale, paBias); + _mm512_storeu_ps(out + ld + blockNSize * 3, r); + + paScale = _mm512_set1_ps(*aScale++); + paBias = _mm512_set1_ps(*aBias++); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[8]), pbScale[0]), paScale, paBias); + _mm512_storeu_ps(out + ld * 2, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[9]), pbScale[1]), paScale, paBias); + _mm512_storeu_ps(out + ld * 2 + blockNSize, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[10]), pbScale[2]), paScale, paBias); + _mm512_storeu_ps(out + ld * 2 + blockNSize * 2, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[11]), pbScale[3]), paScale, paBias); + _mm512_storeu_ps(out + ld * 2 + blockNSize * 3, r); + + paScale = _mm512_set1_ps(*aScale++); + paBias = _mm512_set1_ps(*aBias++); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[12]), pbScale[0]), paScale, paBias); + _mm512_storeu_ps(out + ld * 3, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[13]), pbScale[1]), paScale, paBias); + _mm512_storeu_ps(out + ld * 3 + blockNSize, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[14]), pbScale[2]), paScale, paBias); + _mm512_storeu_ps(out + ld * 3 + blockNSize * 2, r); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[15]), pbScale[3]), paScale, paBias); + _mm512_storeu_ps(out + ld * 3 + blockNSize * 3, r); + sumBuffer += blockNSize * 4; + out += blockNSize * 4; + } + } + + template + void scatteredGEMM( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + // assert k <= 384 + constexpr size_t packM = 48, packN = 256, packK = 384; + thread_local uint8_t buffer[packM * packK + packN * packK]; + thread_local int32_t sumBuffer[packN]; + uint8_t* aBuffer = buffer; + int8_t* bBuffer = reinterpret_cast(buffer + packM * packK); + + for (size_t ni = 0; ni < n; ni += packN) + { + const size_t microN = std::min(packN, n - ni); + packScatteredBPanel(bBuffer, packK, sumBuffer, bBase, reinterpret_cast(bBase + k + 4), bIdx + ni, bIdxScale, bIdxScale / 4, microN, k); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + packScatteredAPanel(aBuffer, packK, aBase, aIdx + mi, aIdxScale, microM, k); + + //qgemmKernel<16>(microM, microN, k, aBuffer, bBuffer, sumBuffer, nullptr, n); + } + } + } + + template void scatteredGEMM( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + + template + void scatteredGEMMBaseline( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + thread_local Vector buffer; + buffer.resize((m + n) * (k + 8)); + uint8_t* aBuffer = buffer.data(); + int8_t* bBuffer = reinterpret_cast(aBuffer + m * (k + 8)); + simd::Operator op; + + for (size_t i = 0; i < m; ++i) + { + std::memcpy(aBuffer + i * (k + 8), &aBase[aIdx[i] * aIdxScale], k + 8); + } + for (size_t i = 0; i < n; ++i) + { + std::memcpy(bBuffer + i * (k + 8), &bBase[bIdx[i] * bIdxScale], k + 8); + } + + for (size_t i = 0; i < m; ++i) + { + for (size_t j = 0; j < n; ++j) + { + const auto* aPtr = aBuffer + i * (k + 8); + const auto* bPtr = bBuffer + j * (k + 8); + int32_t acc = op.dotprod(aPtr, bPtr, k); + const float contextScale = *reinterpret_cast(aPtr + k), + outputScale = *reinterpret_cast(bPtr + k), + contextBias = *reinterpret_cast(aPtr + k + 4); + const int32_t hsum = *reinterpret_cast(bPtr + k + 4); + c[i * ldc + j] = (acc - hsum) * contextScale * outputScale + contextBias; + } + } + } + + inline void pack16x4( + void* out, + const void* a0, + const void* a1, + const void* a2, + const void* a3, + const void* a4, + const void* a5, + const void* a6, + const void* a7, + const void* a8, + const void* a9, + const void* a10, + const void* a11, + const void* a12, + const void* a13, + const void* a14, + const void* a15 + ) + { + // 00, 01, 02, 03, 40, 41, 42, 43 + auto p0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a0)), _mm_loadu_epi32(a4), 1); + auto p1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a1)), _mm_loadu_epi32(a5), 1); + auto p2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a2)), _mm_loadu_epi32(a6), 1); + auto p3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a3)), _mm_loadu_epi32(a7), 1); + auto p4 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a8)), _mm_loadu_epi32(a12), 1); + auto p5 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a9)), _mm_loadu_epi32(a13), 1); + auto p6 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a10)), _mm_loadu_epi32(a14), 1); + auto p7 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a11)), _mm_loadu_epi32(a15), 1); + + // 00, 10, 01, 11, 40, 50, 41, 51 + auto q0 = _mm256_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 42, 52, 43, 53 + auto q1 = _mm256_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 60, 70, 61, 71 + auto q2 = _mm256_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 62, 72, 63, 73 + auto q3 = _mm256_unpackhi_epi32(p2, p3); + auto q4 = _mm256_unpacklo_epi32(p4, p5); + auto q5 = _mm256_unpackhi_epi32(p4, p5); + auto q6 = _mm256_unpacklo_epi32(p6, p7); + auto q7 = _mm256_unpackhi_epi32(p6, p7); + + // 00, 10, 20, 30, 40, 50, 60, 70 + p0 = _mm256_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 41, 51, 61, 71 + p1 = _mm256_unpackhi_epi64(q0, q2); + p2 = _mm256_unpacklo_epi64(q1, q3); + p3 = _mm256_unpackhi_epi64(q1, q3); + p4 = _mm256_unpacklo_epi64(q4, q6); + p5 = _mm256_unpackhi_epi64(q4, q6); + p6 = _mm256_unpacklo_epi64(q5, q7); + p7 = _mm256_unpackhi_epi64(q5, q7); + + auto* pout = reinterpret_cast<__m256i*>(out); + _mm256_storeu_si256(pout++, p0); + _mm256_storeu_si256(pout++, p4); + _mm256_storeu_si256(pout++, p1); + _mm256_storeu_si256(pout++, p5); + _mm256_storeu_si256(pout++, p2); + _mm256_storeu_si256(pout++, p6); + _mm256_storeu_si256(pout++, p3); + _mm256_storeu_si256(pout++, p7); + } + + inline void packScatteredGEMVAPanel(uint8_t* out, size_t ld, + float* scale, float* bias, + const uint8_t* base, const int32_t* idx, + size_t idxScale, size_t m, size_t k) + { + static constexpr size_t blockSize = 16; + int32_t* pout = reinterpret_cast(out); + + for (size_t i = 0; i < m; i += blockSize) + { + const size_t innerM = std::min(blockSize, m - i); + size_t j; + for (j = 0; j < (k & ~15); j += 16) + { + pack16x4(pout, + base + idx[i] * idxScale + j, + base + (1 < innerM ? idx[i + 1] * idxScale + j : 0), + base + (2 < innerM ? idx[i + 2] * idxScale + j : 0), + base + (3 < innerM ? idx[i + 3] * idxScale + j : 0), + base + (4 < innerM ? idx[i + 4] * idxScale + j : 0), + base + (5 < innerM ? idx[i + 5] * idxScale + j : 0), + base + (6 < innerM ? idx[i + 6] * idxScale + j : 0), + base + (7 < innerM ? idx[i + 7] * idxScale + j : 0), + base + (8 < innerM ? idx[i + 8] * idxScale + j : 0), + base + (9 < innerM ? idx[i + 9] * idxScale + j : 0), + base + (10 < innerM ? idx[i + 10] * idxScale + j : 0), + base + (11 < innerM ? idx[i + 11] * idxScale + j : 0), + base + (12 < innerM ? idx[i + 12] * idxScale + j : 0), + base + (13 < innerM ? idx[i + 13] * idxScale + j : 0), + base + (14 < innerM ? idx[i + 14] * idxScale + j : 0), + base + (15 < innerM ? idx[i + 15] * idxScale + j : 0) + ); + pout += 64; + } + + for (; j < k; j += 4) + { + for (size_t x = 0; x < innerM; ++x) + { + const uint8_t* src = base + idx[i + x] * idxScale; + *pout++ = *reinterpret_cast(&src[j]); + } + pout += (blockSize - innerM); + } + + for (size_t x = 0; x < innerM; ++x) + { + scale[i + x] = *reinterpret_cast(&base[idx[i + x] * idxScale + k]); + bias[i + x] = *reinterpret_cast(&base[idx[i + x] * idxScale + k + 4]); + } + } + } + + inline void qgemvKernel16(size_t m, size_t k, + const uint8_t* a, const int8_t* b, + const float* aScale, float bScale, + const float* aBias, int32_t bSum, + float* c) + { + __m512i pa, pb, psum, pbSum; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum); + + for (size_t j = 0; j < m; j += 16) + { + psum = pbSum; + for (size_t i = 0; i < k; i += 4) + { + pa = _mm512_loadu_si512(a); + + pb = _mm512_set1_epi32(*reinterpret_cast(b + i)); + psum = _mm512_dpbusd_epi32(psum, pa, pb); + + a += 64; + } + paScale = _mm512_loadu_ps(aScale); + paBias = _mm512_loadu_ps(aBias); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + aScale += 16; + aBias += 16; + c += 16; + } + } + + void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 128, packN = 1, packK = 384; + thread_local uint8_t buffer[packM * packK + packN * packK + packM * 2 * sizeof(float)]; + thread_local uint8_t optABuffer[packM * packK]; + uint8_t* aBuffer = buffer; + int8_t* bBuffer = reinterpret_cast(aBuffer + packM * packK); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + packScatteredGEMVAPanel(aBuffer, packK, aScale, aBias, aBase, aIdx, aIdxScale, microM, k); + qgemvKernel16(microM, k, aBuffer, bBuffer, aScale, bScale, aBias, bSum, c); + aIdx += microM; + c += microM; + } + } + + inline void qgemv2Kernel16(size_t m, size_t k, + const uint8_t* a, const int8_t* b, size_t ldb, + const float* aScale, float bScale[2], + const float* aBias, int32_t bSum[2], + float* c) + { + __m512i pa, pb[2], psum[2], pbSum[2]; + __m512 paScale, paBias, pbScale[2], r[2], t[2]; + pbScale[0] = _mm512_set1_ps(bScale[0]); + pbScale[1] = _mm512_set1_ps(bScale[1]); + pbSum[0] = _mm512_set1_epi32(-bSum[0]); + pbSum[1] = _mm512_set1_epi32(-bSum[1]); + + for (size_t j = 0; j < m; j += 16) + { + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t i = 0; i < k; i += 4) + { + pa = _mm512_loadu_si512(a); + + pb[0] = _mm512_set1_epi32(*reinterpret_cast(b + i)); + pb[1] = _mm512_set1_epi32(*reinterpret_cast(b + ldb + i)); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa, pb[0]); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa, pb[1]); + + a += 64; + } + paScale = _mm512_loadu_ps(aScale); + paBias = _mm512_loadu_ps(aBias); + + // 00, 10, 20, 30, 40, 50, 60, 70 + r[0] = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale[0]), paScale, paBias); + r[1] = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[1]), pbScale[1]), paScale, paBias); + + // 00, 01, 10, 11, 40, 41, 50, 51 + t[0] = _mm512_unpacklo_ps(r[0], r[1]); + // 20, 21, 30, 31, 60, 61, 70, 71 + t[1] = _mm512_unpackhi_ps(r[0], r[1]); + + _mm_storeu_ps(c, _mm512_extractf32x4_ps(t[0], 0)); + _mm_storeu_ps(c + 4, _mm512_extractf32x4_ps(t[1], 0)); + _mm_storeu_ps(c + 8, _mm512_extractf32x4_ps(t[0], 1)); + _mm_storeu_ps(c + 12, _mm512_extractf32x4_ps(t[1], 1)); + _mm_storeu_ps(c + 16, _mm512_extractf32x4_ps(t[0], 2)); + _mm_storeu_ps(c + 20, _mm512_extractf32x4_ps(t[1], 2)); + _mm_storeu_ps(c + 24, _mm512_extractf32x4_ps(t[0], 3)); + _mm_storeu_ps(c + 28, _mm512_extractf32x4_ps(t[1], 3)); + + aScale += 16; + aBias += 16; + c += 32; + } + } + + void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 128, packN = 2, packK = 384; + thread_local uint8_t buffer[packM * packK + packN * packK + packM * 2 * sizeof(float)]; + thread_local uint8_t optABuffer[packM * packK]; + uint8_t* aBuffer = buffer; + int8_t* bBuffer = reinterpret_cast(aBuffer + packM * packK); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + packScatteredGEMVAPanel(aBuffer, packK, aScale, aBias, aBase, aIdx, aIdxScale, microM, k); + qgemv2Kernel16(microM, k, aBuffer, bBuffer, packK, aScale, bScale, aBias, bSum, c); + aIdx += microM; + c += microM * 2; + } + } + + template + void scatteredGEMMSmall( + size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + static_assert(m <= 3, "m should be less than or equal to 3"); + static_assert(n <= 3, "n should be less than or equal to 3"); + __m512i pa[3], pb[3], psum[3][3]; + const uint8_t* aPtr[3]; + const int8_t* bPtr[3]; + + psum[0][0] = _mm512_setzero_si512(); + if (m > 1) psum[1][0] = _mm512_setzero_si512(); + if (m > 2) psum[2][0] = _mm512_setzero_si512(); + if (n > 1) psum[0][1] = _mm512_setzero_si512(); + if (m > 1 && n > 1) psum[1][1] = _mm512_setzero_si512(); + if (m > 2 && n > 1) psum[2][1] = _mm512_setzero_si512(); + if (n > 2) psum[0][2] = _mm512_setzero_si512(); + if (m > 1 && n > 2) psum[1][2] = _mm512_setzero_si512(); + if (m > 2 && n > 2) psum[2][2] = _mm512_setzero_si512(); + + aPtr[0] = aBase + aIdx[0] * aIdxScale; + if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; + if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; + + bPtr[0] = bBase + bIdx[0] * bIdxScale; + if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; + if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; + + for (size_t x = 0; x < k; x += 64) + { + if (m > 0) + { + pa[0] = _mm512_loadu_si512(aPtr[0]); + aPtr[0] += 64; + } + if (m > 1) + { + pa[1] = _mm512_loadu_si512(aPtr[1]); + aPtr[1] += 64; + } + if (m > 2) + { + pa[2] = _mm512_loadu_si512(aPtr[2]); + aPtr[2] += 64; + } + + if (n > 0) + { + pb[0] = _mm512_loadu_si512(bPtr[0]); + bPtr[0] += 64; + } + if (n > 1) + { + pb[1] = _mm512_loadu_si512(bPtr[1]); + bPtr[1] += 64; + } + if (n > 2) + { + pb[2] = _mm512_loadu_si512(bPtr[2]); + bPtr[2] += 64; + } + + psum[0][0] = _mm512_dpbusd_epi32(psum[0][0], pa[0], pb[0]); + if (m > 1) psum[1][0] = _mm512_dpbusd_epi32(psum[1][0], pa[1], pb[0]); + if (m > 2) psum[2][0] = _mm512_dpbusd_epi32(psum[2][0], pa[2], pb[0]); + if (n > 1) psum[0][1] = _mm512_dpbusd_epi32(psum[0][1], pa[0], pb[1]); + if (m > 1 && n > 1) psum[1][1] = _mm512_dpbusd_epi32(psum[1][1], pa[1], pb[1]); + if (m > 2 && n > 1) psum[2][1] = _mm512_dpbusd_epi32(psum[2][1], pa[2], pb[1]); + if (n > 2) psum[0][2] = _mm512_dpbusd_epi32(psum[0][2], pa[0], pb[2]); + if (m > 1 && n > 2) psum[1][2] = _mm512_dpbusd_epi32(psum[1][2], pa[1], pb[2]); + if (m > 2 && n > 2) psum[2][2] = _mm512_dpbusd_epi32(psum[2][2], pa[2], pb[2]); + } + + float contextScale[3], outputScale[3], contextBias[3]; + int32_t hsum[3]; + + if (m > 0) + { + contextScale[0] = *reinterpret_cast(aPtr[0]); + contextBias[0] = *reinterpret_cast(aPtr[0] + 4); + } + if (m > 1) + { + contextScale[1] = *reinterpret_cast(aPtr[1]); + contextBias[1] = *reinterpret_cast(aPtr[1] + 4); + } + if (m > 2) + { + contextScale[2] = *reinterpret_cast(aPtr[2]); + contextBias[2] = *reinterpret_cast(aPtr[2] + 4); + } + + + if (n > 0) + { + outputScale[0] = *reinterpret_cast(bPtr[0]); + hsum[0] = *reinterpret_cast(bPtr[0] + 4); + } + if (n > 1) + { + outputScale[1] = *reinterpret_cast(bPtr[1]); + hsum[1] = *reinterpret_cast(bPtr[1] + 4); + } + if (n > 2) + { + outputScale[2] = *reinterpret_cast(bPtr[2]); + hsum[2] = *reinterpret_cast(bPtr[2] + 4); + } + + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][0]); + c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; + } + if (m > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][0]); + c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; + } + if (m > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][0]); + c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; + } + if (n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][1]); + c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; + } + if (m > 1 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][1]); + c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; + } + if (m > 2 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][1]); + c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; + } + if (n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][2]); + c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; + } + if (m > 1 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][2]); + c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; + } + if (m > 2 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][2]); + c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; + } + } + + template + void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + using Fn = decltype(&scatteredGEMMBaseline); + static constexpr Fn fnTable[] = { + scatteredGEMMBaseline, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall, + scatteredGEMMSmall + }; + + if (m <= 3 && n <= 3) + { + return (*fnTable[(m - 1) * 3 + (n - 1)])(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + + if (n == 1 && ldc == 1) + { + return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + + if (m >= 4 && n == 2 && ldc == 2) + { + return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } +} diff --git a/src/qgemm.h b/src/qgemm.h new file mode 100644 index 00000000..9871582c --- /dev/null +++ b/src/qgemm.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include + +namespace kiwi +{ + namespace qgemm + { + template + int32_t dotprod( + const uint8_t* a, const int8_t* b, size_t n + ); + + template + void scatteredGEMM( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + + template + void scatteredGEMMBaseline( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + + template + void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } +} diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index 6269237c..94915842 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -74,6 +74,7 @@ + @@ -126,6 +127,7 @@ + From ac5a78a51a81d33ca65f9dc3e90cbbb11337e607 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 9 Feb 2025 03:00:57 +0900 Subject: [PATCH 12/53] Update evaluator & pclm_builder --- tools/Evaluator.cpp | 2 ++ tools/evaluator_main.cpp | 8 ++++++++ tools/pclm_builder.cpp | 8 +++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 996e8b59..c9cefbc6 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -26,6 +26,8 @@ const char* modelTypeToStr(ModelType type) case ModelType::sbg: return "sbg"; case ModelType::pclm: return "pclm"; case ModelType::pclmLocal: return "pclm-local"; + case ModelType::pclmQuantized: return "pclm-quant"; + case ModelType::pclmLocalQuantized: return "pclm-local-quant"; } return "unknown"; } diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index d50af519..4530dd76 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -74,6 +74,14 @@ int main(int argc, const char* argv[]) { kiwiModelType = ModelType::pclmLocal; } + else if (v == "pclm-quant") + { + kiwiModelType = ModelType::pclmQuantized; + } + else if (v == "pclm-local-quant") + { + kiwiModelType = ModelType::pclmLocalQuantized; + } else { cerr << "Invalid model type" << endl; diff --git a/tools/pclm_builder.cpp b/tools/pclm_builder.cpp index 1542e80e..64e1d3eb 100644 --- a/tools/pclm_builder.cpp +++ b/tools/pclm_builder.cpp @@ -9,13 +9,13 @@ using namespace std; using namespace kiwi; -int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, size_t minCnt, const std::string& output) +int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, size_t minCnt, const std::string& output, bool reorderContextIdx = true) { try { tutils::Timer timer; KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); - auto ret = pclm::PCLanguageModelBase::build(contextDef, embedding); + auto ret = lm::PcLangModelBase::build(contextDef, embedding, reorderContextIdx); ret.writeToFile(output + "/pclm.mdl"); double tm = timer.getElapsed(); cout << "Total: " << tm << " ms " << endl; @@ -41,12 +41,14 @@ int main(int argc, const char* argv[]) ValueArg emb{ "e", "emb", "embedding file", true, "", "string" }; ValueArg minCnt{ "n", "min-cnt", "min count of morpheme", false, 10, "int" }; ValueArg output{ "o", "output", "", true, "", "string" }; + SwitchArg preserveContextIdx{ "p", "preserve-context-idx", "preserve context index", false }; cmd.add(mdef); cmd.add(cdef); cmd.add(emb); cmd.add(minCnt); cmd.add(output); + cmd.add(preserveContextIdx); try { @@ -58,5 +60,5 @@ int main(int argc, const char* argv[]) return -1; } - return run(mdef, cdef, emb, minCnt, output); + return run(mdef, cdef, emb, minCnt, output, !preserveContextIdx); } From 5281c43c307304e349a2c323ba1d0ec480759066 Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 12 Feb 2025 02:12:22 +0900 Subject: [PATCH 13/53] Optimize memory layout of `PcLMState::history` --- src/BestPathContainer.hpp | 19 ++----------------- src/PCLanguageModel.cpp | 19 ++++++++----------- src/PCLanguageModel.hpp | 29 ++++++++++------------------- src/SkipBigramModel.cpp | 26 ++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 47 deletions(-) diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp index 6d26a937..3255b6bc 100644 --- a/src/BestPathContainer.hpp +++ b/src/BestPathContainer.hpp @@ -78,23 +78,8 @@ namespace kiwi { size_t operator()(const PathHash& p) const { - size_t ret = 0; - if (sizeof(PathHash) % sizeof(size_t)) - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(uint32_t); ++i) - { - ret ^= ptr[i]; - } - } - else - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(size_t); ++i) - { - ret ^= ptr[i]; - } - } + size_t ret = Hash{}(p.lmState); + ret ^= *reinterpret_cast(&p.rootId); return ret; } }; diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index 161b0628..ce1b900f 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -556,7 +556,6 @@ namespace kiwi template float PcLangModel::progress(int32_t& nodeIdx, uint32_t& contextIdx, - size_t& historyPos, std::array& history, KeyType next) const { @@ -576,13 +575,13 @@ namespace kiwi contextIdcs[0] = contextIdx; for (size_t i = 0; i < windowSize; ++i) { - const auto historyToken = history[(historyPos + i) % windowSize]; + const auto historyToken = history[i]; lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; contextIdcs[i + 1] = (historyToken ? historyToken : 0) + header.contextSize; } LogSoftmax{}(lls, std::integral_constant()); - qgemm::scatteredGEMMBaseline( + qgemm::scatteredGEMMOpt( 1 + windowSize, 1, header.dim, getContextQuantEmb(0), contextIdcs, contextEmbStride(), getOutputQuantEmb(0), nextIdx, outputEmbStride(), @@ -607,7 +606,7 @@ namespace kiwi lls[0] += getContextConfid(contextIdx); for (size_t i = 0; i < windowSize; ++i) { - const auto historyToken = history[(historyPos + i) % windowSize]; + const auto historyToken = history[i]; lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; } logsoftmaxInplace(lls.array()); @@ -616,7 +615,7 @@ namespace kiwi lls[0] += getContextBias(contextIdx); for (size_t i = 0; i < windowSize; ++i) { - const auto historyToken = history[(historyPos + i) % windowSize]; + const auto historyToken = history[i]; if (historyToken) mat.col(i + 1) = Eigen::Map{ getDistantEmb(historyToken), header.dim }; else mat.col(i + 1).setZero(); lls[i + 1] += getDistantBias(historyToken); @@ -655,8 +654,7 @@ namespace kiwi { if (history[windowSize]) { - history[historyPos] = history[windowSize]; - historyPos = (historyPos + 1) % windowSize; + memcpy(&history[0], &history[1], windowSize * sizeof(KeyType)); } history[windowSize] = validDistantToken ? next : 0; } @@ -673,8 +671,7 @@ namespace kiwi ret.contextIdx = progressContextNode(ret.node, next); if (ret.history[windowSize]) { - ret.history[ret.historyPos] = ret.history[windowSize]; - ret.historyPos = (ret.historyPos + 1) % windowSize; + memcpy(&ret.history[0], &ret.history[1], windowSize * sizeof(KeyType)); } ret.history[windowSize] = distantTokenMask(next) ? next : 0; return ret; @@ -798,7 +795,7 @@ namespace kiwi { for (size_t j = 0; j < windowSize; ++j) { - const auto historyToken = prevStates[i].history[(j + prevStates[i].historyPos) % windowSize]; + const auto historyToken = prevStates[i].history[j]; if (historyToken) { historyIdcs.emplace_back(mergePair(historyToken, i * windowSize + j)); @@ -852,7 +849,7 @@ namespace kiwi { for (size_t j = 0; j < windowSize; ++j) { - const auto historyToken = prevStates[i].history[(j + prevStates[i].historyPos) % windowSize]; + const auto historyToken = prevStates[i].history[j]; if (!historyToken) continue; const auto idx = i * windowSize + j; auto inserted = historyMap.emplace(historyToken, historyMap.size()); diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index 8f94dba3..79cfd0fe 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -279,7 +279,6 @@ namespace kiwi float progress(int32_t& nodeIdx, uint32_t& contextIdx, - size_t& historyPos, std::array& history, KeyType next) const; @@ -310,7 +309,6 @@ namespace kiwi { int32_t node = 0; uint32_t contextIdx = 0; - size_t historyPos = 0; std::array history = { {0,} }; static constexpr ArchType arch = _arch; @@ -321,20 +319,18 @@ namespace kiwi bool operator==(const PcLMState& other) const { + static constexpr size_t cmpStart = windowSize / 2; if (node != other.node) return false; - for (size_t i = windowSize / 2; i < windowSize; ++i) + if (memcmp(&history[cmpStart], &other.history[cmpStart], (windowSize - cmpStart) * sizeof(VocabTy))) { - if (history[(historyPos + i) % windowSize] != other.history[(other.historyPos + i) % windowSize]) - { return false; } - } return true; } float nextImpl(const PcLangModel* lm, VocabTy next) { - return lm->progress(node, contextIdx, historyPos, history, next); + return lm->progress(node, contextIdx, history, next); } }; @@ -358,9 +354,8 @@ namespace kiwi float nextImpl(const PcLangModel* lm, VocabTy next) { - size_t historyPos = 0; std::array history = { {0,} }; - return lm->progress(node, contextIdx, historyPos, history, next); + return lm->progress(node, contextIdx, history, next); } }; } @@ -370,14 +365,10 @@ namespace kiwi { size_t operator()(const lm::PcLMState& state) const { - Hash hasher; - std::hash vocabHasher; - size_t ret = hasher(state.node); - for (size_t i = windowSize / 2; i < windowSize; ++i) - { - const auto historyToken = state.history[(state.historyPos + i) % windowSize]; - ret = vocabHasher(historyToken) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); - } + size_t ret = (uint32_t)(state.node * (size_t)2654435761); + static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); + const auto h = *reinterpret_cast(&state.history[cmpStart]); + ret = h ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); return ret; } }; @@ -387,8 +378,8 @@ namespace kiwi { size_t operator()(const lm::PcLMState<0, arch, VocabTy, quantized>& state) const { - Hash hasher; - return hasher(state.node); + size_t ret = (uint32_t)(state.node * (size_t)2654435761); + return ret; } }; } diff --git a/src/SkipBigramModel.cpp b/src/SkipBigramModel.cpp index 380ea6d7..2fd4b0d7 100644 --- a/src/SkipBigramModel.cpp +++ b/src/SkipBigramModel.cpp @@ -32,6 +32,32 @@ namespace kiwi } }; + template + struct Hash>> + { + size_t operator()(const PathHash>& state) const + { + size_t ret = 0; + if (sizeof(state) % sizeof(size_t)) + { + auto ptr = reinterpret_cast(&state); + for (size_t i = 0; i < sizeof(state) / sizeof(uint32_t); ++i) + { + ret = ptr[i] ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + } + else + { + auto ptr = reinterpret_cast(&state); + for (size_t i = 0; i < sizeof(state) / sizeof(size_t); ++i) + { + ret = ptr[i] ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + } + return ret; + } + }; + namespace lm { template From f5929083abac94a5f9592f5ad1a87b7cbdf01c9c Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 12 Feb 2025 02:13:37 +0900 Subject: [PATCH 14/53] Optimize QGemm kernel more --- include/kiwi/PCLanguageModel.h | 2 +- src/PCLanguageModel.cpp | 17 +- src/qgemm.cpp | 627 +++++++++++++++++++++++++-------- 3 files changed, 501 insertions(+), 145 deletions(-) diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h index 536abd36..664869e0 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/PCLanguageModel.h @@ -50,7 +50,7 @@ namespace kiwi const PcLangModelHeader& getHeader() const { return header; } - static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, bool reorderContextIdx = true); + static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, size_t maxContextLength = -1, bool reorderContextIdx = true); static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); }; } diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index ce1b900f..fd0ce140 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -12,6 +12,11 @@ using namespace std; namespace kiwi { + inline size_t padMultipleOf(size_t n, size_t multiple) + { + return (n + multiple - 1) / multiple * multiple; + } + template struct MorphemeEvaluator> { @@ -816,7 +821,7 @@ namespace kiwi } inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); - resultBuf.resize((uniqInputSize + uniqHistorySize) * uniqOutputSize); + resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); uniqHistorySize = 0; for (size_t i = 0; i < historyIdcs.size(); ++i) @@ -861,7 +866,7 @@ namespace kiwi inputEmbBuf.resize((uniqInputSize + uniqHistorySize)* header.dim); confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); - resultBuf.resize((uniqInputSize + uniqHistorySize)* uniqOutputSize); + resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); for (size_t i = 0; i < uniqHistoryTokens.size(); ++i) { @@ -1003,7 +1008,7 @@ namespace kiwi } inverseNextIdcs[idx] = uniqOutputSize - 1; } - resultBuf.resize(max(prevStateSize * uniqOutputSize, (size_t)64)); + resultBuf.resize(padMultipleOf(prevStateSize, 8) * padMultipleOf(uniqOutputSize, 8)); for (size_t i = 0; i < prevStateSize; ++i) { @@ -1057,7 +1062,7 @@ namespace kiwi } } - utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, bool reorderContextId) + utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool reorderContextId) { ifstream contextStr, embeddingStr; if (!openFile(contextStr, contextDefinition)) @@ -1090,6 +1095,10 @@ namespace kiwi if (id < 0) throw IOException{ "Invalid format : " + contextDefinition }; context.push_back(id); } + if (context.size() > maxContextLength) + { + continue; + } if (contextMap.size() < context.size()) contextMap.resize(context.size()); contextMap[context.size() - 1][context] = (uint32_t)clusterId; maxClusterId = max(maxClusterId, (uint32_t)(clusterId + 1)); diff --git a/src/qgemm.cpp b/src/qgemm.cpp index e4fb888e..5b66c917 100644 --- a/src/qgemm.cpp +++ b/src/qgemm.cpp @@ -5,10 +5,27 @@ #include "qgemm.h" #include "SIMD.hpp" +#define UNROLL4() do { {LOOP_BODY(0)} {LOOP_BODY(1)} {LOOP_BODY(2)} {LOOP_BODY(3)} } while(0) + namespace kiwi { namespace qgemm { + static constexpr size_t TLBSize = 32768; + + template + struct SharedThreadLocalBuffer + { + thread_local static uint8_t buffer[size]; + static uint8_t* get() + { + return buffer; + } + }; + + template + thread_local uint8_t SharedThreadLocalBuffer::buffer[size]; + template int32_t dotprod( const uint8_t* a, const int8_t* b, size_t n @@ -307,206 +324,528 @@ namespace kiwi _mm256_storeu_si256(pout++, p7); } - inline void packScatteredGEMVAPanel(uint8_t* out, size_t ld, - float* scale, float* bias, - const uint8_t* base, const int32_t* idx, - size_t idxScale, size_t m, size_t k) + inline void pack4x64to4x16x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m512i& p0, __m512i& p1, __m512i& p2, __m512i& p3 + ) { - static constexpr size_t blockSize = 16; - int32_t* pout = reinterpret_cast(out); + __m512i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm512_loadu_si512(a0); + p1 = _mm512_loadu_si512(a1); + p2 = _mm512_loadu_si512(a2); + p3 = _mm512_loadu_si512(a3); - for (size_t i = 0; i < m; i += blockSize) - { - const size_t innerM = std::min(blockSize, m - i); - size_t j; - for (j = 0; j < (k & ~15); j += 16) - { - pack16x4(pout, - base + idx[i] * idxScale + j, - base + (1 < innerM ? idx[i + 1] * idxScale + j : 0), - base + (2 < innerM ? idx[i + 2] * idxScale + j : 0), - base + (3 < innerM ? idx[i + 3] * idxScale + j : 0), - base + (4 < innerM ? idx[i + 4] * idxScale + j : 0), - base + (5 < innerM ? idx[i + 5] * idxScale + j : 0), - base + (6 < innerM ? idx[i + 6] * idxScale + j : 0), - base + (7 < innerM ? idx[i + 7] * idxScale + j : 0), - base + (8 < innerM ? idx[i + 8] * idxScale + j : 0), - base + (9 < innerM ? idx[i + 9] * idxScale + j : 0), - base + (10 < innerM ? idx[i + 10] * idxScale + j : 0), - base + (11 < innerM ? idx[i + 11] * idxScale + j : 0), - base + (12 < innerM ? idx[i + 12] * idxScale + j : 0), - base + (13 < innerM ? idx[i + 13] * idxScale + j : 0), - base + (14 < innerM ? idx[i + 14] * idxScale + j : 0), - base + (15 < innerM ? idx[i + 15] * idxScale + j : 0) - ); - pout += 64; - } - - for (; j < k; j += 4) - { - for (size_t x = 0; x < innerM; ++x) - { - const uint8_t* src = base + idx[i + x] * idxScale; - *pout++ = *reinterpret_cast(&src[j]); - } - pout += (blockSize - innerM); - } + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm512_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm512_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm512_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm512_unpackhi_epi32(p2, p3); - for (size_t x = 0; x < innerM; ++x) - { - scale[i + x] = *reinterpret_cast(&base[idx[i + x] * idxScale + k]); - bias[i + x] = *reinterpret_cast(&base[idx[i + x] * idxScale + k + 4]); - } - } + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm512_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm512_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm512_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm512_unpackhi_epi64(q1, q3); } - inline void qgemvKernel16(size_t m, size_t k, - const uint8_t* a, const int8_t* b, - const float* aScale, float bScale, - const float* aBias, int32_t bSum, - float* c) + void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) { - __m512i pa, pb, psum, pbSum; + constexpr size_t packM = 16, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum, pr = _mm512_setzero_si512(); __m512 paScale, paBias, pbScale, r; pbScale = _mm512_set1_ps(bScale); - pbSum = _mm512_set1_epi32(-bSum); - - for (size_t j = 0; j < m; j += 16) + pbSum = _mm512_set1_epi32(-bSum / 4); + __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), + shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), + shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), + shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); + + for (size_t mi = 0; mi < m; mi += packM) { - psum = pbSum; - for (size_t i = 0; i < k; i += 4) - { - pa = _mm512_loadu_si512(a); + const size_t microM = std::min(packM, m - mi); +#define LOOP_BODY(mj) \ + const int32_t aOffsets[4] = {\ + mj * 4 < microM ? aIdx[0] * aIdxScale : 0,\ + mj * 4 + 1 < microM ? aIdx[1] * aIdxScale : 0,\ + mj * 4 + 2 < microM ? aIdx[2] * aIdxScale : 0,\ + mj * 4 + 3 < microM ? aIdx[3] * aIdxScale : 0,\ + };\ + auto* aPtr = aBase;\ + psum = pbSum;\ + for (size_t j = 0; j < k; j += 64)\ + {\ + pack4x64to4x16x4(aPtr + aOffsets[0],\ + aPtr + aOffsets[1],\ + aPtr + aOffsets[2],\ + aPtr + aOffsets[3],\ + pa[0], pa[1], pa[2], pa[3]);\ + pb = _mm512_loadu_si512(bBuffer + j);\ + pbs = _mm512_permutexvar_epi32(shfIdx0, pb);\ + psum = _mm512_dpbusd_epi32(psum, pa[0], pbs);\ + pbs = _mm512_permutexvar_epi32(shfIdx1, pb);\ + psum = _mm512_dpbusd_epi32(psum, pa[1], pbs);\ + pbs = _mm512_permutexvar_epi32(shfIdx2, pb);\ + psum = _mm512_dpbusd_epi32(psum, pa[2], pbs);\ + pbs = _mm512_permutexvar_epi32(shfIdx3, pb);\ + psum = _mm512_dpbusd_epi32(psum, pa[3], pbs);\ + aPtr += 64;\ + }\ + for (size_t i = 0; i < 4; ++i)\ + {\ + aScale[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i]);\ + aBias[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i] + 4);\ + }\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4));\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8));\ + pr = _mm512_inserti32x4(pr, _mm512_castsi512_si128(psum), mj);\ + aIdx += 4; - pb = _mm512_set1_epi32(*reinterpret_cast(b + i)); - psum = _mm512_dpbusd_epi32(psum, pa, pb); + UNROLL4(); +#undef LOOP_BODY - a += 64; - } paScale = _mm512_loadu_ps(aScale); paBias = _mm512_loadu_ps(aBias); - - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum), pbScale), paScale, paBias); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(pr), pbScale), paScale, paBias); _mm512_storeu_ps(c, r); - aScale += 16; - aBias += 16; - c += 16; + c += microM; } } - void scatteredGEMV( + void scatteredGEMV8x1( size_t m, size_t k, const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, const int8_t* b, float* c ) { - constexpr size_t packM = 128, packN = 1, packK = 384; - thread_local uint8_t buffer[packM * packK + packN * packK + packM * 2 * sizeof(float)]; - thread_local uint8_t optABuffer[packM * packK]; - uint8_t* aBuffer = buffer; - int8_t* bBuffer = reinterpret_cast(aBuffer + packM * packK); + constexpr size_t packM = 8, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); float* aScale = reinterpret_cast(bBuffer + packN * packK); float* aBias = aScale + packM; memcpy(bBuffer, b, k); float bScale = *reinterpret_cast(b + k); int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum; + __m256 paScale, paBias, pbScale, r; + __m256i pr = _mm256_setzero_si256(); + pbScale = _mm256_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum / 4); + __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), + shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), + shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), + shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); + + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_permutexvar_epi32(shfIdx0, pb); + psum = _mm512_dpbusd_epi32(psum, pa[0], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx1, pb); + psum = _mm512_dpbusd_epi32(psum, pa[1], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx2, pb); + psum = _mm512_dpbusd_epi32(psum, pa[2], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx3, pb); + psum = _mm512_dpbusd_epi32(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm512_castsi512_si256(psum); + aIdx += 4; + } + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_permutexvar_epi32(shfIdx0, pb); + psum = _mm512_dpbusd_epi32(psum, pa[0], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx1, pb); + psum = _mm512_dpbusd_epi32(psum, pa[1], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx2, pb); + psum = _mm512_dpbusd_epi32(psum, pa[2], pbs); + pbs = _mm512_permutexvar_epi32(shfIdx3, pb); + psum = _mm512_dpbusd_epi32(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm256_inserti32x4(pr, _mm512_castsi512_si128(psum), 1); + aIdx += 4; + } + + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + c += 8; + } + + void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 2, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 2; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + + __m512i pa[4], pb[2], psum[2], pbSum[2], pt[2]; + __m256 paScale, paBias, pbScale, r; + __m256i pr; + pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), + shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), + shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), + shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); + for (size_t mi = 0; mi < m; mi += packM) { const size_t microM = std::min(packM, m - mi); - packScatteredGEMVAPanel(aBuffer, packK, aScale, aBias, aBase, aIdx, aIdxScale, microM, k); - qgemvKernel16(microM, k, aBuffer, bBuffer, aScale, bScale, aBias, bSum, c); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + + // 00, 01, 10, 11, ... + pt[0] = _mm512_unpacklo_epi32(psum[0], psum[1]); + // 20, 21, 30, 31, ... + pt[1] = _mm512_unpackhi_epi32(psum[0], psum[1]); + + // 00, 01, 10, 11, 20, 21, 30, 31 + pr = _mm256_permute2x128_si256(_mm512_castsi512_si256(pt[0]), _mm512_castsi512_si256(pt[1]), 0x20); + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + aIdx += microM; - c += microM; + c += microM * 2; } } - inline void qgemv2Kernel16(size_t m, size_t k, - const uint8_t* a, const int8_t* b, size_t ldb, - const float* aScale, float bScale[2], - const float* aBias, int32_t bSum[2], - float* c) + void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) { - __m512i pa, pb[2], psum[2], pbSum[2]; - __m512 paScale, paBias, pbScale[2], r[2], t[2]; - pbScale[0] = _mm512_set1_ps(bScale[0]); - pbScale[1] = _mm512_set1_ps(bScale[1]); - pbSum[0] = _mm512_set1_epi32(-bSum[0]); - pbSum[1] = _mm512_set1_epi32(-bSum[1]); + constexpr size_t packM = 4, packN = 3, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + float bScale[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + 0 + }; + int32_t bSum[3] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4) + }; + __m512i pa[4], pb[3], psum[3], pbSum[3]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_permutexvar_ps( + _mm512_setr_epi32(0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 3, 3), + _mm512_castps128_ps512(_mm_loadu_ps(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), + shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), + shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), + shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15), + shfIdxT = _mm512_setr_epi32(0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4); - for (size_t j = 0; j < m; j += 16) + for (size_t mi = 0; mi < m; mi += packM) { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; psum[0] = pbSum[0]; psum[1] = pbSum[1]; - for (size_t i = 0; i < k; i += 4) + psum[2] = pbSum[2]; + for (size_t j = 0; j < k; j += 64) { - pa = _mm512_loadu_si512(a); + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[2])); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); - pb[0] = _mm512_set1_epi32(*reinterpret_cast(b + i)); - pb[1] = _mm512_set1_epi32(*reinterpret_cast(b + ldb + i)); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa, pb[0]); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa, pb[1]); + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); - a += 64; - } - paScale = _mm512_loadu_ps(aScale); - paBias = _mm512_loadu_ps(aBias); + // 00, 01, 02, 10, 11, 12, 20, 21, 22, 30, 31, 32, ... + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 1, 5, 17, 2, 6, 18, 3, 7, 19, 0, 0, 0, 0 + ), psum[2]); - // 00, 10, 20, 30, 40, 50, 60, 70 - r[0] = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale[0]), paScale, paBias); - r[1] = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[1]), pbScale[1]), paScale, paBias); - - // 00, 01, 10, 11, 40, 41, 50, 51 - t[0] = _mm512_unpacklo_ps(r[0], r[1]); - // 20, 21, 30, 31, 60, 61, 70, 71 - t[1] = _mm512_unpackhi_ps(r[0], r[1]); - - _mm_storeu_ps(c, _mm512_extractf32x4_ps(t[0], 0)); - _mm_storeu_ps(c + 4, _mm512_extractf32x4_ps(t[1], 0)); - _mm_storeu_ps(c + 8, _mm512_extractf32x4_ps(t[0], 1)); - _mm_storeu_ps(c + 12, _mm512_extractf32x4_ps(t[1], 1)); - _mm_storeu_ps(c + 16, _mm512_extractf32x4_ps(t[0], 2)); - _mm_storeu_ps(c + 20, _mm512_extractf32x4_ps(t[1], 2)); - _mm_storeu_ps(c + 24, _mm512_extractf32x4_ps(t[0], 3)); - _mm_storeu_ps(c + 28, _mm512_extractf32x4_ps(t[1], 3)); - - aScale += 16; - aBias += 16; - c += 32; + paScale = _mm512_castps128_ps512(_mm_loadu_ps(aScale)); + paScale = _mm512_permutexvar_ps(shfIdxT, paScale); + paBias = _mm512_castps128_ps512(_mm_loadu_ps(aBias)); + paBias = _mm512_permutexvar_ps(shfIdxT, paBias); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_mask_storeu_ps(c, 0x0FFF, r); + + aIdx += microM; + c += microM * 3; } } - void scatteredGEMV2( + void scatteredGEMV4( size_t m, size_t k, const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, float* c ) { - constexpr size_t packM = 128, packN = 2, packK = 384; - thread_local uint8_t buffer[packM * packK + packN * packK + packM * 2 * sizeof(float)]; - thread_local uint8_t optABuffer[packM * packK]; - uint8_t* aBuffer = buffer; - int8_t* bBuffer = reinterpret_cast(aBuffer + packM * packK); + constexpr size_t packM = 4, packN = 4, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM; + float* aBias = aScale + packM * 4; memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); - float bScale[2] = { + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + memcpy(bBuffer + packK * 3, bBase + bIdx[3] * bIdxScale, k); + float bScale[4] = { *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k) }; - int32_t bSum[2] = { + int32_t bSum[4] = { *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k + 4) }; + __m512i pa[4], pb[4], psum[4], pbSum[4]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_broadcast_f32x4(_mm_loadu_ps(bScale)); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + pbSum[3] = _mm512_set1_epi32(-bSum[3] / 4); + __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), + shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), + shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), + shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); + for (size_t mi = 0; mi < m; mi += packM) { const size_t microM = std::min(packM, m - mi); - packScatteredGEMVAPanel(aBuffer, packK, aScale, aBias, aBase, aIdx, aIdxScale, microM, k); - qgemv2Kernel16(microM, k, aBuffer, bBuffer, packK, aScale, bScale, aBias, bSum, c); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + psum[2] = pbSum[2]; + psum[3] = pbSum[3]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + pb[3] = _mm512_loadu_si512(bBuffer + packK * 3 + j); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); + psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); + psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[2])); + psum[2] = _mm512_dpbusd_epi32(psum[2], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[2])); + psum[3] = _mm512_dpbusd_epi32(psum[3], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[3])); + psum[3] = _mm512_dpbusd_epi32(psum[3], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[3])); + psum[3] = _mm512_dpbusd_epi32(psum[3], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[3])); + psum[3] = _mm512_dpbusd_epi32(psum[3], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[3])); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 4] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 4] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 4)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 8)); + + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); + // 02, 12, 22, 32, 03, 13, 23, 33 + psum[2] = _mm512_inserti32x4(psum[2], _mm512_castsi512_si128(psum[3]), 1); + + // 00, 01, 02, 03, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33 + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 20, 1, 5, 17, 21, 2, 6, 18, 22, 3, 7, 19, 23 + ), psum[2]); + + paScale = _mm512_loadu_ps(aScale); + paScale = _mm512_shuffle_ps(paScale, paScale, 0); + paBias = _mm512_loadu_ps(aBias); + paBias = _mm512_shuffle_ps(paBias, paBias, 0); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + aIdx += microM; - c += microM * 2; + c += microM * 4; } } @@ -697,14 +1036,22 @@ namespace kiwi if (n == 1 && ldc == 1) { - return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + if (m == 8) + { + return scatteredGEMV8x1(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + else + { + return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } } - if (m >= 4 && n == 2 && ldc == 2) + if (m >= 4) { - return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 2 && ldc == 2) return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 3 && ldc == 3) return scatteredGEMV3(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 4 && ldc == 4) return scatteredGEMV4(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); } - return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); } From 108c4fdccd988fea2e2f714f2df09cf408152e62 Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 12 Feb 2025 02:23:50 +0900 Subject: [PATCH 15/53] Fixed wrong initialization of pclm-local model --- src/PCLanguageModel.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index fd0ce140..e71c3453 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -442,10 +442,16 @@ namespace kiwi allEmbs = make_unique(contextEmbSize + outputEmbSize + distantEmbSize + positionConfSize + distantMaskSize); auto p = allEmbs.get(); contextEmbPtr = reinterpret_cast(p); - distantEmbPtr = windowSize > 0 ? reinterpret_cast(p += contextEmbSize) : nullptr; + distantEmbPtr = reinterpret_cast(p += contextEmbSize); outputEmbPtr = reinterpret_cast(p += distantEmbSize); - positionConfidPtr = windowSize > 0 ? reinterpret_cast(p += outputEmbSize) : nullptr; - distantMaskPtr = windowSize > 0 ? reinterpret_cast(p += positionConfSize) : nullptr; + positionConfidPtr = reinterpret_cast(p += outputEmbSize); + distantMaskPtr = reinterpret_cast(p += positionConfSize); + if (windowSize == 0) + { + distantEmbPtr = nullptr; + positionConfidPtr = nullptr; + distantMaskPtr = nullptr; + } } auto* eptr = ptr + header.embOffset; From aa29ff53bbfd9b384960e3d35b4f21caace59af3 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 15 Feb 2025 02:02:41 +0900 Subject: [PATCH 16/53] implement VLE for context Trie & optimize nst search --- include/kiwi/ArchUtils.h | 11 +-- src/Knlm.hpp | 2 +- src/PCLanguageModel.cpp | 144 ++++++++++++++++++--------- src/PCLanguageModel.hpp | 83 +++++++++------- src/archImpl/avx2.cpp | 5 - src/archImpl/avx512vnni.cpp | 18 ++++ src/archImpl/avx_vnni.cpp | 12 +++ src/search.cpp | 187 +++++++++++++++++++++++++++--------- src/search.h | 31 ++++-- 9 files changed, 343 insertions(+), 150 deletions(-) create mode 100644 src/archImpl/avx512vnni.cpp create mode 100644 src/archImpl/avx_vnni.cpp diff --git a/include/kiwi/ArchUtils.h b/include/kiwi/ArchUtils.h index c28e5f72..baa27ee5 100644 --- a/include/kiwi/ArchUtils.h +++ b/include/kiwi/ArchUtils.h @@ -27,16 +27,7 @@ namespace kiwi const char* archToStr(ArchType arch); template - struct ArchInfo; - - template<> - struct ArchInfo - { - static constexpr size_t alignment = 4; - }; - - template<> - struct ArchInfo + struct ArchInfo { static constexpr size_t alignment = 4; }; diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 8c8dab82..18531153 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -367,7 +367,7 @@ namespace kiwi } }; - template + template struct KnLMState : public LmStateBase> { int32_t node = 0; diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index e71c3453..397a0a0e 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -17,10 +17,10 @@ namespace kiwi return (n + multiple - 1) / multiple * multiple; } - template - struct MorphemeEvaluator> + template + struct MorphemeEvaluator> { - using LmState = lm::PcLMState; + using LmState = lm::PcLMState; template void eval( @@ -47,7 +47,7 @@ namespace kiwi thread_local Vector nextWids, nextDistantWids; thread_local Vector scores; - const auto* langMdl = static_cast*>(kw->getLangModel()); + const auto* langMdl = static_cast*>(kw->getLangModel()); const Morpheme* morphBase = kw->morphemes.data(); const auto spacePenalty = kw->spacePenalty; const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; @@ -320,16 +320,26 @@ namespace kiwi arr -= std::log(arr.exp().sum()); } - template - PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ mem } + template + PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ mem } { auto* ptr = reinterpret_cast(mem.get()); Vector nodeSizes(header.numNodes); streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); - keyValueData = make_unique((header.numNodes - 1) * (sizeof(KeyType) + sizeof(int32_t))); - auto keyData = make_unique(header.numNodes - 1); - if (std::is_same::value) + + static constexpr size_t kvAlignment = ArchInfo::alignment; + size_t paddedKVSize = 0; + for (size_t i = 0; i < nodeSizes.size(); ++i) + { + if (!nodeSizes[i]) continue; + paddedKVSize += padMultipleOf(nodeSizes[i] * (sizeof(VlKeyType) + sizeof(int32_t)), kvAlignment); + } + + keyValueData = make_unique(paddedKVSize + kvAlignment); + alignedKeyValueData = reinterpret_cast(padMultipleOf(reinterpret_cast(keyValueData.get()), kvAlignment)); + auto keyData = make_unique(header.numNodes - 1); + if (std::is_same::value) { streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); } @@ -366,9 +376,8 @@ namespace kiwi } node.value = values[i]; node.numNexts = nodeSizes[i]; - node.nextOffset = nextOffset; + keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)nextOffset, (size_t)(nextOffset + node.numNexts) }); nextOffset += nodeSizes[i]; - keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)node.nextOffset, (size_t)(node.nextOffset + node.numNexts) }); nonLeafIdx++; } else @@ -386,22 +395,23 @@ namespace kiwi } } - uint8_t* kvDataPtr = keyValueData.get(); + uint8_t* kvDataPtr = const_cast(alignedKeyValueData); nonLeafIdx = 0; nextOffset = 0; for (size_t i = 0; i < header.numNodes; ++i) { if (!nodeSizes[i]) continue; auto& node = nodeData[nonLeafIdx]; - memcpy(kvDataPtr, &keyData[nextOffset], node.numNexts * sizeof(KeyType)); - kvDataPtr += node.numNexts * sizeof(KeyType); - memcpy(kvDataPtr, &valueData[nextOffset], node.numNexts * sizeof(int32_t)); - kvDataPtr += node.numNexts * sizeof(int32_t); + node.nextOffset = (uint32_t)(kvDataPtr - alignedKeyValueData); + memcpy(kvDataPtr, &keyData[nextOffset], node.numNexts * sizeof(VlKeyType)); + memcpy(kvDataPtr + node.numNexts * sizeof(VlKeyType), &valueData[nextOffset], node.numNexts * sizeof(int32_t)); + kvDataPtr += node.numNexts * (sizeof(VlKeyType) + sizeof(int32_t)); nextOffset += node.numNexts; nonLeafIdx++; } allRootValueData = make_unique(header.vocabSize); + memset(allRootValueData.get(), 0, sizeof(int32_t) * header.vocabSize); for (size_t i = 0; i < nodeData[0].numNexts; ++i) { allRootValueData[keyData[i]] = valueData[i]; @@ -410,7 +420,7 @@ namespace kiwi for (size_t i = 0; i < nonLeafIdx; ++i) { auto& node = nodeData[i]; - nst::prepareKV(&keyValueData[node.nextOffset * (sizeof(KeyType) + sizeof(int32_t))], node.numNexts, tempBuf); + nst::prepareKV(const_cast(&alignedKeyValueData[node.nextOffset]), node.numNexts, tempBuf); } Deque dq; @@ -419,14 +429,13 @@ namespace kiwi auto p = dq.front(); for (size_t i = 0; i < p->numNexts; ++i) { - auto k = keyData[p->nextOffset + i]; - auto v = valueData[p->nextOffset + i]; - if (v <= 0) continue; - auto* child = &p[v]; - child->lower = findLowerNode(p, k) - child; + auto kv = nst::extractKV(&alignedKeyValueData[p->nextOffset], p->numNexts, i); + if (kv.second <= 0) continue; + auto* child = &p[kv.second]; + child->lower = findLowerNode(p, kv.first) - child; if (child->value == 0) { - child->value = findLowerValue(p, k); + child->value = findLowerValue(p, kv.first); } dq.emplace_back(child); } @@ -564,8 +573,8 @@ namespace kiwi } } - template - float PcLangModel::progress(int32_t& nodeIdx, + template + float PcLangModel::progress(int32_t& nodeIdx, uint32_t& contextIdx, std::array& history, KeyType next) const @@ -673,9 +682,9 @@ namespace kiwi } // specialization for windowSize > 0 - template + template template - auto PcLangModel::nextState( + auto PcLangModel::nextState( const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next) const -> LmStateType { LmStateType ret = state; @@ -689,9 +698,9 @@ namespace kiwi } // specialization for windowSize == 0 - template + template template - auto PcLangModel::nextState( + auto PcLangModel::nextState( const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next) const -> LmStateType { LmStateType ret = state; @@ -709,9 +718,9 @@ namespace kiwi return make_pair(a >> 32, a & 0xFFFFFFFF); } - template + template template - void PcLangModel::progressMatrix( + void PcLangModel::progressMatrix( const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -894,6 +903,9 @@ namespace kiwi if constexpr (quantized) { + //thread_local Map, size_t> shapeCnt; + //shapeCnt[make_pair(uniqInputSize + uniqHistorySize, uniqOutputSize)]++; + //ScopedTimer<> timer{ 0 }; qgemm::scatteredGEMMOpt( uniqInputSize + uniqHistorySize, uniqOutputSize, header.dim, getContextQuantEmb(0), contextIdcs2.data(), contextEmbStride(), @@ -963,9 +975,9 @@ namespace kiwi } } - template + template template - void PcLangModel::progressMatrix( + void PcLangModel::progressMatrix( const typename std::enable_if<_windowSize == 0, LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -1076,7 +1088,8 @@ namespace kiwi throw IOException{ "Cannot open file : " + contextDefinition }; } - uint32_t maxClusterId = 0; + uint32_t maxClusterId = 0, maxContextId = 0; + size_t keySize = 0; using Node = utils::TrieNodeEx>>; utils::ContinuousTrie trie(1); { @@ -1100,6 +1113,7 @@ namespace kiwi auto id = stol(tokens[i].begin(), tokens[i].end()); if (id < 0) throw IOException{ "Invalid format : " + contextDefinition }; context.push_back(id); + maxContextId = max(maxContextId, (uint32_t)id); } if (context.size() > maxContextLength) { @@ -1141,11 +1155,47 @@ namespace kiwi } } + if (maxContextId <= 0xFFFF) + { + keySize = 2; + } + else if (maxContextId <= 0xFFFFF) + { + keySize = 3; // variable length key + } + else + { + keySize = 4; + } + for (auto& c : contextMap) { for (auto& p : c) { - trie.build(p.first.begin(), p.first.end(), p.second + 1); + if (keySize == 3) + { + static constexpr size_t tMax = (1 << 16) - (1 << 10) * 2; + context.clear(); + for (auto id : p.first) + { + if (id < tMax) + { + context.emplace_back(id); + } + else + { + id -= tMax; + const size_t high = id >> 10, low = id & 0x3FF; + context.emplace_back(tMax + high); + context.emplace_back(tMax + (1 << 10) + low); + } + } + trie.build(context.begin(), context.end(), p.second + 1); + } + else + { + trie.build(p.first.begin(), p.first.end(), p.second + 1); + } } } } @@ -1278,7 +1328,7 @@ namespace kiwi header.dim = dim; header.contextSize = contextSize; header.vocabSize = outputSize; - header.keySize = 4; + header.keySize = keySize; header.windowSize = windowSize; header.numNodes = nodeSizes.size(); @@ -1329,31 +1379,31 @@ namespace kiwi return mem; } - template - void* PcLangModel::getFindBestPathFn() const + template + void* PcLangModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder::findBestPath>; } - template - void* PcLangModel::getNewJoinerFn() const + template + void* PcLangModel::getNewJoinerFn() const { return (void*)&newJoinerWithKiwi; } - template + template inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) { auto& header = *reinterpret_cast(mem.get()); if (!useDistantTokens) { - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); } switch (header.windowSize) { case 7: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; }; @@ -1366,9 +1416,11 @@ namespace kiwi switch (header.keySize) { case 2: - return createOptimizedModelWithWindowSize(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); + case 3: + return createOptimizedModelWithWindowSize(std::move(mem)); case 4: - return createOptimizedModelWithWindowSize(std::move(mem)); + return createOptimizedModelWithWindowSize(std::move(mem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; } diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index 79cfd0fe..af210369 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -15,16 +15,17 @@ namespace kiwi { namespace lm { - template + template class PcLMState; - template + template class PcLangModel : public PcLangModelBase { using MyNode = Node; std::unique_ptr nodeData; std::unique_ptr keyValueData; + const uint8_t* alignedKeyValueData = nullptr; std::unique_ptr allRootValueData; std::unique_ptr allEmbs; const uint8_t* contextEmbPtr = nullptr; // [numContexts, (dim + scale? + bias + confid + vts)] @@ -130,13 +131,12 @@ namespace kiwi while (node->lower) { auto* lowerNode = node + node->lower; - auto* kvs = &keyValueData[lowerNode->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; + auto* kvs = &alignedKeyValueData[lowerNode->nextOffset]; int32_t found; - if (nst::searchKV( + if ((found = nst::searchKV( kvs, lowerNode->numNexts, - k, - found) && found >= 0) + k)) > 0) { return lowerNode + found; } @@ -150,13 +150,12 @@ namespace kiwi while (node->lower) { auto* lowerNode = node + node->lower; - auto* kvs = &keyValueData[lowerNode->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; + auto* kvs = &alignedKeyValueData[lowerNode->nextOffset]; int32_t found; - if (nst::searchKV( + if ((found = nst::searchKV( kvs, lowerNode->numNexts, - k, - found)) + k)) != 0) { if (found >= 0) { @@ -174,7 +173,7 @@ namespace kiwi public: using VocabType = KeyType; - using LmStateType = PcLMState; + using LmStateType = PcLMState; PcLangModel(utils::MemoryObject&& mem); @@ -196,23 +195,42 @@ namespace kiwi uint32_t progressContextNode(int32_t& nodeIdx, KeyType next) const { + if (std::is_same::value) + { + return progressContextNodeVl(nodeIdx, next); + } + + static constexpr size_t tMax = (1 << 16) - (1 << 10) * 2; + if (next < tMax) + { + return progressContextNodeVl(nodeIdx, next); + } + next -= tMax; + const size_t high = next >> 10, low = next & 0x3FF; + progressContextNodeVl(nodeIdx, tMax + high); + return progressContextNodeVl(nodeIdx, tMax + (1 << 10) + low); + } + + uint32_t progressContextNodeVl(int32_t& nodeIdx, VlKeyType next) const + { + static constexpr size_t N = 64 / sizeof(VlKeyType) + 1; while (1) { int32_t v; auto* node = &nodeData[nodeIdx]; - auto* kvs = &keyValueData[node->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; + auto* kvs = &alignedKeyValueData[node->nextOffset]; if (node != nodeData.get()) { //ScopedTimer<> timer(node->numNexts <= 16 ? 0 : node->numNexts <= 272 ? 1 : 2); PREFETCH_T0(node + node->lower); - if (!nst::searchKV( + if ((v = nst::searchKV( kvs, - node->numNexts, next, v - )) + node->numNexts, next + )) == 0) { if (!node->lower) return 0; nodeIdx += node->lower; - PREFETCH_T0(&keyValueData[nodeData[nodeIdx].nextOffset * (sizeof(KeyType) + sizeof(int32_t))]); + PREFETCH_T0(&alignedKeyValueData[nodeData[nodeIdx].nextOffset]); continue; } } @@ -237,15 +255,14 @@ namespace kiwi while (node->lower) { node += node->lower; - auto* lkvs = &keyValueData[node->nextOffset * (sizeof(KeyType) + sizeof(int32_t))]; + auto* lkvs = &alignedKeyValueData[node->nextOffset]; int32_t lv; if (node != nodeData.get()) { - //ScopedTimer<> timer(node->numNexts <= 16 ? 0 : node->numNexts <= 272 ? 1 : 2); - if (nst::searchKV( + if ((lv = nst::searchKV( lkvs, - node->numNexts, next, lv - )) + node->numNexts, next + )) != 0) { if (lv > 0) { @@ -304,8 +321,8 @@ namespace kiwi LmStateType* outStates, float* outScores) const; }; - template - struct PcLMState : public LmStateBase> + template + struct PcLMState : public LmStateBase> { int32_t node = 0; uint32_t contextIdx = 0; @@ -328,14 +345,14 @@ namespace kiwi return true; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const PcLangModel* lm, VocabTy next) { return lm->progress(node, contextIdx, history, next); } }; - template - struct PcLMState<0, _arch, VocabTy, quantized> : public LmStateBase> + template + struct PcLMState<0, _arch, VocabTy, VlVocabTy, quantized> : public LmStateBase> { int32_t node = 0; uint32_t contextIdx = 0; @@ -352,7 +369,7 @@ namespace kiwi return node == other.node; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const PcLangModel* lm, VocabTy next) { std::array history = { {0,} }; return lm->progress(node, contextIdx, history, next); @@ -360,10 +377,10 @@ namespace kiwi }; } - template - struct Hash> + template + struct Hash> { - size_t operator()(const lm::PcLMState& state) const + size_t operator()(const lm::PcLMState& state) const { size_t ret = (uint32_t)(state.node * (size_t)2654435761); static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); @@ -373,10 +390,10 @@ namespace kiwi } }; - template - struct Hash> + template + struct Hash> { - size_t operator()(const lm::PcLMState<0, arch, VocabTy, quantized>& state) const + size_t operator()(const lm::PcLMState<0, arch, VocabTy, VlVocabTy, quantized>& state) const { size_t ret = (uint32_t)(state.node * (size_t)2654435761); return ret; diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 6ff2c16e..76e0c1d4 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -8,10 +8,5 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; - - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; } } diff --git a/src/archImpl/avx512vnni.cpp b/src/archImpl/avx512vnni.cpp new file mode 100644 index 00000000..f879d3ef --- /dev/null +++ b/src/archImpl/avx512vnni.cpp @@ -0,0 +1,18 @@ +#include "../SkipBigramModelImpl.hpp" +#include "../qgemm.h" + +namespace kiwi +{ + namespace lm + { + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + } + + namespace qgemm + { + + } +} diff --git a/src/archImpl/avx_vnni.cpp b/src/archImpl/avx_vnni.cpp new file mode 100644 index 00000000..1f8add2e --- /dev/null +++ b/src/archImpl/avx_vnni.cpp @@ -0,0 +1,12 @@ +#include "../SkipBigramModelImpl.hpp" + +namespace kiwi +{ + namespace lm + { + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + template class SkipBigramModel; + } +} diff --git a/src/search.cpp b/src/search.cpp index 821a93bd..58f060a6 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -16,16 +16,16 @@ template bool detail::searchImpl(const uint32_t*, size_t, uint32_t, size_t&);\ template bool detail::searchImpl(const uint64_t*, size_t, uint64_t, size_t&);\ template bool detail::searchImpl(const char16_t*, size_t, char16_t, size_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint8_t, uint32_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint16_t, uint32_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint32_t, uint32_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint64_t, uint32_t&);\ - template bool detail::searchKVImpl(const void*, size_t, char16_t, uint32_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint8_t, uint64_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint16_t, uint64_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint32_t, uint64_t&);\ - template bool detail::searchKVImpl(const void*, size_t, uint64_t, uint64_t&);\ - template bool detail::searchKVImpl(const void*, size_t, char16_t, uint64_t&);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint8_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint16_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint32_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint64_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, char16_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint8_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint16_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint32_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint64_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, char16_t);\ template Vector detail::reorderImpl(const uint8_t*, size_t);\ template Vector detail::reorderImpl(const uint16_t*, size_t);\ template Vector detail::reorderImpl(const uint32_t*, size_t);\ @@ -39,6 +39,14 @@ #include #endif +#if defined(_MSC_VER) +#define FORCE_INLINE __forceinline +#elif defined(__GNUC__) +#define FORCE_INLINE __attribute__((always_inline)) +#else +#define FORCE_INLINE inline +#endif + #ifdef __GNUC__ #define ARCH_TARGET(x) __attribute__((target(x))) #else @@ -167,9 +175,9 @@ namespace kiwi } template - bool detail::searchKVImpl(const void* kv, size_t size, IntTy target, ValueTy& ret) + ValueTy detail::searchKVImpl(const void* kv, size_t size, IntTy target) { - return OptimizedImpl::searchKV(kv, size, target, ret); + return OptimizedImpl::searchKV(kv, size, target); } template @@ -196,17 +204,16 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { size_t idx; const IntTy* keys = reinterpret_cast(kv); const ValueTy* values = reinterpret_cast(keys + size); if (search(keys, size, target, idx)) { - ret = values[idx]; - return true; + return values[idx]; } - else return false; + else return 0; } }; INSTANTIATE_IMPL(ArchType::none); @@ -243,17 +250,16 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { size_t idx; const IntTy* keys = reinterpret_cast(kv); const ValueTy* values = reinterpret_cast(keys + size); if (search(keys, size, target, idx)) { - ret = values[idx]; - return true; + return values[idx]; } - else return false; + else return 0; } }; INSTANTIATE_IMPL(ArchType::balanced); @@ -267,7 +273,7 @@ namespace kiwi { template ARCH_TARGET("sse2") - inline bool testEq(__m128i p, size_t offset, size_t size, size_t& ret) + FORCE_INLINE bool testEq(__m128i p, size_t offset, size_t size, size_t& ret) { uint32_t m = _mm_movemask_epi8(p); uint32_t b = utils::countTrailingZeroes(m); @@ -281,7 +287,7 @@ namespace kiwi template ARCH_TARGET("sse2") - bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + FORCE_INLINE bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -443,7 +449,7 @@ namespace kiwi template ARCH_TARGET("sse2") - bool nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + FORCE_INLINE ValueTy nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target) { size_t i = 0, r; @@ -461,6 +467,31 @@ namespace kiwi break; } + if (size < n) + { + pkey = _mm_loadu_si128(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm_cmpeq_epi8(ptarget, pkey); + break; + case 2: + peq = _mm_cmpeq_epi16(ptarget, pkey); + break; + case 4: + peq = _mm_cmpeq_epi32(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + while (i < size) { pkey = _mm_loadu_si128(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); @@ -484,14 +515,13 @@ namespace kiwi { const size_t groupSize = std::min(n - 1, size - i); const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); - ret = values[r]; - return true; + return values[r]; } r = utils::popcount((uint32_t)_mm_movemask_epi8(pgt)) / sizeof(IntTy); i = i * n + (n - 1) * (r + 1); } - return false; + return 0; } template<> @@ -514,10 +544,10 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { using SignedIntTy = typename SignedType::type; - return nstSearchKVSSE2((const uint8_t*)kv, size, (SignedIntTy)target, ret); + return nstSearchKVSSE2((const uint8_t*)kv, size, target); } }; INSTANTIATE_IMPL(ArchType::sse2); @@ -532,7 +562,7 @@ namespace kiwi { template ARCH_TARGET("avx2") - inline bool testEq(__m256i p, size_t offset, size_t size, size_t& ret) + FORCE_INLINE bool testEq(__m256i p, size_t offset, size_t size, size_t& ret) { uint32_t m = _mm256_movemask_epi8(p); uint32_t b = utils::countTrailingZeroes(m); @@ -544,7 +574,7 @@ namespace kiwi return false; } - inline bool testEqMask(uint64_t m, size_t offset, size_t size, size_t& ret) + FORCE_INLINE bool testEqMask(uint64_t m, size_t offset, size_t size, size_t& ret) { uint32_t b = utils::countTrailingZeroes(m); if (m && (offset + b) < size) @@ -557,7 +587,7 @@ namespace kiwi template ARCH_TARGET("avx2") - bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + FORCE_INLINE bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -746,8 +776,13 @@ namespace kiwi template ARCH_TARGET("avx2") - bool nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + FORCE_INLINE ValueTy nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target) { + if (size < (n + 1) / 2) + { + return nstSearchKVSSE2<(n + 1) / 2, IntTy, ValueTy>(kv, size, target); + } + size_t i = 0, r; __m256i ptarget, pkey, peq, pgt; @@ -767,6 +802,34 @@ namespace kiwi break; } + if (size < n) + { + pkey = _mm256_loadu_si256(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm256_cmpeq_epi8(ptarget, pkey); + break; + case 2: + peq = _mm256_cmpeq_epi16(ptarget, pkey); + break; + case 4: + peq = _mm256_cmpeq_epi32(ptarget, pkey); + break; + case 8: + peq = _mm256_cmpeq_epi64(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + while (i < size) { pkey = _mm256_loadu_si256(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); @@ -794,19 +857,18 @@ namespace kiwi { const size_t groupSize = std::min(n - 1, size - i); const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); - ret = values[r]; - return true; + return values[r]; } r = utils::popcount((uint32_t)_mm256_movemask_epi8(pgt)) / sizeof(IntTy); i = i * n + (n - 1) * (r + 1); } - return false; + return 0; } template ARCH_TARGET("avx512bw") - bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) + FORCE_INLINE bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -996,8 +1058,13 @@ namespace kiwi template ARCH_TARGET("avx512bw") - bool nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target, ValueTy& ret) + FORCE_INLINE ValueTy nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target) { + if (size < (n + 1) / 2) + { + return nstSearchKVAVX2<(n + 1) / 2, IntTy, ValueTy>(kv, size, target); + } + size_t i = 0, r; const IntTy* keys; @@ -1019,6 +1086,35 @@ namespace kiwi break; } + if (size < n) + { + keys = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))]); + pkey = _mm512_loadu_si512(reinterpret_cast(keys)); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm512_cmpeq_epi8_mask(ptarget, pkey); + break; + case 2: + peq = _mm512_cmpeq_epi16_mask(ptarget, pkey); + break; + case 4: + peq = _mm512_cmpeq_epi32_mask(ptarget, pkey); + break; + case 8: + peq = _mm512_cmpeq_epi64_mask(ptarget, pkey); + break; + } + + if (testEqMask(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&keys[groupSize]); + return values[r]; + } + return 0; + } + while (i < size) { keys = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))]); @@ -1047,14 +1143,13 @@ namespace kiwi { const size_t groupSize = std::min(n - 1, size - i); const ValueTy* values = reinterpret_cast(&keys[groupSize]); - ret = values[r]; - return true; + return values[r]; } r = utils::popcount(pgt); i = i * n + (n - 1) * (r + 1); } - return false; + return 0; } template<> @@ -1077,10 +1172,10 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { using SignedIntTy = typename SignedType::type; - return nstSearchKVSSE2((const uint8_t*)kv, size, (SignedIntTy)target, ret); + return nstSearchKVSSE2((const uint8_t*)kv, size, target); } }; INSTANTIATE_IMPL(ArchType::sse4_1); @@ -1105,10 +1200,10 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { using SignedIntTy = typename SignedType::type; - return nstSearchKVAVX2((const uint8_t*)kv, size, (SignedIntTy)target, ret); + return nstSearchKVAVX2((const uint8_t*)kv, size, target); } }; INSTANTIATE_IMPL(ArchType::avx2); @@ -1139,10 +1234,10 @@ namespace kiwi } template - static bool searchKV(const void* kv, size_t size, IntTy target, ValueTy& ret) + static ValueTy searchKV(const void* kv, size_t size, IntTy target) { using SignedIntTy = typename SignedType::type; - return nstSearchKVAVX512((const uint8_t*)kv, size, (SignedIntTy)target, ret); + return nstSearchKVAVX512((const uint8_t*)kv, size, target); } }; INSTANTIATE_IMPL(ArchType::avx512bw); diff --git a/src/search.h b/src/search.h index f3c4fa9d..10845287 100644 --- a/src/search.h +++ b/src/search.h @@ -21,7 +21,7 @@ namespace kiwi bool searchImpl(const IntTy* keys, size_t size, IntTy target, size_t& ret); template - bool searchKVImpl(const void* keys, size_t size, IntTy target, ValueTy& ret); + ValueTy searchKVImpl(const void* keys, size_t size, IntTy target); template Vector reorderImpl(const IntTy* keys, size_t size); @@ -87,6 +87,25 @@ namespace kiwi } } + template + std::pair extractKV(const void* kv, size_t totSize, size_t idx) + { + const size_t packetSize = detail::getPacketSizeImpl() / sizeof(IntTy); + if (packetSize <= 1) + { + const auto* key = reinterpret_cast(kv); + const auto* value = reinterpret_cast(key + totSize); + return std::make_pair(key[idx], value[idx]); + } + + const size_t groupIdx = idx / packetSize; + const size_t groupOffset = idx % packetSize; + const auto* group = reinterpret_cast(kv) + groupIdx * packetSize * (sizeof(IntTy) + sizeof(Value)); + const auto* key = reinterpret_cast(group); + const auto* value = reinterpret_cast(key + std::min(packetSize, totSize - groupIdx * packetSize)); + return std::make_pair(key[groupOffset], value[groupOffset]); + } + template bool search(const IntTy* keys, const Value* values, size_t size, IntTy target, Out& ret) { @@ -112,15 +131,9 @@ namespace kiwi } template - bool searchKV(const void* kv, size_t size, IntTy target, Out& ret) + Out searchKV(const void* kv, size_t size, IntTy target) { - typename UnsignedType::type out; - if (detail::searchKVImpl(kv, size, target, out)) - { - ret = out; - return true; - } - else return false; + return detail::searchKVImpl::type>(kv, size, target); } } } From a2ff10a71e0045d578be142acc96ace496a5821d Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 15 Feb 2025 02:04:51 +0900 Subject: [PATCH 17/53] Add max-length argument for pclm context --- tools/pclm_builder.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/pclm_builder.cpp b/tools/pclm_builder.cpp index 64e1d3eb..618ea39a 100644 --- a/tools/pclm_builder.cpp +++ b/tools/pclm_builder.cpp @@ -9,13 +9,14 @@ using namespace std; using namespace kiwi; -int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, size_t minCnt, const std::string& output, bool reorderContextIdx = true) +int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, + size_t minCnt, size_t maxLength, const std::string& output, bool reorderContextIdx = true) { try { tutils::Timer timer; KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); - auto ret = lm::PcLangModelBase::build(contextDef, embedding, reorderContextIdx); + auto ret = lm::PcLangModelBase::build(contextDef, embedding, maxLength, reorderContextIdx); ret.writeToFile(output + "/pclm.mdl"); double tm = timer.getElapsed(); cout << "Total: " << tm << " ms " << endl; @@ -40,6 +41,7 @@ int main(int argc, const char* argv[]) ValueArg cdef{ "c", "context-def", "context definition", true, "", "string" }; ValueArg emb{ "e", "emb", "embedding file", true, "", "string" }; ValueArg minCnt{ "n", "min-cnt", "min count of morpheme", false, 10, "int" }; + ValueArg maxLength{ "l", "max-length", "max length of n-grams", false, (size_t)-1, "int"}; ValueArg output{ "o", "output", "", true, "", "string" }; SwitchArg preserveContextIdx{ "p", "preserve-context-idx", "preserve context index", false }; @@ -47,6 +49,7 @@ int main(int argc, const char* argv[]) cmd.add(cdef); cmd.add(emb); cmd.add(minCnt); + cmd.add(maxLength); cmd.add(output); cmd.add(preserveContextIdx); @@ -60,5 +63,5 @@ int main(int argc, const char* argv[]) return -1; } - return run(mdef, cdef, emb, minCnt, output, !preserveContextIdx); + return run(mdef, cdef, emb, minCnt, maxLength, output, !preserveContextIdx); } From 9285d0517de9c31ffe87d3966babb7bf75c32de3 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 16 Feb 2025 21:40:24 +0900 Subject: [PATCH 18/53] Add an option for `useVLE` --- include/kiwi/PCLanguageModel.h | 2 +- src/PCLanguageModel.cpp | 4 ++-- tools/pclm_builder.cpp | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/PCLanguageModel.h index 664869e0..00b975fd 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/PCLanguageModel.h @@ -50,7 +50,7 @@ namespace kiwi const PcLangModelHeader& getHeader() const { return header; } - static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, size_t maxContextLength = -1, bool reorderContextIdx = true); + static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, size_t maxContextLength = -1, bool useVLE = true, bool reorderContextIdx = true); static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); }; } diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index 397a0a0e..53d302a1 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -1080,7 +1080,7 @@ namespace kiwi } } - utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool reorderContextId) + utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool useVLE, bool reorderContextId) { ifstream contextStr, embeddingStr; if (!openFile(contextStr, contextDefinition)) @@ -1159,7 +1159,7 @@ namespace kiwi { keySize = 2; } - else if (maxContextId <= 0xFFFFF) + else if (useVLE && maxContextId <= 0xFFFFF) { keySize = 3; // variable length key } diff --git a/tools/pclm_builder.cpp b/tools/pclm_builder.cpp index 618ea39a..8e825ff9 100644 --- a/tools/pclm_builder.cpp +++ b/tools/pclm_builder.cpp @@ -10,13 +10,13 @@ using namespace std; using namespace kiwi; int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, - size_t minCnt, size_t maxLength, const std::string& output, bool reorderContextIdx = true) + size_t minCnt, size_t maxLength, const std::string& output, bool useVLE, bool reorderContextIdx = true) { try { tutils::Timer timer; KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); - auto ret = lm::PcLangModelBase::build(contextDef, embedding, maxLength, reorderContextIdx); + auto ret = lm::PcLangModelBase::build(contextDef, embedding, maxLength, useVLE, reorderContextIdx); ret.writeToFile(output + "/pclm.mdl"); double tm = timer.getElapsed(); cout << "Total: " << tm << " ms " << endl; @@ -43,6 +43,7 @@ int main(int argc, const char* argv[]) ValueArg minCnt{ "n", "min-cnt", "min count of morpheme", false, 10, "int" }; ValueArg maxLength{ "l", "max-length", "max length of n-grams", false, (size_t)-1, "int"}; ValueArg output{ "o", "output", "", true, "", "string" }; + SwitchArg useVLE{ "", "use-vle", "use VLE", false }; SwitchArg preserveContextIdx{ "p", "preserve-context-idx", "preserve context index", false }; cmd.add(mdef); @@ -51,6 +52,7 @@ int main(int argc, const char* argv[]) cmd.add(minCnt); cmd.add(maxLength); cmd.add(output); + cmd.add(useVLE); cmd.add(preserveContextIdx); try @@ -63,5 +65,5 @@ int main(int argc, const char* argv[]) return -1; } - return run(mdef, cdef, emb, minCnt, maxLength, output, !preserveContextIdx); + return run(mdef, cdef, emb, minCnt, maxLength, output, useVLE, !preserveContextIdx); } From 30ec39bf27e33b81ca66c891e87380e0dabb95a9 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 16 Feb 2025 21:46:28 +0900 Subject: [PATCH 19/53] Add caching `nextState` & optimize `nextState` --- src/PCLanguageModel.cpp | 465 ++++++++++++++++++++++++++++++---------- src/PCLanguageModel.hpp | 54 +++-- 2 files changed, 397 insertions(+), 122 deletions(-) diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index 53d302a1..c7a28c58 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -118,9 +118,19 @@ namespace kiwi if (nextWids.size() > 0) { + if (prevLmStates.size() == 1 && nextWids.size() == 1) + { + nextLmStates.resize(1); + scores.resize(1); + nextLmStates[0] = prevLmStates[0]; + scores[0] = nextLmStates[0].next(langMdl, nextWids[0]); + } + else + { nextLmStates.resize(prevLmStates.size() * nextWids.size()); scores.resize(prevLmStates.size() * nextWids.size()); langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + } } for (size_t curId = 0; curId < regularMorphs.size(); ++curId) @@ -681,17 +691,45 @@ namespace kiwi return ll; } + template + struct PcLangModel::TLSForProgressMatrix + { + Vector> contextCache; + Vector contextIdcs, historyIdcs, nextIdcs; + Vector inverseContextIdcs, inverseHistoryIdcs, inverseNextIdcs; + Vector resultBuf, confidenceBuf, scoreBuf; + UnorderedMap historyMap; + Vector uniqHistoryTokens; + Vector inputEmbBuf, outputEmbBuf; // only for non-quantized + Vector contextIdcs2, nextIdcs2; // only for quantized + }; + // specialization for windowSize > 0 template template - auto PcLangModel::nextState( - const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next) const -> LmStateType + inline auto PcLangModel::nextState( + const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next, + bool cacheIsValid, pair& cache) const -> LmStateType { - LmStateType ret = state; + LmStateType ret{ state.node }; // partially initialized + if (cacheIsValid) + { + ret.node = cache.first; + ret.contextIdx = cache.second; + } + else + { ret.contextIdx = progressContextNode(ret.node, next); - if (ret.history[windowSize]) + cache = std::make_pair(ret.node, ret.contextIdx); + } + + if (state.history[windowSize]) + { + memcpy(&ret.history[0], &state.history[1], windowSize * sizeof(KeyType)); + } + else { - memcpy(&ret.history[0], &ret.history[1], windowSize * sizeof(KeyType)); + memcpy(&ret.history[0], &state.history[0], windowSize * sizeof(KeyType)); } ret.history[windowSize] = distantTokenMask(next) ? next : 0; return ret; @@ -700,11 +738,21 @@ namespace kiwi // specialization for windowSize == 0 template template - auto PcLangModel::nextState( - const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next) const -> LmStateType + inline auto PcLangModel::nextState( + const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next, + bool cacheIsValid, pair& cache) const -> LmStateType { - LmStateType ret = state; + LmStateType ret{ state.node }; // partially initialized + if (cacheIsValid) + { + ret.node = cache.first; + ret.contextIdx = cache.second; + } + else + { ret.contextIdx = progressContextNode(ret.node, next); + cache = std::make_pair(ret.node, ret.contextIdx); + } return ret; } @@ -719,93 +767,87 @@ namespace kiwi } template - template - void PcLangModel::progressMatrix( - const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + inline void PcLangModel::progressMatrixWSort( + TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const { static constexpr size_t scoreBatchSize = 32; - thread_local Vector contextIdcs, historyIdcs, nextIdcs; - thread_local Vector inverseContextIdcs, inverseHistoryIdcs, inverseNextIdcs; - thread_local Vector inputEmbBuf, outputEmbBuf, resultBuf, confidenceBuf; - thread_local Vector scoreBuf; - thread_local Vector contextIdcs2, nextIdcs2; - contextIdcs.resize(prevStateSize); - historyIdcs.clear(); - nextIdcs.resize(nextIdSize); - inverseContextIdcs.resize(prevStateSize); - inverseHistoryIdcs.clear(); - inverseHistoryIdcs.resize(prevStateSize * windowSize, -1); - inverseNextIdcs.resize(nextIdSize); + tls.contextIdcs.resize(prevStateSize); + tls.historyIdcs.clear(); + tls.nextIdcs.resize(nextIdSize); + tls.inverseContextIdcs.resize(prevStateSize); + tls.inverseHistoryIdcs.clear(); + tls.inverseHistoryIdcs.resize(prevStateSize * windowSize, -1); + tls.inverseNextIdcs.resize(nextIdSize); if (quantized) { - contextIdcs2.clear(); - nextIdcs2.clear(); + tls.contextIdcs2.clear(); + tls.nextIdcs2.clear(); } else { - inputEmbBuf.resize(prevStateSize * header.dim); - outputEmbBuf.resize(nextIdSize * header.dim); + tls.inputEmbBuf.resize(prevStateSize * header.dim); + tls.outputEmbBuf.resize(nextIdSize * header.dim); } - confidenceBuf.resize(prevStateSize * 2); - scoreBuf.resize(scoreBatchSize * (windowSize + 2)); + tls.confidenceBuf.resize(prevStateSize * 2); + tls.scoreBuf.resize(scoreBatchSize * (windowSize + 2)); const size_t numInvalidDistantTokens = nextIdSize - numValidDistantTokens; for (size_t i = 0; i < nextIdSize; ++i) { - nextIdcs[i] = mergePair(nextIds[i], i); + tls.nextIdcs[i] = mergePair(nextIds[i], i); } - sort(nextIdcs.begin(), nextIdcs.begin() + numInvalidDistantTokens); - sort(nextIdcs.begin() + numInvalidDistantTokens, nextIdcs.end()); + sort(tls.nextIdcs.begin(), tls.nextIdcs.begin() + numInvalidDistantTokens); + sort(tls.nextIdcs.begin() + numInvalidDistantTokens, tls.nextIdcs.end()); size_t uniqOutputSize = 0; for (size_t i = 0; i < nextIdSize; ++i) { - const auto nextId = splitPair(nextIdcs[i]).first; - const auto idx = splitPair(nextIdcs[i]).second; - if (i == 0 || nextId != splitPair(nextIdcs[i - 1]).first) + const auto nextId = splitPair(tls.nextIdcs[i]).first; + const auto idx = splitPair(tls.nextIdcs[i]).second; + if (i == 0 || nextId != splitPair(tls.nextIdcs[i - 1]).first) { if (quantized) { - nextIdcs2.emplace_back(nextId); + tls.nextIdcs2.emplace_back(nextId); } else { - copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &outputEmbBuf[uniqOutputSize * header.dim]); + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &tls.outputEmbBuf[uniqOutputSize * header.dim]); } uniqOutputSize++; } - inverseNextIdcs[idx] = uniqOutputSize - 1; + tls.inverseNextIdcs[idx] = uniqOutputSize - 1; } - resultBuf.resize(prevStateSize * uniqOutputSize); + tls.resultBuf.resize(prevStateSize * uniqOutputSize); for (size_t i = 0; i < prevStateSize; ++i) { - contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); + tls.contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); } - sort(contextIdcs.begin(), contextIdcs.end()); + sort(tls.contextIdcs.begin(), tls.contextIdcs.end()); size_t uniqInputSize = 0; for (size_t i = 0; i < prevStateSize; ++i) { - const auto contextId = splitPair(contextIdcs[i]).first; - const auto idx = splitPair(contextIdcs[i]).second; - if (i == 0 || contextId != splitPair(contextIdcs[i - 1]).first) + const auto contextId = splitPair(tls.contextIdcs[i]).first; + const auto idx = splitPair(tls.contextIdcs[i]).second; + if (i == 0 || contextId != splitPair(tls.contextIdcs[i - 1]).first) { if (quantized) { - contextIdcs2.emplace_back(contextId); + tls.contextIdcs2.emplace_back(contextId); } else { - copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &inputEmbBuf[uniqInputSize * header.dim]); - fill(&resultBuf[uniqInputSize * uniqOutputSize], &resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &tls.inputEmbBuf[uniqInputSize * header.dim]); + fill(&tls.resultBuf[uniqInputSize * uniqOutputSize], &tls.resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); } - confidenceBuf[uniqInputSize * 2] = getContextConfid(contextId); - confidenceBuf[uniqInputSize * 2 + 1] = getContextValidTokenSum(contextId); + tls.confidenceBuf[uniqInputSize * 2] = getContextConfid(contextId); + tls.confidenceBuf[uniqInputSize * 2 + 1] = getContextValidTokenSum(contextId); uniqInputSize++; } - inverseContextIdcs[idx] = uniqInputSize - 1; + tls.inverseContextIdcs[idx] = uniqInputSize - 1; } size_t uniqHistorySize = 0; @@ -818,53 +860,54 @@ namespace kiwi const auto historyToken = prevStates[i].history[j]; if (historyToken) { - historyIdcs.emplace_back(mergePair(historyToken, i * windowSize + j)); + tls.historyIdcs.emplace_back(mergePair(historyToken, i * windowSize + j)); } } } - sort(historyIdcs.begin(), historyIdcs.end()); + sort(tls.historyIdcs.begin(), tls.historyIdcs.end()); uniqHistorySize = 0; - for (size_t i = 0; i < historyIdcs.size(); ++i) + for (size_t i = 0; i < tls.historyIdcs.size(); ++i) { - const auto historyToken = splitPair(historyIdcs[i]).first; - const auto idx = splitPair(historyIdcs[i]).second; - if (i == 0 || historyToken != splitPair(historyIdcs[i - 1]).first) + const auto historyToken = splitPair(tls.historyIdcs[i]).first; + const auto idx = splitPair(tls.historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(tls.historyIdcs[i - 1]).first) { uniqHistorySize++; } - inverseHistoryIdcs[idx] = uniqHistorySize - 1; + tls.inverseHistoryIdcs[idx] = uniqHistorySize - 1; + } + if (!quantized) + { + tls.inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); } - inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); - confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); - resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); + tls.confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + tls.resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); uniqHistorySize = 0; - for (size_t i = 0; i < historyIdcs.size(); ++i) + for (size_t i = 0; i < tls.historyIdcs.size(); ++i) { - const auto historyToken = splitPair(historyIdcs[i]).first; - const auto idx = splitPair(historyIdcs[i]).second; - if (i == 0 || historyToken != splitPair(historyIdcs[i - 1]).first) + const auto historyToken = splitPair(tls.historyIdcs[i]).first; + const auto idx = splitPair(tls.historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(tls.historyIdcs[i - 1]).first) { if (quantized) { - contextIdcs2.emplace_back(historyToken + header.contextSize); + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); } else { - copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &inputEmbBuf[(uniqInputSize + uniqHistorySize) * header.dim]); - fill(&resultBuf[(uniqInputSize + uniqHistorySize) * uniqOutputSize], &resultBuf[(uniqInputSize + uniqHistorySize + 1) * uniqOutputSize], getDistantBias(historyToken)); + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(uniqInputSize + uniqHistorySize) * header.dim]); + fill(&tls.resultBuf[(uniqInputSize + uniqHistorySize) * uniqOutputSize], &tls.resultBuf[(uniqInputSize + uniqHistorySize + 1) * uniqOutputSize], getDistantBias(historyToken)); } - confidenceBuf[uniqInputSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); + tls.confidenceBuf[uniqInputSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); uniqHistorySize++; } } } else // use map for large size { - thread_local UnorderedMap historyMap; - thread_local Vector uniqHistoryTokens; - historyMap.clear(); - uniqHistoryTokens.clear(); + tls.historyMap.clear(); + tls.uniqHistoryTokens.clear(); for (size_t i = 0; i < prevStateSize; ++i) { for (size_t j = 0; j < windowSize; ++j) @@ -872,34 +915,35 @@ namespace kiwi const auto historyToken = prevStates[i].history[j]; if (!historyToken) continue; const auto idx = i * windowSize + j; - auto inserted = historyMap.emplace(historyToken, historyMap.size()); - inverseHistoryIdcs[idx] = inserted.first->second; - if (inserted.second) uniqHistoryTokens.emplace_back(historyToken); + auto inserted = tls.historyMap.emplace(historyToken, tls.historyMap.size()); + tls.inverseHistoryIdcs[idx] = inserted.first->second; + if (inserted.second) tls.uniqHistoryTokens.emplace_back(historyToken); } } - uniqHistorySize = historyMap.size(); - - inputEmbBuf.resize((uniqInputSize + uniqHistorySize)* header.dim); - confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); - resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); + uniqHistorySize = tls.historyMap.size(); + if (!quantized) + { + tls.inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); + } + tls.confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + tls.resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); - for (size_t i = 0; i < uniqHistoryTokens.size(); ++i) + for (size_t i = 0; i < tls.uniqHistoryTokens.size(); ++i) { - const auto historyToken = uniqHistoryTokens[i]; + const auto historyToken = tls.uniqHistoryTokens[i]; if (quantized) { - contextIdcs2.emplace_back(historyToken + header.contextSize); + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); } else { - copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &inputEmbBuf[(uniqInputSize + i) * header.dim]); - fill(&resultBuf[(uniqInputSize + i) * uniqOutputSize], &resultBuf[(uniqInputSize + i + 1) * uniqOutputSize], getDistantBias(historyToken)); + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(uniqInputSize + i) * header.dim]); + fill(&tls.resultBuf[(uniqInputSize + i) * uniqOutputSize], &tls.resultBuf[(uniqInputSize + i + 1) * uniqOutputSize], getDistantBias(historyToken)); } - confidenceBuf[uniqInputSize * 2 + i] = getDistantConfid(historyToken); + tls.confidenceBuf[uniqInputSize * 2 + i] = getDistantConfid(historyToken); } } - - Eigen::Map resultMap{ resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + Eigen::Map resultMap{ tls.resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; if constexpr (quantized) { @@ -908,27 +952,201 @@ namespace kiwi //ScopedTimer<> timer{ 0 }; qgemm::scatteredGEMMOpt( uniqInputSize + uniqHistorySize, uniqOutputSize, header.dim, - getContextQuantEmb(0), contextIdcs2.data(), contextEmbStride(), - getOutputQuantEmb(0), nextIdcs2.data(), outputEmbStride(), - resultBuf.data(), uniqOutputSize); + getContextQuantEmb(0), tls.contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), tls.nextIdcs2.data(), outputEmbStride(), + tls.resultBuf.data(), uniqOutputSize); } else { - Eigen::Map inputMap{ inputEmbBuf.data(), header.dim, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; - Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; + Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; resultMap += outputMap.transpose() * inputMap; } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { for (size_t i = 0; i < prevStateSize; ++i) { - const auto state = prevStates[i]; + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; for (size_t j = 0; j < numInvalidDistantTokens; ++j) { - outScores[i * nextIdSize + j] = resultMap(inverseNextIdcs[j], inverseContextIdcs[i]); - outStates[i * nextIdSize + j] = nextState<_windowSize>(state, nextIds[j]); + outScores[i * nextIdSize + j] = resultMap(tls.inverseNextIdcs[j], tls.inverseContextIdcs[i]); } } - auto* validTokenSumBuf = scoreBuf.data() + scoreBatchSize * (windowSize + 1); + auto* validTokenSumBuf = tls.scoreBuf.data() + scoreBatchSize * (windowSize + 1); + for (size_t i = 0; i < prevStateSize * numValidDistantTokens; i += scoreBatchSize) + { + const size_t batchSize = std::min(scoreBatchSize, prevStateSize * numValidDistantTokens - i); + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] = tls.confidenceBuf[tls.inverseContextIdcs[pIdx] * 2]; + validTokenSumBuf[j] = tls.confidenceBuf[tls.inverseContextIdcs[pIdx] * 2 + 1]; + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + tls.scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : tls.confidenceBuf[uniqInputSize * 2 + idx]; + } + } + Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; + scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; + LogSoftmaxTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); + scoreMap.rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] += resultMap(tls.inverseNextIdcs[nIdx], tls.inverseContextIdcs[pIdx]); + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + if (idx != -1) + { + tls.scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(tls.inverseNextIdcs[nIdx], uniqInputSize + idx); + } + } + } + LogSumExpTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); + + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + const bool cacheIsValid = pIdx > 0 && prevStates[pIdx].node == prevStates[pIdx - 1].node; + outScores[pIdx * nextIdSize + nIdx] = tls.scoreBuf[j]; + } + } + } + + template + inline void PcLangModel::progressMatrixWOSort( + TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + static constexpr size_t scoreBatchSize = 32; + + if (quantized) + { + tls.contextIdcs2.clear(); + tls.nextIdcs2.clear(); + } + else + { + tls.inputEmbBuf.resize(prevStateSize * (1 + windowSize) * header.dim); + tls.outputEmbBuf.resize(nextIdSize * header.dim); + } + tls.confidenceBuf.resize(prevStateSize * (2 + windowSize)); + tls.scoreBuf.resize(scoreBatchSize * (windowSize + 2)); + tls.resultBuf.resize(padMultipleOf(prevStateSize * (1 + windowSize), 8) * padMultipleOf(nextIdSize, 8)); + + const size_t numInvalidDistantTokens = nextIdSize - numValidDistantTokens; + for (size_t i = 0; i < nextIdSize; ++i) + { + const auto nextId = nextIds[i]; + if (quantized) + { + tls.nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &tls.outputEmbBuf[i * header.dim]); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = prevStates[i].contextIdx; + if (quantized) + { + tls.contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &tls.inputEmbBuf[i * header.dim]); + fill(&tls.resultBuf[i * nextIdSize], &tls.resultBuf[(i + 1) * nextIdSize], getContextBias(contextId)); + } + tls.confidenceBuf[i * 2] = getContextConfid(contextId); + tls.confidenceBuf[i * 2 + 1] = getContextValidTokenSum(contextId); + } + + size_t uniqHistorySize = 0; + tls.inverseHistoryIdcs.clear(); + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[j]; + if (historyToken) + { + if (quantized) + { + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(prevStateSize + uniqHistorySize) * header.dim]); + fill(&tls.resultBuf[(prevStateSize + uniqHistorySize) * nextIdSize], &tls.resultBuf[(prevStateSize + uniqHistorySize + 1) * nextIdSize], getDistantBias(historyToken)); + } + tls.confidenceBuf[prevStateSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); + uniqHistorySize++; + } + tls.inverseHistoryIdcs.emplace_back(historyToken ? uniqHistorySize - 1 : -1); + + } + } + + Eigen::Map resultMap{ tls.resultBuf.data(), (Eigen::Index)nextIdSize, (Eigen::Index)(prevStateSize + uniqHistorySize) }; + + if constexpr (quantized) + { + qgemm::scatteredGEMMOpt( + prevStateSize + uniqHistorySize, nextIdSize, header.dim, + getContextQuantEmb(0), tls.contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), tls.nextIdcs2.data(), outputEmbStride(), + tls.resultBuf.data(), nextIdSize); + } + else + { + Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(prevStateSize + uniqHistorySize) }; + Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)nextIdSize }; + resultMap += outputMap.transpose() * inputMap; + } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + for (size_t j = 0; j < numInvalidDistantTokens; ++j) + { + outScores[i * nextIdSize + j] = resultMap(j, i); + } + } + + auto* validTokenSumBuf = tls.scoreBuf.data() + scoreBatchSize * (windowSize + 1); for (size_t i = 0; i < prevStateSize * numValidDistantTokens; i += scoreBatchSize) { @@ -937,44 +1155,61 @@ namespace kiwi { const auto pIdx = (i + j) / numValidDistantTokens; const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; - scoreBuf[j] = confidenceBuf[inverseContextIdcs[pIdx] * 2]; - validTokenSumBuf[j] = confidenceBuf[inverseContextIdcs[pIdx] * 2 + 1]; + tls.scoreBuf[j] = tls.confidenceBuf[pIdx * 2]; + validTokenSumBuf[j] = tls.confidenceBuf[pIdx * 2 + 1]; for (size_t k = 0; k < windowSize; ++k) { - const auto idx = inverseHistoryIdcs[pIdx * windowSize + k]; - scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : confidenceBuf[uniqInputSize * 2 + idx]; + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + tls.scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : tls.confidenceBuf[prevStateSize * 2 + idx]; } } - Eigen::Map> scoreMap{ scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; + Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; - LogSoftmaxTransposed{}(scoreBuf.data(), batchSize, scoreBatchSize); + LogSoftmaxTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); scoreMap.rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; for (size_t j = 0; j < batchSize; ++j) { const auto pIdx = (i + j) / numValidDistantTokens; const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; - scoreBuf[j] += resultMap(inverseNextIdcs[nIdx], inverseContextIdcs[pIdx]); + tls.scoreBuf[j] += resultMap(nIdx, pIdx); for (size_t k = 0; k < windowSize; ++k) { - const auto idx = inverseHistoryIdcs[pIdx * windowSize + k]; + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; if (idx != -1) { - scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(inverseNextIdcs[nIdx], uniqInputSize + idx); + tls.scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(nIdx, prevStateSize + idx); } } } - LogSumExpTransposed{}(scoreBuf.data(), batchSize, scoreBatchSize); + LogSumExpTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); for (size_t j = 0; j < batchSize; ++j) { const auto pIdx = (i + j) / numValidDistantTokens; const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; - outScores[pIdx * nextIdSize + nIdx] = scoreBuf[j]; - outStates[pIdx * nextIdSize + nIdx] = nextState(prevStates[pIdx], nextIds[nIdx]); + outScores[pIdx * nextIdSize + nIdx] = tls.scoreBuf[j]; } } } + template + template + void PcLangModel::progressMatrix( + const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + thread_local TLSForProgressMatrix tls; + if (prevStateSize <= (quantized ? 16 : 8) && nextIdSize <= 16) + { + return progressMatrixWOSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + else + { + return progressMatrixWSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + } + template template void PcLangModel::progressMatrix( @@ -1069,12 +1304,22 @@ namespace kiwi Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; resultMap += outputMap.transpose() * inputMap; } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { for (size_t i = 0; i < prevStateSize; ++i) { const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { for (size_t j = 0; j < nextIdSize; ++j) { - outStates[i * nextIdSize + j] = nextState<_windowSize>(state, nextIds[j]); outScores[i * nextIdSize + j] = resultMap(inverseNextIdcs[j], inverseContextIdcs[i]); } } diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index af210369..4b09b2f2 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -300,21 +300,33 @@ namespace kiwi KeyType next) const; template - LmStateType nextState(const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next) const; + LmStateType nextState(const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next, + bool cacheIsValid, std::pair& cache) const; template - LmStateType nextState(const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next) const; + LmStateType nextState(const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next, + bool cacheIsValid, std::pair& cache) const; /* - * �� prevStateSize���� ���¿� nextIdSize���� ���� ��ū�� �޾Ƽ�, �� ���º��� ���� ��ū�� ������ Ȯ���� ����ϰ� �� ���¸� ��ȯ�Ѵ�. - * �� ���°��� outStates�� ����ǰ�, �� ���º� Ȯ������ outScores�� ����ȴ�. - * nextIdSize���� ���� ��ū �� ������ numValidDistantTokens���� ��ū�� ��ȿ�� distant ��ū���� ó���ȴ�. + * 총 prevStateSize개의 상태와 nextIdSize개의 다음 토큰을 받아서, 각 상태별로 다음 토큰이 등장할 확률을 계산하고 새 상태를 반환한다. + * 새 상태값은 outStates에 저장되고, 각 상태별 확률값은 outScores에 저장된다. + * nextIdSize개의 다음 토큰 중 마지막 numValidDistantTokens개의 토큰은 유효한 distant 토큰으로 처리된다. */ template void progressMatrix(const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const; + struct TLSForProgressMatrix; + + inline void progressMatrixWSort(TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + + inline void progressMatrixWOSort(TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + template void progressMatrix(const typename std::enable_if<(_windowSize == 0), LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, @@ -325,14 +337,23 @@ namespace kiwi struct PcLMState : public LmStateBase> { int32_t node = 0; - uint32_t contextIdx = 0; - std::array history = { {0,} }; + uint32_t contextIdx; + std::array history; static constexpr ArchType arch = _arch; static constexpr bool transposed = true; - PcLMState() = default; - PcLMState(const ILangModel* lm) {} + PcLMState() : contextIdx{ 0 }, history { { 0, } } + { + } + + PcLMState(const ILangModel* lm) : contextIdx{ 0 }, history{ {0,} } + { + } + + PcLMState(int32_t _node) : node{ _node } // partially initialized state + { + } bool operator==(const PcLMState& other) const { @@ -355,14 +376,23 @@ namespace kiwi struct PcLMState<0, _arch, VocabTy, VlVocabTy, quantized> : public LmStateBase> { int32_t node = 0; - uint32_t contextIdx = 0; + uint32_t contextIdx; static constexpr ArchType arch = _arch; static constexpr bool transposed = true; static constexpr size_t windowSize = 0; - PcLMState() = default; - PcLMState(const ILangModel* lm) {} + PcLMState() : contextIdx{ 0 } + { + } + + PcLMState(const ILangModel* lm) : contextIdx{ 0 } + { + } + + PcLMState(int32_t _node) : node{ _node } // partially initialized state + { + } bool operator==(const PcLMState& other) const { From 7affd86c3266650e13a7208dfa506fb3a85fb2b8 Mon Sep 17 00:00:00 2001 From: bab2min Date: Thu, 20 Feb 2025 01:16:15 +0900 Subject: [PATCH 20/53] Optimize `BestPathContainer` for small and medium sizes --- src/BestPathContainer.hpp | 259 ++++++++++++++++++++++++++++++++------ src/PCLanguageModel.cpp | 7 +- src/PCLanguageModel.hpp | 16 ++- src/PathEvaluator.hpp | 20 ++- src/search.cpp | 121 +++++++++++++++++- src/search.h | 10 ++ 6 files changed, 385 insertions(+), 48 deletions(-) diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp index 3255b6bc..d97f0fab 100644 --- a/src/BestPathContainer.hpp +++ b/src/BestPathContainer.hpp @@ -12,22 +12,25 @@ namespace kiwi enum class PathEvaluatingMode { topN, - top1, top1Small, + top1Medium, + top1, }; template struct WordLL { + LmState lmState; + uint8_t prevRootId = 0; + SpecialState spState; + uint8_t rootId = 0; + const Morpheme* morpheme = nullptr; float accScore = 0, accTypoCost = 0; const WordLL* parent = nullptr; - LmState lmState; Wid wid = 0; uint16_t ownFormId = 0; uint8_t combineSocket = 0; - uint8_t rootId = 0; - SpecialState spState; WordLL() = default; @@ -38,7 +41,7 @@ namespace kiwi parent{ _parent }, lmState{ _lmState }, spState{ _spState }, - rootId{ parent ? parent->rootId : (uint8_t)0 } + rootId{ _parent ? _parent->rootId : (uint8_t)0 } { } @@ -47,6 +50,34 @@ namespace kiwi if (parent) return parent->root(); else return this; } + + bool equalTo(const LmState& lmState, uint8_t prevRootId, SpecialState spState) const + { + return (this->prevRootId == prevRootId & this->spState == spState) && this->lmState == lmState; + } + + bool operator==(const WordLL& o) const + { + return equalTo(o.lmState, o.prevRootId, o.spState); + } + }; + + template + struct Hash> + { + size_t operator()(const WordLL& p) const + { + size_t ret = Hash{}(p.lmState); + ret = *reinterpret_cast(&p.prevRootId) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + return ret; + } + + size_t operator()(const LmState& lmState, uint8_t prevRootId, uint8_t spState) const + { + size_t ret = Hash{}(lmState); + ret = ((uint16_t)(prevRootId) | ((uint16_t)spState << 8)) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + return ret; + } }; static constexpr uint8_t commonRootId = -1; @@ -79,7 +110,7 @@ namespace kiwi size_t operator()(const PathHash& p) const { size_t ret = Hash{}(p.lmState); - ret ^= *reinterpret_cast(&p.rootId); + ret = *reinterpret_cast(&p.rootId) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); return ret; } }; @@ -110,6 +141,12 @@ namespace kiwi template class BestPathConatiner; + template + struct BestPathContainerTraits + { + static constexpr size_t maxSize = -1; + }; + template class BestPathConatiner { @@ -117,15 +154,17 @@ namespace kiwi UnorderedMap, std::pair> bestPathIndex; Vector> bestPathValues; public: + inline void clear() { bestPathIndex.clear(); bestPathValues.clear(); } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { + PathHash ph{ lmState, prevRootId, spState }; auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); if (inserted.second) { @@ -183,22 +222,25 @@ namespace kiwi template class BestPathConatiner { - UnorderedMap, WordLL> bestPathes; + UnorderedSet> bestPathes; public: inline void clear() { bestPathes.clear(); } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { WordLL newPath{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + newPath.prevRootId = prevRootId; if (rootId != commonRootId) newPath.rootId = rootId; - auto inserted = bestPathes.emplace(ph, newPath); + auto inserted = bestPathes.emplace(newPath); if (!inserted.second) { - auto& target = inserted.first->second; + // this is dangerous, but we can update the key safely + // because an equality between the two objects is guaranteed + auto& target = const_cast&>(*inserted.first); if (accScore > target.accScore) { target = newPath; @@ -210,7 +252,7 @@ namespace kiwi { for (auto& p : bestPathes) { - resultOut.emplace_back(move(p.second)); + resultOut.emplace_back(move(p)); auto& newPath = resultOut.back(); // fill the rest information of resultOut @@ -224,55 +266,200 @@ namespace kiwi } }; - template - class BestPathConatiner + template<> + struct BestPathContainerTraits { - Vector> bestPathIndicesSmall; - Vector> bestPathValuesSmall; + static constexpr size_t maxSize = (sizeof(size_t) == 8 ? 64 : 32) * 2; + }; + + template<> + struct BestPathContainerTraits + { + static constexpr size_t maxSize = BestPathContainerTraits::maxSize * 4; + }; + + template + class BucketedHashContainer + { + static constexpr size_t bucketSize = 1 << bucketBits; + + std::array::maxSize>, bucketSize> hashes; + std::array>, bucketSize> values; + public: + BucketedHashContainer() + { + for (auto& v : values) + { + v.reserve(BestPathContainerTraits::maxSize); + } + } inline void clear() { - bestPathIndicesSmall.clear(); - bestPathValuesSmall.clear(); + for (auto& v : values) + { + v.clear(); + } } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, + template + inline void insertOptimized(size_t topN, uint8_t prevRootId, uint8_t rootId, const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { - const auto it = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph); - if (it == bestPathIndicesSmall.end()) + static constexpr size_t numBits = sizeof(size_t) * 8; + const size_t h = Hash>{}(lmState, prevRootId, spState); + const size_t bucket = (h >> 8) & (bucketSize - 1); + auto& hash = hashes[bucket]; + auto& value = values[bucket]; + + size_t it = value.size(); + size_t bits[2]; + bits[0] = nst::findAll(hash.data(), std::min(value.size(), numBits), (uint8_t)h); + bits[1] = value.size() > numBits ? nst::findAll(hash.data() + numBits, value.size() - numBits, (uint8_t)h) : 0; + while (bits[0]) + { + const size_t i = utils::countTrailingZeroes(bits[0]); + if (value[i].equalTo(lmState, prevRootId, spState)) + { + it = i; + goto breakloop; + } + bits[0] &= ~((size_t)1 << i); + } + while (bits[1]) { - bestPathIndicesSmall.push_back(ph); - bestPathValuesSmall.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); - if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; + const size_t i = utils::countTrailingZeroes(bits[1]); + if (value[i].equalTo(lmState, prevRootId, spState)) + { + it = i + numBits; + goto breakloop; + } + bits[1] &= ~((size_t)1 << i); + } + + breakloop:; + if (it >= value.size()) + { + if (value.size() < hash.size()) + { + hash[value.size()] = h; + value.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + value.back().prevRootId = prevRootId; + if (rootId != commonRootId) value.back().rootId = rootId; + } + else + { + // skip insertion if container is full. + // this isn't correct, but it rarely happens + } } else { - auto& target = bestPathValuesSmall[it - bestPathIndicesSmall.begin()]; + auto& target = value[it]; if (accScore > target.accScore) { - target = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + target.morpheme = morph; + target.accScore = accScore; + target.accTypoCost = accTypoCost; + target.parent = parent; + target.lmState = move(lmState); + target.spState = spState; + target.rootId = parent ? parent->rootId : 0; if (rootId != commonRootId) target.rootId = rootId; } } } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { - for (auto& p : bestPathValuesSmall) + static constexpr ArchType archType = LmState::arch; + if constexpr (archType != ArchType::none && archType != ArchType::balanced) { - resultOut.emplace_back(move(p)); - auto& newPath = resultOut.back(); + return insertOptimized(topN, prevRootId, rootId, morph, accScore, accTypoCost, parent, move(lmState), spState); + } - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->isSingle()) + const size_t h = Hash>{}(lmState, prevRootId, spState); + const size_t bucket = (h >> 8) & (bucketSize - 1); + auto& hash = hashes[bucket]; + auto& value = values[bucket]; + + const auto hashEnd = hash.begin() + value.size(); + auto it = find(hash.begin(), hashEnd, (uint8_t)h); + while (it != hashEnd) + { + if (value[it - hash.begin()].equalTo(lmState, prevRootId, spState)) { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; + break; + } + ++it; + it = find(it, hashEnd, (uint8_t)h); + } + + if (it == hashEnd) + { + if (value.size() < hash.size()) + { + hash[value.size()] = h; + value.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + value.back().prevRootId = prevRootId; + if (rootId != commonRootId) value.back().rootId = rootId; + } + else + { + // skip insertion if container is full. + // this isn't correct, but it rarely happens + } + } + else + { + auto& target = value[it - hash.begin()]; + if (accScore > target.accScore) + { + target.morpheme = morph; + target.accScore = accScore; + target.accTypoCost = accTypoCost; + target.parent = parent; + target.lmState = move(lmState); + target.spState = spState; + target.rootId = parent ? parent->rootId : 0; + if (rootId != commonRootId) target.rootId = rootId; + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& v : values) + { + for (auto& p : v) + { + resultOut.emplace_back(move(p)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->isSingle()) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } } } } }; -} \ No newline at end of file + + + template + class alignas(BestPathContainerTraits::maxSize) BestPathConatiner + : public BucketedHashContainer + { + }; + + template + class alignas(BestPathContainerTraits::maxSize) BestPathConatiner + : public BucketedHashContainer + { + }; +} diff --git a/src/PCLanguageModel.cpp b/src/PCLanguageModel.cpp index c7a28c58..53641d3f 100644 --- a/src/PCLanguageModel.cpp +++ b/src/PCLanguageModel.cpp @@ -73,10 +73,15 @@ namespace kiwi continue; } regularPrevPathes.emplace_back(&prevPath); - prevLmStates.emplace_back(prevPath.lmState); } } + prevLmStates.resize(regularPrevPathes.size()); + for (size_t i = 0; i < regularPrevPathes.size(); ++i) + { + prevLmStates[i] = regularPrevPathes[i]->lmState; + } + for (auto& curMorph : morphs) { if (curMorph->combineSocket) diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index 4b09b2f2..5891f18c 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -407,15 +407,23 @@ namespace kiwi }; } + static constexpr size_t largePrime = sizeof(size_t) == 8 ? 2305843009213693951ll : 2654435761ll; + + inline size_t rol(size_t x, size_t r) + { + return (x << r) | (x >> (sizeof(size_t) * 8 - r)); + } + template struct Hash> { size_t operator()(const lm::PcLMState& state) const { - size_t ret = (uint32_t)(state.node * (size_t)2654435761); + size_t ret = (state.node * largePrime) ^ rol(state.node, sizeof(size_t) * 4 + 1); static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); - const auto h = *reinterpret_cast(&state.history[cmpStart]); - ret = h ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + size_t h = *reinterpret_cast(&state.history[cmpStart]); + h = (h * largePrime) ^ rol(h, sizeof(size_t) * 4 - 1); + ret = h ^ rol(ret, 3); return ret; } }; @@ -425,7 +433,7 @@ namespace kiwi { size_t operator()(const lm::PcLMState<0, arch, VocabTy, VlVocabTy, quantized>& state) const { - size_t ret = (uint32_t)(state.node * (size_t)2654435761); + size_t ret = (state.node * largePrime) ^ rol(state.node, sizeof(size_t) * 4 + 1); return ret; } }; diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index e4f84446..86bbb94f 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -215,8 +215,7 @@ namespace kiwi spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); } - PathHash ph{ state, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(state), spState); + bestPathCont.insert(topN, prevPath.rootId, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(state), spState); }; if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) @@ -354,7 +353,6 @@ namespace kiwi { totalPrevPathes += cache[prev - startNode].size(); } - const bool useContainerForSmall = totalPrevPathes <= 48; for (bool ignoreCond : {false, true}) { @@ -430,11 +428,16 @@ namespace kiwi evalSingleMorpheme(nCache, node, ownFormId, curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); } - else if (useContainerForSmall) + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) { evalSingleMorpheme(nCache, node, ownFormId, curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); } + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } else { evalSingleMorpheme(nCache, node, ownFormId, @@ -928,7 +931,6 @@ namespace kiwi { totalPrevPathes += cache[prev - startNode].size(); } - const bool useContainerForSmall = totalPrevPathes <= 48; MorphemeEvaluator me; if (topN > 1) @@ -937,12 +939,18 @@ namespace kiwi ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } - else if (useContainerForSmall) + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) { me.eval(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) + { + me.eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } else { me.eval(nCache, kw, ownFormList, cache, diff --git a/src/search.cpp b/src/search.cpp index 58f060a6..2aca6f11 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -31,7 +31,8 @@ template Vector detail::reorderImpl(const uint32_t*, size_t);\ template Vector detail::reorderImpl(const uint64_t*, size_t);\ template Vector detail::reorderImpl(const char16_t*, size_t);\ - template size_t detail::getPacketSizeImpl(); + template size_t detail::getPacketSizeImpl();\ + template size_t detail::findAllImpl(const uint8_t*, size_t, uint8_t);\ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 || KIWI_ARCH_X86 || KIWI_ARCH_X86_64 #include @@ -186,6 +187,12 @@ namespace kiwi return OptimizedImpl::packetSize; } + template + size_t detail::findAllImpl(const uint8_t* arr, size_t size, uint8_t key) + { + return OptimizedImpl::findAll(arr, size, key); + } + template<> struct OptimizedImpl { @@ -215,6 +222,16 @@ namespace kiwi } else return 0; } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + size_t ret = 0; + for (size_t i = 0; i < size; ++i) + { + ret |= (size_t)(arr[i] == key) << i; + } + return ret; + } }; INSTANTIATE_IMPL(ArchType::none); @@ -261,6 +278,16 @@ namespace kiwi } else return 0; } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + size_t ret = 0; + for (size_t i = 0; i < size; ++i) + { + ret |= (size_t)(arr[i] == key) << i; + } + return ret; + } }; INSTANTIATE_IMPL(ArchType::balanced); } @@ -524,6 +551,48 @@ namespace kiwi return 0; } + ARCH_TARGET("sse2") + inline size_t findAllSSE2(const uint8_t* arr, size_t size, uint8_t key) + { + __m128i pkey = _mm_set1_epi8(key); + if (size <= 16) + { + __m128i parr = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i pcmp = _mm_cmpeq_epi8(pkey, parr); + return (size_t)_mm_movemask_epi8(pcmp) & (((size_t)1 << size) - 1); + } + else if (size <= 32) + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16)) & (((size_t)1 << size) - 1); + } + else if (size <= 48) + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i parr2 = _mm_loadu_si128(reinterpret_cast(arr + 32)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + __m128i pcmp2 = _mm_cmpeq_epi8(pkey, parr2); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16) | ((size_t)_mm_movemask_epi8(pcmp2) << 32)) & (((size_t)1 << size) - 1); + } + else + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i parr2 = _mm_loadu_si128(reinterpret_cast(arr + 32)); + __m128i parr3 = _mm_loadu_si128(reinterpret_cast(arr + 48)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + __m128i pcmp2 = _mm_cmpeq_epi8(pkey, parr2); + __m128i pcmp3 = _mm_cmpeq_epi8(pkey, parr3); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16) | ((size_t)_mm_movemask_epi8(pcmp2) << 32) | ((size_t)_mm_movemask_epi8(pcmp3) << 48)) & (((size_t)1 << size) - 1); + } + } + template<> struct OptimizedImpl { @@ -549,6 +618,11 @@ namespace kiwi using SignedIntTy = typename SignedType::type; return nstSearchKVSSE2((const uint8_t*)kv, size, target); } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllSSE2(arr, size, key); + } }; INSTANTIATE_IMPL(ArchType::sse2); } @@ -866,6 +940,27 @@ namespace kiwi return 0; } + ARCH_TARGET("avx2") + inline size_t findAllAVX2(const uint8_t* arr, size_t size, uint8_t key) + { + if (size <= 32) + { + __m256i pkey = _mm256_set1_epi8(key); + __m256i parr = _mm256_loadu_si256(reinterpret_cast(arr)); + __m256i pcmp = _mm256_cmpeq_epi8(pkey, parr); + return (size_t)_mm256_movemask_epi8(pcmp) & (((size_t)1 << size) - 1); + } + else + { + __m256i pkey = _mm256_set1_epi8(key); + __m256i parr0 = _mm256_loadu_si256(reinterpret_cast(arr)); + __m256i parr1 = _mm256_loadu_si256(reinterpret_cast(arr + 32)); + __m256i pcmp0 = _mm256_cmpeq_epi8(pkey, parr0); + __m256i pcmp1 = _mm256_cmpeq_epi8(pkey, parr1); + return ((size_t)_mm256_movemask_epi8(pcmp0) | ((size_t)_mm256_movemask_epi8(pcmp1) << 32)) & (((size_t)1 << size) - 1); + } + } + template ARCH_TARGET("avx512bw") FORCE_INLINE bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) @@ -1152,6 +1247,15 @@ namespace kiwi return 0; } + ARCH_TARGET("avx512bw") + inline size_t findAllAVX512(const uint8_t* arr, size_t size, uint8_t key) + { + __m512i pkey = _mm512_set1_epi8(key); + __m512i parr = _mm512_loadu_si512(reinterpret_cast(arr)); + __mmask64 pcmp = _mm512_cmpeq_epi8_mask(pkey, parr); + return (size_t)pcmp & (((size_t)1 << size) - 1); + } + template<> struct OptimizedImpl { @@ -1177,6 +1281,11 @@ namespace kiwi using SignedIntTy = typename SignedType::type; return nstSearchKVSSE2((const uint8_t*)kv, size, target); } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllSSE2(arr, size, key); + } }; INSTANTIATE_IMPL(ArchType::sse4_1); @@ -1205,6 +1314,11 @@ namespace kiwi using SignedIntTy = typename SignedType::type; return nstSearchKVAVX2((const uint8_t*)kv, size, target); } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllAVX2(arr, size, key); + } }; INSTANTIATE_IMPL(ArchType::avx2); @@ -1239,6 +1353,11 @@ namespace kiwi using SignedIntTy = typename SignedType::type; return nstSearchKVAVX512((const uint8_t*)kv, size, target); } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllAVX512(arr, size, key); + } }; INSTANTIATE_IMPL(ArchType::avx512bw); diff --git a/src/search.h b/src/search.h index 10845287..6ec129f3 100644 --- a/src/search.h +++ b/src/search.h @@ -28,6 +28,9 @@ namespace kiwi template size_t getPacketSizeImpl(); + + template + size_t findAllImpl(const uint8_t* arr, size_t size, uint8_t key); } template @@ -135,5 +138,12 @@ namespace kiwi { return detail::searchKVImpl::type>(kv, size, target); } + + template + size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + if (size == 0) return 0; + return detail::findAllImpl(arr, size, key); + } } } From 1fa906d5fc7b8f27daa2310297aee29fd6553d6f Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 21 Feb 2025 00:23:11 +0900 Subject: [PATCH 21/53] Optimize hashes for PcLM --- src/PCLanguageModel.hpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/PCLanguageModel.hpp b/src/PCLanguageModel.hpp index 5891f18c..8bb995f9 100644 --- a/src/PCLanguageModel.hpp +++ b/src/PCLanguageModel.hpp @@ -414,12 +414,21 @@ namespace kiwi return (x << r) | (x >> (sizeof(size_t) * 8 - r)); } + template<> + struct Hash + { + size_t operator()(uint32_t v) const + { + return ((size_t)v * largePrime) ^ rol((size_t)v, sizeof(size_t) * 4 + 1); + } + }; + template struct Hash> { size_t operator()(const lm::PcLMState& state) const { - size_t ret = (state.node * largePrime) ^ rol(state.node, sizeof(size_t) * 4 + 1); + size_t ret = Hash{}(state.node); static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); size_t h = *reinterpret_cast(&state.history[cmpStart]); h = (h * largePrime) ^ rol(h, sizeof(size_t) * 4 - 1); @@ -433,7 +442,7 @@ namespace kiwi { size_t operator()(const lm::PcLMState<0, arch, VocabTy, VlVocabTy, quantized>& state) const { - size_t ret = (state.node * largePrime) ^ rol(state.node, sizeof(size_t) * 4 + 1); + size_t ret = Hash{}(state.node); return ret; } }; From d2e2b9ba02c741d6b17d837d39b0cd720adbdbeb Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 21 Feb 2025 01:02:57 +0900 Subject: [PATCH 22/53] rename `PcLM` to `CoNgram`(cong) --- CMakeLists.txt | 18 +++- .../{PCLanguageModel.h => CoNgramModel.h} | 14 ++-- include/kiwi/Types.h | 10 +-- src/{PCLanguageModel.cpp => CoNgramModel.cpp} | 84 +++++++++---------- src/{PCLanguageModel.hpp => CoNgramModel.hpp} | 55 ++++++------ src/KiwiBuilder.cpp | 11 ++- tools/Evaluator.cpp | 10 +-- tools/{pclm_builder.cpp => cong_builder.cpp} | 6 +- tools/evaluator_main.cpp | 16 ++-- ...{build_pclm.vcxproj => build_cong.vcxproj} | 7 +- vsproj/kiwi_shared_library.vcxproj | 14 +++- 11 files changed, 135 insertions(+), 110 deletions(-) rename include/kiwi/{PCLanguageModel.h => CoNgramModel.h} (66%) rename src/{PCLanguageModel.cpp => CoNgramModel.cpp} (94%) rename src/{PCLanguageModel.hpp => CoNgramModel.hpp} (86%) rename tools/{pclm_builder.cpp => cong_builder.cpp} (91%) rename vsproj/{build_pclm.vcxproj => build_cong.vcxproj} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 36d93cb6..79092974 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,18 +45,21 @@ endif() set ( CORE_SRCS src/ArchUtils.cpp src/Combiner.cpp + src/CoNgramModel.cpp + src/Dataset.cpp src/Form.cpp src/FeatureTestor.cpp src/FileUtils.cpp - src/Dataset.cpp src/Joiner.cpp src/Kiwi.cpp src/KiwiBuilder.cpp src/Knlm.cpp src/KTrie.cpp src/PatternMatcher.cpp + src/qgemm.cpp src/search.cpp src/ScriptType.cpp + src/SkipBigramModel.cpp src/SubstringExtractor.cpp src/SwTokenizer.cpp src/TagUtils.cpp @@ -84,6 +87,11 @@ include_directories( third_party/cpp-btree ) include_directories( third_party/variant/include ) include_directories( third_party/eigen ) include_directories( third_party/json/include ) +include_directories( third_party/streamvbyte/include ) +add_subdirectory( third_party/streamvbyte ) +set ( STREAMBYTE_OBJECTS + $ +) if(KIWI_USE_CPUINFO) message(STATUS "Use cpuinfo") include_directories( third_party/cpuinfo/include ) @@ -156,7 +164,9 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") set( CORE_SRCS ${CORE_SRCS} src/archImpl/avx2.cpp + src/archImpl/avx_vnni.cpp src/archImpl/avx512bw.cpp + src/archImpl/avx512vnni.cpp ) endif() if(MSVC) @@ -164,14 +174,18 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "/arch:SSE2") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX512") + set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX512") endif() else() set_source_files_properties(src/archImpl/sse2.cpp PROPERTIES COMPILE_FLAGS "-msse2") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "-msse2 -msse4.1") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma") + set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavxvnni") set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw") + set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vnni") endif() endif() elseif (KIWI_CPU_ARCH MATCHES "arm64") @@ -191,12 +205,14 @@ add_library( "${PROJECT_NAME}_static" STATIC ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_STATIC} + ${STREAMBYTE_OBJECTS} ) add_library( "${PROJECT_NAME}" SHARED ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_SHARED} + ${STREAMBYTE_OBJECTS} ) # Install the kiwi library as well as header files to (`include/kiwi` directory) diff --git a/include/kiwi/PCLanguageModel.h b/include/kiwi/CoNgramModel.h similarity index 66% rename from include/kiwi/PCLanguageModel.h rename to include/kiwi/CoNgramModel.h index 00b975fd..3cffd91b 100644 --- a/include/kiwi/PCLanguageModel.h +++ b/include/kiwi/CoNgramModel.h @@ -15,7 +15,7 @@ namespace kiwi { namespace lm { - struct PcLangModelHeader + struct CoNgramModelHeader { uint64_t vocabSize, contextSize; uint16_t dim; @@ -34,24 +34,24 @@ namespace kiwi uint32_t nextOffset = 0; }; - class PcLangModelBase : public ILangModel + class CoNgramModelBase : public ILangModel { protected: const size_t memorySize = 0; - PcLangModelHeader header; + CoNgramModelHeader header; - PcLangModelBase(const utils::MemoryObject& mem) : memorySize{ mem.size() }, header{ *reinterpret_cast(mem.get()) } + CoNgramModelBase(const utils::MemoryObject& mem) : memorySize{ mem.size() }, header{ *reinterpret_cast(mem.get()) } { } public: - virtual ~PcLangModelBase() {} + virtual ~CoNgramModelBase() {} size_t vocabSize() const override { return header.vocabSize; } size_t getMemorySize() const override { return memorySize; } - const PcLangModelHeader& getHeader() const { return header; } + const CoNgramModelHeader& getHeader() const { return header; } static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, size_t maxContextLength = -1, bool useVLE = true, bool reorderContextIdx = true); - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); }; } } diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index 1a7a260c..e4f86f05 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -1,4 +1,4 @@ -/** +/** * @file Types.h * @author bab2min (bab2min@gmail.com) * @brief Kiwi C++ API에 쓰이는 주요 타입들을 모아놓은 헤더 파일 @@ -308,10 +308,10 @@ namespace kiwi none = 0, /**< Select default model */ knlm = 1, /**< Kneser-Ney Language Model */ sbg = 2, /**< Skip-Bigram Model */ - pclm = 3, /**< Pre-computed Context Language Model */ - pclmLocal = 4, /**< Pre-computed Context Language Model (Only local context) */ - pclmQuantized = 5, /**< Pre-computed Context Language Model (quantized) */ - pclmLocalQuantized = 6, /**< Pre-computed Context Language Model (Only local context, quantized) */ + cong = 3, /**< Contextual N-gram embedding Language Model (Only local context) */ + congGlobal = 4, /**< Contextual N-gram embedding Language Model (local and global context) */ + congFp32 = 5, /**< Contextual N-gram embedding Language Model (Only local context, non-quantized(slow) version) */ + congGlobalFp32 = 6, /**< Contextual N-gram embedding Language Model (local and global context, non-quantized(slow) version) */ knlmTransposed, }; diff --git a/src/PCLanguageModel.cpp b/src/CoNgramModel.cpp similarity index 94% rename from src/PCLanguageModel.cpp rename to src/CoNgramModel.cpp index 53641d3f..ba209ae4 100644 --- a/src/PCLanguageModel.cpp +++ b/src/CoNgramModel.cpp @@ -3,7 +3,7 @@ #include "PathEvaluator.hpp" #include "Joiner.hpp" #include "Kiwi.hpp" -#include "PCLanguageModel.hpp" +#include "CoNgramModel.hpp" #include "StrUtils.h" #include "FrozenTrie.hpp" #include "qgemm.h" @@ -18,9 +18,9 @@ namespace kiwi } template - struct MorphemeEvaluator> + struct MorphemeEvaluator> { - using LmState = lm::PcLMState; + using LmState = lm::CoNgramState; template void eval( @@ -47,7 +47,7 @@ namespace kiwi thread_local Vector nextWids, nextDistantWids; thread_local Vector scores; - const auto* langMdl = static_cast*>(kw->getLangModel()); + const auto* langMdl = static_cast*>(kw->getLangModel()); const Morpheme* morphBase = kw->morphemes.data(); const auto spacePenalty = kw->spacePenalty; const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; @@ -132,9 +132,9 @@ namespace kiwi } else { - nextLmStates.resize(prevLmStates.size() * nextWids.size()); - scores.resize(prevLmStates.size() * nextWids.size()); - langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + nextLmStates.resize(prevLmStates.size() * nextWids.size()); + scores.resize(prevLmStates.size() * nextWids.size()); + langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); } } @@ -336,7 +336,7 @@ namespace kiwi } template - PcLangModel::PcLangModel(utils::MemoryObject&& mem) : PcLangModelBase{ mem } + CoNgramModel::CoNgramModel(utils::MemoryObject&& mem) : CoNgramModelBase{ mem } { auto* ptr = reinterpret_cast(mem.get()); @@ -589,7 +589,7 @@ namespace kiwi } template - float PcLangModel::progress(int32_t& nodeIdx, + float CoNgramModel::progress(int32_t& nodeIdx, uint32_t& contextIdx, std::array& history, KeyType next) const @@ -621,7 +621,6 @@ namespace kiwi getContextQuantEmb(0), contextIdcs, contextEmbStride(), getOutputQuantEmb(0), nextIdx, outputEmbStride(), &lls[1 + windowSize], 1); - for (size_t i = 0; i < 1 + windowSize; ++i) { lls[i] += lls[i + 1 + windowSize]; @@ -667,6 +666,7 @@ namespace kiwi { const auto* contextPtr = getContextQuantEmb(contextIdx); const auto* outputPtr = getOutputQuantEmb(next); + int32_t acc = qgemm::dotprod(contextPtr, outputPtr, header.dim); const float contextScale = *reinterpret_cast(contextPtr + header.dim), outputScale = *reinterpret_cast(outputPtr + header.dim), @@ -697,7 +697,7 @@ namespace kiwi } template - struct PcLangModel::TLSForProgressMatrix + struct CoNgramModel::TLSForProgressMatrix { Vector> contextCache; Vector contextIdcs, historyIdcs, nextIdcs; @@ -712,7 +712,7 @@ namespace kiwi // specialization for windowSize > 0 template template - inline auto PcLangModel::nextState( + inline auto CoNgramModel::nextState( const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next, bool cacheIsValid, pair& cache) const -> LmStateType { @@ -724,7 +724,7 @@ namespace kiwi } else { - ret.contextIdx = progressContextNode(ret.node, next); + ret.contextIdx = progressContextNode(ret.node, next); cache = std::make_pair(ret.node, ret.contextIdx); } @@ -743,7 +743,7 @@ namespace kiwi // specialization for windowSize == 0 template template - inline auto PcLangModel::nextState( + inline auto CoNgramModel::nextState( const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next, bool cacheIsValid, pair& cache) const -> LmStateType { @@ -755,7 +755,7 @@ namespace kiwi } else { - ret.contextIdx = progressContextNode(ret.node, next); + ret.contextIdx = progressContextNode(ret.node, next); cache = std::make_pair(ret.node, ret.contextIdx); } return ret; @@ -772,7 +772,7 @@ namespace kiwi } template - inline void PcLangModel::progressMatrixWSort( + inline void CoNgramModel::progressMatrixWSort( TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -948,13 +948,11 @@ namespace kiwi tls.confidenceBuf[uniqInputSize * 2 + i] = getDistantConfid(historyToken); } } + Eigen::Map resultMap{ tls.resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; if constexpr (quantized) { - //thread_local Map, size_t> shapeCnt; - //shapeCnt[make_pair(uniqInputSize + uniqHistorySize, uniqOutputSize)]++; - //ScopedTimer<> timer{ 0 }; qgemm::scatteredGEMMOpt( uniqInputSize + uniqHistorySize, uniqOutputSize, header.dim, getContextQuantEmb(0), tls.contextIdcs2.data(), contextEmbStride(), @@ -971,8 +969,8 @@ namespace kiwi pair contextCache; for (size_t j = 0; j < nextIdSize; ++j) { - for (size_t i = 0; i < prevStateSize; ++i) - { + for (size_t i = 0; i < prevStateSize; ++i) + { const auto& state = prevStates[i]; const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); @@ -1036,7 +1034,7 @@ namespace kiwi } template - inline void PcLangModel::progressMatrixWOSort( + inline void CoNgramModel::progressMatrixWOSort( TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -1047,7 +1045,7 @@ namespace kiwi { tls.contextIdcs2.clear(); tls.nextIdcs2.clear(); - } + } else { tls.inputEmbBuf.resize(prevStateSize * (1 + windowSize) * header.dim); @@ -1199,7 +1197,7 @@ namespace kiwi template template - void PcLangModel::progressMatrix( + void CoNgramModel::progressMatrix( const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -1217,7 +1215,7 @@ namespace kiwi template template - void PcLangModel::progressMatrix( + void CoNgramModel::progressMatrix( const typename std::enable_if<_windowSize == 0, LmStateType>::type* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const @@ -1313,9 +1311,9 @@ namespace kiwi pair contextCache; for (size_t j = 0; j < nextIdSize; ++j) { - for (size_t i = 0; i < prevStateSize; ++i) - { - const auto& state = prevStates[i]; + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); } @@ -1330,7 +1328,7 @@ namespace kiwi } } - utils::MemoryObject PcLangModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool useVLE, bool reorderContextId) + utils::MemoryObject CoNgramModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool useVLE, bool reorderContextId) { ifstream contextStr, embeddingStr; if (!openFile(contextStr, contextDefinition)) @@ -1573,8 +1571,8 @@ namespace kiwi distantMask.resize(compressedDistantMaskSize); } - PcLangModelHeader header; - memset(&header, 0, sizeof(PcLangModelHeader)); + CoNgramModelHeader header; + memset(&header, 0, sizeof(CoNgramModelHeader)); header.dim = dim; header.contextSize = contextSize; header.vocabSize = outputSize; @@ -1583,7 +1581,7 @@ namespace kiwi header.numNodes = nodeSizes.size(); size_t finalSize = 0; - header.nodeOffset = alignedOffsetInc(finalSize, sizeof(PcLangModelHeader)); + header.nodeOffset = alignedOffsetInc(finalSize, sizeof(CoNgramModelHeader)); header.keyOffset = alignedOffsetInc(finalSize, compressedNodeSizes.size()); header.valueOffset = alignedOffsetInc(finalSize, compressedKeys.size()); header.embOffset = alignedOffsetInc(finalSize, compressedValues.size()); @@ -1595,7 +1593,7 @@ namespace kiwi utils::MemoryOwner mem{ finalSize }; utils::omstream ostr{ (char*)mem.get(), (std::ptrdiff_t)mem.size() }; - ostr.write((const char*)&header, sizeof(PcLangModelHeader)); + ostr.write((const char*)&header, sizeof(CoNgramModelHeader)); writePadding(ostr); ostr.write((const char*)compressedNodeSizes.data(), compressedNodeSizes.size()); writePadding(ostr); @@ -1630,39 +1628,39 @@ namespace kiwi } template - void* PcLangModel::getFindBestPathFn() const + void* CoNgramModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder::findBestPath>; } template - void* PcLangModel::getNewJoinerFn() const + void* CoNgramModel::getNewJoinerFn() const { return (void*)&newJoinerWithKiwi; } template - inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) + inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) { - auto& header = *reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(mem.get()); if (!useDistantTokens) { - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); } switch (header.windowSize) { case 7: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; }; } template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { - auto& header = *reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(mem.get()); switch (header.keySize) { case 2: @@ -1688,7 +1686,7 @@ namespace kiwi }; }; - std::unique_ptr PcLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens, bool quantized) + std::unique_ptr CoNgramModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens, bool quantized) { static tp::Table tables[] = { CreateOptimizedModelGetter{}, diff --git a/src/PCLanguageModel.hpp b/src/CoNgramModel.hpp similarity index 86% rename from src/PCLanguageModel.hpp rename to src/CoNgramModel.hpp index 8bb995f9..68cbd62c 100644 --- a/src/PCLanguageModel.hpp +++ b/src/CoNgramModel.hpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include "ArchAvailable.h" #include "search.h" @@ -16,10 +16,10 @@ namespace kiwi namespace lm { template - class PcLMState; + class CoNgramState; template - class PcLangModel : public PcLangModelBase + class CoNgramModel : public CoNgramModelBase { using MyNode = Node; @@ -173,21 +173,21 @@ namespace kiwi public: using VocabType = KeyType; - using LmStateType = PcLMState; + using LmStateType = CoNgramState; - PcLangModel(utils::MemoryObject&& mem); + CoNgramModel(utils::MemoryObject&& mem); ModelType getType() const override { if (quantized) { - if (windowSize > 0) return ModelType::pclmQuantized; - else return ModelType::pclmLocalQuantized; + if (windowSize > 0) return ModelType::congGlobal; + else return ModelType::cong; } else { - if (windowSize > 0) return ModelType::pclm; - else return ModelType::pclmLocal; + if (windowSize > 0) return ModelType::congGlobalFp32; + else return ModelType::congFp32; } } void* getFindBestPathFn() const override; @@ -221,7 +221,6 @@ namespace kiwi auto* kvs = &alignedKeyValueData[node->nextOffset]; if (node != nodeData.get()) { - //ScopedTimer<> timer(node->numNexts <= 16 ? 0 : node->numNexts <= 272 ? 1 : 2); PREFETCH_T0(node + node->lower); if ((v = nst::searchKV( kvs, @@ -334,7 +333,7 @@ namespace kiwi }; template - struct PcLMState : public LmStateBase> + struct CoNgramState : public LmStateBase> { int32_t node = 0; uint32_t contextIdx; @@ -343,37 +342,37 @@ namespace kiwi static constexpr ArchType arch = _arch; static constexpr bool transposed = true; - PcLMState() : contextIdx{ 0 }, history { { 0, } } + CoNgramState() : contextIdx{ 0 }, history { { 0, } } { } - PcLMState(const ILangModel* lm) : contextIdx{ 0 }, history{ {0,} } + CoNgramState(const ILangModel* lm) : contextIdx{ 0 }, history{ {0,} } { } - PcLMState(int32_t _node) : node{ _node } // partially initialized state + CoNgramState(int32_t _node) : node{ _node } // partially initialized state { } - bool operator==(const PcLMState& other) const + bool operator==(const CoNgramState& other) const { static constexpr size_t cmpStart = windowSize / 2; if (node != other.node) return false; if (memcmp(&history[cmpStart], &other.history[cmpStart], (windowSize - cmpStart) * sizeof(VocabTy))) { - return false; - } + return false; + } return true; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const CoNgramModel* lm, VocabTy next) { return lm->progress(node, contextIdx, history, next); } }; template - struct PcLMState<0, _arch, VocabTy, VlVocabTy, quantized> : public LmStateBase> + struct CoNgramState<0, _arch, VocabTy, VlVocabTy, quantized> : public LmStateBase> { int32_t node = 0; uint32_t contextIdx; @@ -382,24 +381,24 @@ namespace kiwi static constexpr bool transposed = true; static constexpr size_t windowSize = 0; - PcLMState() : contextIdx{ 0 } + CoNgramState() : contextIdx{ 0 } { } - PcLMState(const ILangModel* lm) : contextIdx{ 0 } + CoNgramState(const ILangModel* lm) : contextIdx{ 0 } { } - PcLMState(int32_t _node) : node{ _node } // partially initialized state + CoNgramState(int32_t _node) : node{ _node } // partially initialized state { } - bool operator==(const PcLMState& other) const + bool operator==(const CoNgramState& other) const { return node == other.node; } - float nextImpl(const PcLangModel* lm, VocabTy next) + float nextImpl(const CoNgramModel* lm, VocabTy next) { std::array history = { {0,} }; return lm->progress(node, contextIdx, history, next); @@ -424,9 +423,9 @@ namespace kiwi }; template - struct Hash> + struct Hash> { - size_t operator()(const lm::PcLMState& state) const + size_t operator()(const lm::CoNgramState& state) const { size_t ret = Hash{}(state.node); static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); @@ -438,9 +437,9 @@ namespace kiwi }; template - struct Hash> + struct Hash> { - size_t operator()(const lm::PcLMState<0, arch, VocabTy, VlVocabTy, quantized>& state) const + size_t operator()(const lm::CoNgramState<0, arch, VocabTy, VlVocabTy, quantized>& state) const { size_t ret = Hash{}(state.node); return ret; diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 24a9459c..1398dbd5 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -16,7 +16,7 @@ #include "RaggedVector.hpp" #include "SkipBigramTrainer.hpp" #include "SkipBigramModel.hpp" -#include "PCLanguageModel.hpp" +#include "CoNgramModel.hpp" #include "SortUtils.hpp" using namespace std; @@ -791,12 +791,11 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio { langMdl = lm::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); } - else if (modelType == ModelType::pclm || modelType == ModelType::pclmLocal - || modelType == ModelType::pclmQuantized || modelType == ModelType::pclmLocalQuantized) + else if (ModelType::cong <= modelType && modelType <= ModelType::congGlobalFp32 ) { - langMdl = lm::PcLangModelBase::create(utils::MMap(modelPath + string{ "/pclm.mdl" }), archType, - (modelType == ModelType::pclm || modelType == ModelType::pclmQuantized), - (modelType == ModelType::pclmQuantized || modelType == ModelType::pclmLocalQuantized)); + langMdl = lm::CoNgramModelBase::create(utils::MMap(modelPath + string{ "/cong.mdl" }), archType, + (modelType == ModelType::congGlobal || modelType == ModelType::congGlobalFp32), + (modelType == ModelType::cong || modelType == ModelType::congGlobal)); } if (!!(options & BuildOption::loadDefaultDict)) diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index c9cefbc6..12511a85 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -24,10 +24,10 @@ const char* modelTypeToStr(ModelType type) case ModelType::knlm: return "knlm"; case ModelType::knlmTransposed: return "knlm-transposed"; case ModelType::sbg: return "sbg"; - case ModelType::pclm: return "pclm"; - case ModelType::pclmLocal: return "pclm-local"; - case ModelType::pclmQuantized: return "pclm-quant"; - case ModelType::pclmLocalQuantized: return "pclm-local-quant"; + case ModelType::cong: return "cong"; + case ModelType::congGlobal: return "cong-global"; + case ModelType::congFp32: return "cong-fp32"; + case ModelType::congGlobalFp32: return "cong-global-fp32"; } return "unknown"; } diff --git a/tools/pclm_builder.cpp b/tools/cong_builder.cpp similarity index 91% rename from tools/pclm_builder.cpp rename to tools/cong_builder.cpp index 8e825ff9..6dc9b38d 100644 --- a/tools/pclm_builder.cpp +++ b/tools/cong_builder.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include "toolUtils.h" @@ -16,8 +16,8 @@ int run(const std::string& morphemeDef, const std::string& contextDef, const std { tutils::Timer timer; KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); - auto ret = lm::PcLangModelBase::build(contextDef, embedding, maxLength, useVLE, reorderContextIdx); - ret.writeToFile(output + "/pclm.mdl"); + auto ret = lm::CoNgramModelBase::build(contextDef, embedding, maxLength, useVLE, reorderContextIdx); + ret.writeToFile(output + "/cong.mdl"); double tm = timer.getElapsed(); cout << "Total: " << tm << " ms " << endl; return 0; diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index 4530dd76..483771e8 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -66,21 +66,21 @@ int main(int argc, const char* argv[]) { kiwiModelType = ModelType::knlmTransposed; } - else if (v == "pclm") + else if (v == "cong") { - kiwiModelType = ModelType::pclm; + kiwiModelType = ModelType::cong; } - else if (v == "pclm-local") + else if (v == "cong-global") { - kiwiModelType = ModelType::pclmLocal; + kiwiModelType = ModelType::congGlobal; } - else if (v == "pclm-quant") + else if (v == "cong-fp32") { - kiwiModelType = ModelType::pclmQuantized; + kiwiModelType = ModelType::congFp32; } - else if (v == "pclm-local-quant") + else if (v == "cong-global-fp32") { - kiwiModelType = ModelType::pclmLocalQuantized; + kiwiModelType = ModelType::congGlobalFp32; } else { diff --git a/vsproj/build_pclm.vcxproj b/vsproj/build_cong.vcxproj similarity index 97% rename from vsproj/build_pclm.vcxproj rename to vsproj/build_cong.vcxproj index dd26dd5a..b2012b69 100644 --- a/vsproj/build_pclm.vcxproj +++ b/vsproj/build_cong.vcxproj @@ -32,6 +32,7 @@ Win32Proj KiwiRun 10.0 + build_cong @@ -131,6 +132,7 @@ NDEBUG;_CONSOLE;%(PreprocessorDefinitions) MultiThreaded /utf-8 %(AdditionalOptions) + stdcpp17 Console @@ -163,6 +165,7 @@ WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) MultiThreadedDebug /utf-8 %(AdditionalOptions) + stdcpp17 Console @@ -176,6 +179,7 @@ _DEBUG;_CONSOLE;%(PreprocessorDefinitions) MultiThreadedDebug /utf-8 %(AdditionalOptions) + stdcpp17 Console @@ -204,6 +208,7 @@ WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) MultiThreaded /utf-8 %(AdditionalOptions) + stdcpp17 Console @@ -217,7 +222,7 @@ - + diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index 94915842..dc2a5e0c 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -41,7 +41,7 @@ - + @@ -73,7 +73,7 @@ - + @@ -102,6 +102,14 @@ AdvancedVectorExtensions512 AdvancedVectorExtensions512 + + AdvancedVectorExtensions512 + AdvancedVectorExtensions512 + + + AdvancedVectorExtensions2 + AdvancedVectorExtensions2 + NotSet true @@ -126,7 +134,7 @@ - + From 1734e677bb7058f1fd912190767af5fe500ee22a Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 21 Feb 2025 01:18:38 +0900 Subject: [PATCH 23/53] Add automatic model type detection --- include/kiwi/Kiwi.h | 8 +++++++- include/kiwi/Utils.h | 1 + src/FileUtils.cpp | 15 +++++++++++++++ src/KiwiBuilder.cpp | 29 +++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index b3ebc9be..49834c6c 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -629,6 +629,12 @@ namespace kiwi void addAllomorphsToRule(); public: + + /** + * @brief 주어진 모델 경로로부터 모델의 타입을 추정한다. + */ + static ModelType getModelType(const std::string& modelPath); + /** * @brief KiwiBuilder의 기본 생성자 * @@ -667,7 +673,7 @@ namespace kiwi * @param numThreads 모델 및 형태소 분석에 사용할 스레드 개수 * @param options 생성 옵션. `kiwi::BuildOption`을 참조 */ - KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, ModelType modelType = ModelType::knlm); + KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, ModelType modelType = ModelType::none); /** * @brief 현재 KiwiBuilder 객체가 유효한 분석 모델을 로딩한 상태인지 알려준다. diff --git a/include/kiwi/Utils.h b/include/kiwi/Utils.h index b2a76e81..c3b1680d 100644 --- a/include/kiwi/Utils.h +++ b/include/kiwi/Utils.h @@ -388,5 +388,6 @@ namespace kiwi std::ifstream& openFile(std::ifstream& f, const std::string& filePath, std::ios_base::openmode mode = std::ios_base::in); std::ofstream& openFile(std::ofstream& f, const std::string& filePath, std::ios_base::openmode mode = std::ios_base::out); + bool isOpenable(const std::string& filePath); } diff --git a/src/FileUtils.cpp b/src/FileUtils.cpp index 07d219eb..b8376421 100644 --- a/src/FileUtils.cpp +++ b/src/FileUtils.cpp @@ -48,4 +48,19 @@ namespace kiwi f.exceptions(exc); return f; } + + bool isOpenable(const string& filePath) + { + ifstream ifs; + try + { + openFile(ifs, filePath); + } + catch (const IOException&) + { + return false; + } + return true; + } + } diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 1398dbd5..75bf48c0 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -772,6 +772,26 @@ void KiwiBuilder::saveMorphBin(std::ostream& os) const serializer::writeMany(os, serializer::toKey("KIWI"), forms, morphemes); } +ModelType KiwiBuilder::getModelType(const string& modelPath) +{ + if (isOpenable(modelPath + "/cong.mdl")) + { + return ModelType::congGlobal; + } + else if (isOpenable(modelPath + "/skipbigram.mdl")) + { + return ModelType::sbg; + } + else if (isOpenable(modelPath + "/sj.knlm")) + { + return ModelType::knlm; + } + else + { + return ModelType::none; + } +} + KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOption _options, ModelType _modelType) : detector{ modelPath, _numThreads }, options{ _options }, modelType{ _modelType }, numThreads{ _numThreads ? _numThreads : thread::hardware_concurrency() } { @@ -783,6 +803,15 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio loadMorphBin(iss); } + if (modelType == ModelType::none) + { + modelType = getModelType(modelPath); + if (modelType == ModelType::none) + { + throw runtime_error{ "Cannot find any valid model files in the given path" }; + } + } + if (modelType == ModelType::knlm || modelType == ModelType::knlmTransposed) { langMdl = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType, modelType == ModelType::knlmTransposed); From 37e0469181f2241f8442334ff503420accc99f2b Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 3 Mar 2025 19:52:26 +0900 Subject: [PATCH 24/53] Implement dispatcher of CoNgramModel --- include/kiwi/Joiner.h | 1 + src/BestPathContainer.hpp | 41 +- src/CoNgramModel.cpp | 93 ++- src/CoNgramModel.hpp | 2 +- src/MathFunc.h | 21 + src/MathFunc.hpp | 313 ++++++++++ src/PathEvaluator.hpp | 11 +- src/SIMD.hpp | 48 +- src/SkipBigramModelImpl.hpp | 283 +-------- src/archImpl/avx2.cpp | 117 ++++ src/archImpl/avx2_qgemm.hpp | 456 ++++++++++++++ src/archImpl/avx512_qgemm.hpp | 673 ++++++++++++++++++++ src/archImpl/avx512bw.cpp | 158 ++++- src/archImpl/avx512vnni.cpp | 127 +++- src/archImpl/avx_vnni.cpp | 93 +++ src/archImpl/none.cpp | 66 ++ src/archImpl/sse2.cpp | 41 ++ src/archImpl/sse4_1.cpp | 169 ++++++ src/gemm.h | 24 + src/qgemm.cpp | 1078 +-------------------------------- src/qgemm.hpp | 190 ++++++ src/search.cpp | 32 +- src/search.h | 1 + 23 files changed, 2599 insertions(+), 1439 deletions(-) create mode 100644 src/MathFunc.h create mode 100644 src/MathFunc.hpp create mode 100644 src/archImpl/avx2_qgemm.hpp create mode 100644 src/archImpl/avx512_qgemm.hpp create mode 100644 src/gemm.h create mode 100644 src/qgemm.hpp diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 98247929..6cd4af89 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -137,6 +137,7 @@ namespace kiwi { this->~ErasedVector(); new (this) ErasedVector{ other }; + return *this; } ErasedVector& operator=(ErasedVector&& other) diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp index d97f0fab..a8b86649 100644 --- a/src/BestPathContainer.hpp +++ b/src/BestPathContainer.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace kiwi { @@ -53,7 +54,7 @@ namespace kiwi bool equalTo(const LmState& lmState, uint8_t prevRootId, SpecialState spState) const { - return (this->prevRootId == prevRootId & this->spState == spState) && this->lmState == lmState; + return ((this->prevRootId == prevRootId) & (this->spState == spState)) && this->lmState == lmState; } bool operator==(const WordLL& o) const @@ -165,10 +166,10 @@ namespace kiwi const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { PathHash ph{ lmState, prevRootId, spState }; - auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); + auto inserted = bestPathIndex.emplace(ph, std::make_pair((uint32_t)bestPathValues.size(), 1)); if (inserted.second) { - bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); if (rootId != commonRootId) bestPathValues.back().rootId = rootId; bestPathValues.resize(bestPathValues.size() + topN - 1); } @@ -176,21 +177,21 @@ namespace kiwi { auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; - if (distance(bestPathFirst, bestPathLast) < topN) + if (std::distance(bestPathFirst, bestPathLast) < topN) { - *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; if (rootId != commonRootId) bestPathLast->rootId = rootId; - push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); + std::push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); ++inserted.first->second.second; } else { if (accScore > bestPathFirst->accScore) { - pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + std::pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; - push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + std::push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); } } } @@ -204,7 +205,7 @@ namespace kiwi const auto size = p.second.second; for (size_t i = 0; i < size; ++i) { - resultOut.emplace_back(move(bestPathValues[index + i])); + resultOut.emplace_back(std::move(bestPathValues[index + i])); auto& newPath = resultOut.back(); // fill the rest information of resultOut @@ -232,7 +233,7 @@ namespace kiwi inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) { - WordLL newPath{ morph, accScore, accTypoCost, parent, move(lmState), spState }; + WordLL newPath{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; newPath.prevRootId = prevRootId; if (rootId != commonRootId) newPath.rootId = rootId; auto inserted = bestPathes.emplace(newPath); @@ -252,7 +253,7 @@ namespace kiwi { for (auto& p : bestPathes) { - resultOut.emplace_back(move(p)); + resultOut.emplace_back(std::move(p)); auto& newPath = resultOut.back(); // fill the rest information of resultOut @@ -344,7 +345,7 @@ namespace kiwi if (value.size() < hash.size()) { hash[value.size()] = h; - value.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + value.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); value.back().prevRootId = prevRootId; if (rootId != commonRootId) value.back().rootId = rootId; } @@ -363,7 +364,7 @@ namespace kiwi target.accScore = accScore; target.accTypoCost = accTypoCost; target.parent = parent; - target.lmState = move(lmState); + target.lmState = std::move(lmState); target.spState = spState; target.rootId = parent ? parent->rootId : 0; if (rootId != commonRootId) target.rootId = rootId; @@ -377,7 +378,7 @@ namespace kiwi static constexpr ArchType archType = LmState::arch; if constexpr (archType != ArchType::none && archType != ArchType::balanced) { - return insertOptimized(topN, prevRootId, rootId, morph, accScore, accTypoCost, parent, move(lmState), spState); + return insertOptimized(topN, prevRootId, rootId, morph, accScore, accTypoCost, parent, std::move(lmState), spState); } const size_t h = Hash>{}(lmState, prevRootId, spState); @@ -386,7 +387,7 @@ namespace kiwi auto& value = values[bucket]; const auto hashEnd = hash.begin() + value.size(); - auto it = find(hash.begin(), hashEnd, (uint8_t)h); + auto it = std::find(hash.begin(), hashEnd, (uint8_t)h); while (it != hashEnd) { if (value[it - hash.begin()].equalTo(lmState, prevRootId, spState)) @@ -394,7 +395,7 @@ namespace kiwi break; } ++it; - it = find(it, hashEnd, (uint8_t)h); + it = std::find(it, hashEnd, (uint8_t)h); } if (it == hashEnd) @@ -402,7 +403,7 @@ namespace kiwi if (value.size() < hash.size()) { hash[value.size()] = h; - value.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); + value.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); value.back().prevRootId = prevRootId; if (rootId != commonRootId) value.back().rootId = rootId; } @@ -421,7 +422,7 @@ namespace kiwi target.accScore = accScore; target.accTypoCost = accTypoCost; target.parent = parent; - target.lmState = move(lmState); + target.lmState = std::move(lmState); target.spState = spState; target.rootId = parent ? parent->rootId : 0; if (rootId != commonRootId) target.rootId = rootId; @@ -435,7 +436,7 @@ namespace kiwi { for (auto& p : v) { - resultOut.emplace_back(move(p)); + resultOut.emplace_back(std::move(p)); auto& newPath = resultOut.back(); // fill the rest information of resultOut diff --git a/src/CoNgramModel.cpp b/src/CoNgramModel.cpp index ba209ae4..40dc38a7 100644 --- a/src/CoNgramModel.cpp +++ b/src/CoNgramModel.cpp @@ -7,6 +7,7 @@ #include "StrUtils.h" #include "FrozenTrie.hpp" #include "qgemm.h" +#include "gemm.h" using namespace std; @@ -134,7 +135,7 @@ namespace kiwi { nextLmStates.resize(prevLmStates.size() * nextWids.size()); scores.resize(prevLmStates.size() * nextWids.size()); - langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + langMdl->template progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); } } @@ -328,13 +329,6 @@ namespace kiwi } } - template - void logsoftmaxInplace(Arr& arr) - { - arr -= arr.maxCoeff(); - arr -= std::log(arr.exp().sum()); - } - template CoNgramModel::CoNgramModel(utils::MemoryObject&& mem) : CoNgramModelBase{ mem } { @@ -603,9 +597,9 @@ namespace kiwi { int32_t contextIdcs[1 + windowSize]; float lls[(1 + windowSize) * 2]; - int32_t nextIdx[1] = { next }; + int32_t nextIdx[1] = { (int32_t)next }; - copy(positionConfidPtr, positionConfidPtr + windowSize + 1, lls); + memcpy(lls, positionConfidPtr, (windowSize + 1) * sizeof(float)); lls[0] += getContextConfid(contextIdx); contextIdcs[0] = contextIdx; for (size_t i = 0; i < windowSize; ++i) @@ -614,8 +608,7 @@ namespace kiwi lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; contextIdcs[i + 1] = (historyToken ? historyToken : 0) + header.contextSize; } - LogSoftmax{}(lls, std::integral_constant()); - + logSoftmax(lls, windowSize + 1); qgemm::scatteredGEMMOpt( 1 + windowSize, 1, header.dim, getContextQuantEmb(0), contextIdcs, contextEmbStride(), @@ -626,7 +619,7 @@ namespace kiwi lls[i] += lls[i + 1 + windowSize]; } lls[0] -= getContextValidTokenSum(contextIdx); - ll = LogSumExp{}(lls, std::integral_constant()); + ll = logSumExp(lls, windowSize + 1); ll += getContextValidTokenSum(contextIdx); } else @@ -636,28 +629,32 @@ namespace kiwi thread_local Eigen::VectorXf lls; lls.resize(1 + windowSize); - lls = Eigen::Map{ positionConfidPtr, windowSize + 1 }; + memcpy(lls.data(), positionConfidPtr, (windowSize + 1) * sizeof(float)); lls[0] += getContextConfid(contextIdx); for (size_t i = 0; i < windowSize; ++i) { const auto historyToken = history[i]; lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; } - logsoftmaxInplace(lls.array()); + logSoftmax(lls.data(), windowSize + 1); - mat.col(0) = Eigen::Map{ getContextEmb(contextIdx), header.dim }; + memcpy(mat.col(0).data(), getContextEmb(contextIdx), header.dim * sizeof(float)); lls[0] += getContextBias(contextIdx); for (size_t i = 0; i < windowSize; ++i) { const auto historyToken = history[i]; - if (historyToken) mat.col(i + 1) = Eigen::Map{ getDistantEmb(historyToken), header.dim }; - else mat.col(i + 1).setZero(); + if (historyToken) memcpy(mat.col(i + 1).data(), getDistantEmb(historyToken), header.dim * sizeof(float)); + else memset(mat.col(i + 1).data(), 0, header.dim * sizeof(float)); lls[i + 1] += getDistantBias(historyToken); } lls.tail(windowSize).array() += getContextValidTokenSum(contextIdx); Eigen::Map outputVec{ getOutputEmb(next), header.dim }; - lls += mat.transpose() * outputVec; - ll = LogSumExp{}(lls.data(), std::integral_constant()); + gemm::template gemv( + mat.cols(), mat.rows(), + mat.data(), mat.colStride(), + outputVec.data(), lls.data() + ); + ll = logSumExp(lls.data(), windowSize + 1); } } else @@ -680,7 +677,11 @@ namespace kiwi ll = getContextBias(contextIdx); Eigen::Map contextVec{ getContextEmb(contextIdx), header.dim }; Eigen::Map outputVec{ getOutputEmb(next), header.dim }; - ll += (contextVec.transpose() * outputVec)[0]; + gemm::template gemv( + 1, header.dim, + contextVec.data(), contextVec.colStride(), + outputVec.data(), &ll + ); } } @@ -963,7 +964,12 @@ namespace kiwi { Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; - resultMap += outputMap.transpose() * inputMap; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); } pair contextCache; @@ -1005,8 +1011,8 @@ namespace kiwi } Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; - LogSoftmaxTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); - scoreMap.rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + logSoftmaxTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + scoreMap.template rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; for (size_t j = 0; j < batchSize; ++j) { const auto pIdx = (i + j) / numValidDistantTokens; @@ -1021,7 +1027,7 @@ namespace kiwi } } } - LogSumExpTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); + logSumExpTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); for (size_t j = 0; j < batchSize; ++j) { @@ -1125,7 +1131,12 @@ namespace kiwi { Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(prevStateSize + uniqHistorySize) }; Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)nextIdSize }; - resultMap += outputMap.transpose() * inputMap; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); } pair contextCache; @@ -1168,8 +1179,8 @@ namespace kiwi } Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; - LogSoftmaxTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); - scoreMap.rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + logSoftmaxTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + scoreMap.template rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; for (size_t j = 0; j < batchSize; ++j) { const auto pIdx = (i + j) / numValidDistantTokens; @@ -1184,7 +1195,7 @@ namespace kiwi } } } - LogSumExpTransposed{}(tls.scoreBuf.data(), batchSize, scoreBatchSize); + logSumExpTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); for (size_t j = 0; j < batchSize; ++j) { @@ -1305,7 +1316,12 @@ namespace kiwi { Eigen::Map inputMap{ inputEmbBuf.data(), header.dim, (Eigen::Index)uniqInputSize }; Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; - resultMap += outputMap.transpose() * inputMap; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); } pair contextCache; @@ -1328,6 +1344,23 @@ namespace kiwi } } + static constexpr size_t serialAlignment = 16; + inline size_t alignedOffsetInc(size_t& offset, size_t inc, size_t alignment = serialAlignment) + { + return offset = (offset + inc + alignment - 1) & ~(alignment - 1); + } + + inline std::ostream& writePadding(std::ostream& os, size_t alignment = serialAlignment) + { + const size_t pos = os.tellp(); + size_t pad = ((pos + alignment - 1) & ~(alignment - 1)) - pos; + for (size_t i = 0; i < pad; ++i) + { + os.put(0); + } + return os; + } + utils::MemoryObject CoNgramModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool useVLE, bool reorderContextId) { ifstream contextStr, embeddingStr; diff --git a/src/CoNgramModel.hpp b/src/CoNgramModel.hpp index 68cbd62c..0906863c 100644 --- a/src/CoNgramModel.hpp +++ b/src/CoNgramModel.hpp @@ -9,7 +9,7 @@ #include "ArchAvailable.h" #include "search.h" #include "streamvbyte.h" -#include "SkipBigramModelImpl.hpp" +#include "MathFunc.h" namespace kiwi { diff --git a/src/MathFunc.h b/src/MathFunc.h new file mode 100644 index 00000000..036a6cec --- /dev/null +++ b/src/MathFunc.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace kiwi +{ + namespace lm + { + template + float logSumExp(const float* arr, size_t size); + + template + void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + + template + void logSoftmax(float* arr, size_t size); + + template + void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } +} diff --git a/src/MathFunc.hpp b/src/MathFunc.hpp new file mode 100644 index 00000000..fbf68602 --- /dev/null +++ b/src/MathFunc.hpp @@ -0,0 +1,313 @@ +#pragma once +#include +#include "MathFunc.h" +#include "SIMD.hpp" + +namespace kiwi +{ + namespace lm + { + template + float logSumExpImpl(const float* arr) + { + simd::Operator op; + + auto pmax = op.loadf(arr); + for (size_t i = op.packetSize; i < size; i += op.packetSize) + { + pmax = op.maxf(pmax, op.loadf(&arr[i])); + } + pmax = op.redmaxbf(pmax); + + auto sum = op.zerof(); + for (size_t i = 0; i < size; i += op.packetSize) + { + sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); + } + return std::log(op.redsumf(sum)) + op.firstf(pmax); + } + + template + struct LogSumExp + { + template + float operator()(const float* arr, std::integral_constant) + { + return logSumExpImpl(arr); + } + }; + + template<> + struct LogSumExp + { + template + float operator()(const float* arr, std::integral_constant) + { + float maxValue = *std::max_element(arr, arr + size); + float sum = 0; + for (size_t i = 0; i < size; ++i) + { + sum += std::exp(arr[i] - maxValue); + } + return std::log(sum) + maxValue; + } + }; + + template<> + struct LogSumExp : public LogSumExp + { + }; + + template + void logSoftmaxImpl(float* arr) + { + simd::Operator op; + + auto pmax = op.loadf(arr); + for (size_t i = op.packetSize; i < size; i += op.packetSize) + { + pmax = op.maxf(pmax, op.loadf(&arr[i])); + } + pmax = op.redmaxbf(pmax); + + auto sum = op.zerof(); + for (size_t i = 0; i < size; i += op.packetSize) + { + sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); + } + pmax = op.addf(op.logf(op.set1f(op.redsumf(sum))), pmax); + for (size_t i = 0; i < size; i += op.packetSize) + { + op.storef(&arr[i], op.subf(op.loadf(&arr[i]), pmax)); + } + } + + template + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + return logSoftmaxImpl(arr); + } + }; + + template<> + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + float maxValue = *std::max_element(arr, arr + size); + float sum = 0; + for (size_t i = 0; i < size; ++i) + { + sum += std::exp(arr[i] - maxValue); + } + maxValue += std::log(sum); + for (size_t i = 0; i < size; ++i) + { + arr[i] -= maxValue; + } + } + }; + + template<> + struct LogSoftmax : public LogSoftmax + { + }; + + template + struct LogSoftmaxTransposed; + + template + struct LogSoftmaxTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + m = op.expf(a0); + m = op.addf(m, op.expf(a1)); + m = op.addf(m, op.expf(a2)); + m = op.addf(m, op.expf(a3)); + m = op.addf(m, op.expf(a4)); + m = op.addf(m, op.expf(a5)); + m = op.addf(m, op.expf(a6)); + m = op.addf(m, op.expf(a7)); + m = op.logf(m); + + // subtract + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + op.storef(arr, a0); + op.storef(arr + stride, a1); + op.storef(arr + stride * 2, a2); + op.storef(arr + stride * 3, a3); + op.storef(arr + stride * 4, a4); + op.storef(arr + stride * 5, a5); + op.storef(arr + stride * 6, a6); + op.storef(arr + stride * 7, a7); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSoftmaxTransposed : public LogSoftmaxTransposed + { + }; + + template<> + struct LogSoftmaxTransposed : public LogSoftmaxTransposed + { + }; + + template + struct LogSumExpTransposed; + + template + struct LogSumExpTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + auto s = op.expf(a0); + s = op.addf(s, op.expf(a1)); + s = op.addf(s, op.expf(a2)); + s = op.addf(s, op.expf(a3)); + s = op.addf(s, op.expf(a4)); + s = op.addf(s, op.expf(a5)); + s = op.addf(s, op.expf(a6)); + s = op.addf(s, op.expf(a7)); + s = op.logf(s); + + op.storef(arr, op.addf(m, s)); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSumExpTransposed : public LogSumExpTransposed + { + }; + + template<> + struct LogSumExpTransposed : public LogSumExpTransposed + { + }; + + template + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template + void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride) + { + if (size == 8) return LogSumExpTransposed{}(arr, batchSize, stride); + throw std::runtime_error("Unsupported size"); + } + + template + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template + void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride) + { + if (size == 8) return LogSoftmaxTransposed{}(arr, batchSize, stride); + throw std::runtime_error("Unsupported size"); + } + } +} diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 86bbb94f..fccfb34d 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -640,6 +640,7 @@ namespace kiwi nextWids.clear(); size_t prevId = -1; + size_t length; for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { for (auto& prevPath : cache[prev - startNode]) @@ -677,7 +678,7 @@ namespace kiwi if (!formEvaluator(curMorph, ignoreCondScore, candScore)) continue; - size_t length = 0; + length = 0; if (curMorph->combineSocket && curMorph->isSingle()) { // no op @@ -935,25 +936,25 @@ namespace kiwi MorphemeEvaluator me; if (topN > 1) { - me.eval(nCache, kw, ownFormList, cache, + me.template eval(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else if (totalPrevPathes <= BestPathContainerTraits::maxSize) { - me.eval(nCache, kw, ownFormList, cache, + me.template eval(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else if (totalPrevPathes <= BestPathContainerTraits::maxSize) { - me.eval(nCache, kw, ownFormList, cache, + me.template eval(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else { - me.eval(nCache, kw, ownFormList, cache, + me.template eval(nCache, kw, ownFormList, cache, ownFormId, validMorphCands, node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } diff --git a/src/SIMD.hpp b/src/SIMD.hpp index e3bada00..6b9c2a71 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -11,6 +11,15 @@ #define STRONG_INLINE inline #endif +#if defined(_MSC_VER) +#define FORCE_INLINE __forceinline +#elif defined(__GNUC__) +#define FORCE_INLINE __attribute__((always_inline)) +#else +#define FORCE_INLINE inline +#endif + + #include "ArchAvailable.h" namespace kiwi @@ -370,9 +379,10 @@ namespace kiwi { return _mm_blendv_ps(b, a, mask); } + static STRONG_INLINE __m128i select(__m128i mask, __m128i a, __m128i b) { - return _mm_blendv_epi32(b, a, mask); + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b), _mm_castsi128_ps(a), _mm_castsi128_ps(mask))); } static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) @@ -457,8 +467,14 @@ namespace kiwi static STRONG_INLINE __m256 bor(__m256 a, __m256 b) { return _mm256_or_ps(a, b); } static STRONG_INLINE __m256i bor(__m256i a, __m256i b) { return _mm256_or_si256(a, b); } - static STRONG_INLINE __m256 select(__m256 mask, __m256 a, __m256 b) { return _mm256_blendv_ps(b, a, mask); } - static STRONG_INLINE __m256i select(__m256i mask, __m256i a, __m256i b) { return _mm256_blendv_epi32(b, a, mask); } + static STRONG_INLINE __m256 select(__m256 mask, __m256 a, __m256 b) + { + return _mm256_blendv_ps(b, a, mask); + } + static STRONG_INLINE __m256i select(__m256i mask, __m256i a, __m256i b) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(b), _mm256_castsi256_ps(a), _mm256_castsi256_ps(mask))); + } static STRONG_INLINE __m256 cmp_eq(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_EQ_OQ); } static STRONG_INLINE __m256 cmp_le(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_LE_OQ); } @@ -794,6 +810,32 @@ namespace kiwi { return set1f(redmaxf(a)); } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + int32x4_t pa, pb, sum = vdupq_n_s32(0); + for (size_t i = 0; i < size; i += 16) + { + pa = vreinterpretq_s32_u32(vmovl_u16(vld1_u16(reinterpret_cast(a + i)))); + pb = vreinterpretq_s32_s8(vld1_s8(reinterpret_cast(b + i))); + sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(vreinterpretq_s8_s32(pa)), vget_low_s8(pb))); + sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(vreinterpretq_s8_s32(pa)), vget_high_s8(pb))); + } + return vgetq_lane_s32(vpadd_s32(vpadd_s32(sum, sum), sum), 0); + } + + static STRONG_INLINE int32_t dotprod(const int8_t* a, const int8_t* b, size_t size) + { + int32x4_t pa, pb, sum = vdupq_n_s32(0); + for (size_t i = 0; i < size; i += 16) + { + pa = vreinterpretq_s32_s8(vld1q_s8(a + i)); + pb = vreinterpretq_s32_s8(vld1q_s8(b + i)); + sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(pa), vget_low_s8(pb))); + sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(pa), vget_high_s8(pb))); + } + return vgetq_lane_s32(vpadd_s32(vpadd_s32(sum, sum), sum), 0); + } }; template<> diff --git a/src/SkipBigramModelImpl.hpp b/src/SkipBigramModelImpl.hpp index 404ea6da..2310e1d0 100644 --- a/src/SkipBigramModelImpl.hpp +++ b/src/SkipBigramModelImpl.hpp @@ -2,293 +2,12 @@ #include #include "SkipBigramModel.hpp" -#include "SIMD.hpp" +#include "MathFunc.hpp" namespace kiwi { namespace lm { - template - struct LogSumExp - { - template - float operator()(const float* arr, std::integral_constant) - { - return logSumExpImpl(arr); - } - }; - - template<> - struct LogSumExp - { - template - float operator()(const float* arr, std::integral_constant) - { - float maxValue = *std::max_element(arr, arr + size); - float sum = 0; - for (size_t i = 0; i < size; ++i) - { - sum += std::exp(arr[i] - maxValue); - } - return std::log(sum) + maxValue; - } - }; - - template<> - struct LogSumExp : public LogSumExp - { - }; - - template - float logSumExpImpl(const float* arr) - { - if ((archType == ArchType::avx512bw || archType == ArchType::avx512vnni) && size < 16) - { - return logSumExpImpl(arr); - } - simd::Operator op; - - auto pmax = op.loadf(arr); - for (size_t i = op.packetSize; i < size; i += op.packetSize) - { - pmax = op.maxf(pmax, op.loadf(&arr[i])); - } - pmax = op.redmaxbf(pmax); - - auto sum = op.zerof(); - for (size_t i = 0; i < size; i += op.packetSize) - { - sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); - } - return std::log(op.redsumf(sum)) + op.firstf(pmax); - } - - template - struct LogSoftmax - { - template - void operator()(float* arr, std::integral_constant) - { - return logSoftmaxImpl(arr); - } - }; - - template<> - struct LogSoftmax - { - template - void operator()(float* arr, std::integral_constant) - { - float maxValue = *std::max_element(arr, arr + size); - float sum = 0; - for (size_t i = 0; i < size; ++i) - { - sum += std::exp(arr[i] - maxValue); - } - maxValue += std::log(sum); - for (size_t i = 0; i < size; ++i) - { - arr[i] -= maxValue; - } - } - }; - - template<> - struct LogSoftmax : public LogSoftmax - { - }; - - template - void logSoftmaxImpl(float* arr) - { - if ((archType == ArchType::avx512bw || archType == ArchType::avx512vnni) && size < 16) - { - return logSoftmaxImpl(arr); - } - simd::Operator op; - - auto pmax = op.loadf(arr); - for (size_t i = op.packetSize; i < size; i += op.packetSize) - { - pmax = op.maxf(pmax, op.loadf(&arr[i])); - } - pmax = op.redmaxbf(pmax); - - auto sum = op.zerof(); - for (size_t i = 0; i < size; i += op.packetSize) - { - sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); - } - pmax = op.addf(op.logf(op.set1f(op.redsumf(sum))), pmax); - for (size_t i = 0; i < size; i += op.packetSize) - { - op.storef(&arr[i], op.subf(op.loadf(&arr[i]), pmax)); - } - } - - template - struct LogSoftmaxTransposed; - - template - struct LogSoftmaxTransposed - { - static constexpr size_t size = 8; - - void block(float* arr, size_t stride) - { - simd::Operator op; - simd::FloatPacket a0 = op.loadf(arr), - a1 = op.loadf(arr + stride), - a2 = op.loadf(arr + stride * 2), - a3 = op.loadf(arr + stride * 3), - a4 = op.loadf(arr + stride * 4), - a5 = op.loadf(arr + stride * 5), - a6 = op.loadf(arr + stride * 6), - a7 = op.loadf(arr + stride * 7); - // find maximum - auto m = op.maxf(a0, a1); - m = op.maxf(m, a2); - m = op.maxf(m, a3); - m = op.maxf(m, a4); - m = op.maxf(m, a5); - m = op.maxf(m, a6); - m = op.maxf(m, a7); - - // subtract maximum - a0 = op.subf(a0, m); - a1 = op.subf(a1, m); - a2 = op.subf(a2, m); - a3 = op.subf(a3, m); - a4 = op.subf(a4, m); - a5 = op.subf(a5, m); - a6 = op.subf(a6, m); - a7 = op.subf(a7, m); - - // exp, reduce sum and log - m = op.expf(a0); - m = op.addf(m, op.expf(a1)); - m = op.addf(m, op.expf(a2)); - m = op.addf(m, op.expf(a3)); - m = op.addf(m, op.expf(a4)); - m = op.addf(m, op.expf(a5)); - m = op.addf(m, op.expf(a6)); - m = op.addf(m, op.expf(a7)); - m = op.logf(m); - - // subtract - a0 = op.subf(a0, m); - a1 = op.subf(a1, m); - a2 = op.subf(a2, m); - a3 = op.subf(a3, m); - a4 = op.subf(a4, m); - a5 = op.subf(a5, m); - a6 = op.subf(a6, m); - a7 = op.subf(a7, m); - - op.storef(arr, a0); - op.storef(arr + stride, a1); - op.storef(arr + stride * 2, a2); - op.storef(arr + stride * 3, a3); - op.storef(arr + stride * 4, a4); - op.storef(arr + stride * 5, a5); - op.storef(arr + stride * 6, a6); - op.storef(arr + stride * 7, a7); - } - - void operator()(float* arr, size_t batchSize, size_t stride) - { - simd::Operator op; - for (size_t i = 0; i < batchSize; i += op.packetSize) - { - block(arr, stride); - arr += op.packetSize; - } - } - }; - - template<> - struct LogSoftmaxTransposed : public LogSoftmaxTransposed - { - }; - - template<> - struct LogSoftmaxTransposed : public LogSoftmaxTransposed - { - }; - - template - struct LogSumExpTransposed; - - template - struct LogSumExpTransposed - { - static constexpr size_t size = 8; - - void block(float* arr, size_t stride) - { - simd::Operator op; - simd::FloatPacket a0 = op.loadf(arr), - a1 = op.loadf(arr + stride), - a2 = op.loadf(arr + stride * 2), - a3 = op.loadf(arr + stride * 3), - a4 = op.loadf(arr + stride * 4), - a5 = op.loadf(arr + stride * 5), - a6 = op.loadf(arr + stride * 6), - a7 = op.loadf(arr + stride * 7); - // find maximum - auto m = op.maxf(a0, a1); - m = op.maxf(m, a2); - m = op.maxf(m, a3); - m = op.maxf(m, a4); - m = op.maxf(m, a5); - m = op.maxf(m, a6); - m = op.maxf(m, a7); - - // subtract maximum - a0 = op.subf(a0, m); - a1 = op.subf(a1, m); - a2 = op.subf(a2, m); - a3 = op.subf(a3, m); - a4 = op.subf(a4, m); - a5 = op.subf(a5, m); - a6 = op.subf(a6, m); - a7 = op.subf(a7, m); - - // exp, reduce sum and log - auto s = op.expf(a0); - s = op.addf(s, op.expf(a1)); - s = op.addf(s, op.expf(a2)); - s = op.addf(s, op.expf(a3)); - s = op.addf(s, op.expf(a4)); - s = op.addf(s, op.expf(a5)); - s = op.addf(s, op.expf(a6)); - s = op.addf(s, op.expf(a7)); - s = op.logf(s); - - op.storef(arr, op.addf(m, s)); - } - - void operator()(float* arr, size_t batchSize, size_t stride) - { - simd::Operator op; - for (size_t i = 0; i < batchSize; i += op.packetSize) - { - block(arr, stride); - arr += op.packetSize; - } - } - }; - - template<> - struct LogSumExpTransposed : public LogSumExpTransposed - { - }; - - template<> - struct LogSumExpTransposed : public LogSumExpTransposed - { - }; - template float SkipBigramModel::evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const { diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 76e0c1d4..0bc09f8a 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -1,4 +1,27 @@ #include "../SkipBigramModelImpl.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenAVX2 +#include + +namespace kiwi +{ + namespace qgemm + { + // emulate _mm256_dpbusd_epi32 using AVX2 + static FORCE_INLINE __m256i dpbusd(__m256i src, __m256i a, __m256i b) + { + __m256i one16 = _mm256_set1_epi16(1); + __m256i t0 = _mm256_maddubs_epi16(a, b); + __m256i t1 = _mm256_madd_epi16(t0, one16); + return _mm256_add_epi32(src, t1); + } + } +} + +#define DPBUSD dpbusd +#include "avx2_qgemm.hpp" namespace kiwi { @@ -8,5 +31,99 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_256(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_256(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/archImpl/avx2_qgemm.hpp b/src/archImpl/avx2_qgemm.hpp new file mode 100644 index 00000000..59fb89ee --- /dev/null +++ b/src/archImpl/avx2_qgemm.hpp @@ -0,0 +1,456 @@ +#pragma once +#include "../qgemm.hpp" +#include + +namespace kiwi +{ + namespace qgemm + { + inline void pack4x32to4x8x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m256i& p0, __m256i& p1, __m256i& p2, __m256i& p3 + ) + { + __m256i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm256_loadu_si256((const __m256i*)a0); + p1 = _mm256_loadu_si256((const __m256i*)a1); + p2 = _mm256_loadu_si256((const __m256i*)a2); + p3 = _mm256_loadu_si256((const __m256i*)a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm256_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm256_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm256_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm256_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm256_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm256_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm256_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm256_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 4, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m256i pa[4], pb, pbs, psum, pbSum; + __m128 paScale, paBias, pbScale, r; + pbScale = _mm_set1_ps(bScale); + pbSum = _mm256_set1_epi32(-bSum / 2); + __m128i pr; + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + pr = _mm_add_epi32(_mm256_castsi256_si128(psum), _mm256_extracti128_si256(psum, 1)); + aIdx += 4; + + paScale = _mm_loadu_ps(aScale); + paBias = _mm_loadu_ps(aBias); + r = _mm_fmadd_ps(_mm_mul_ps(_mm_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm_storeu_ps(c, r); + c += microM; + } + } + + inline void scatteredGEMV8x1_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 8, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m256i pa[4], pb, pbs, psum, pbSum; + __m256 paScale, paBias, pbScale, r; + __m256i pr = _mm256_setzero_si256(); + pbScale = _mm256_set1_ps(bScale); + pbSum = _mm256_set1_epi32(-bSum / 2); + + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + pr = _mm256_add_epi32(psum, _mm256_castsi128_si256(_mm256_extracti128_si256(psum, 1))); + aIdx += 4; + } + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm256_add_epi32(psum, _mm256_castsi128_si256(_mm256_extracti128_si256(psum, 1))); + pr = _mm256_inserti128_si256(pr, _mm256_castsi256_si128(psum), 1); + aIdx += 4; + } + + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + c += 8; + } + + inline void scatteredGEMV2_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 2, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 2; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + + __m256i pa[4], pb[2], psum[2], pbSum[2], pt[2]; + __m256 paScale, paBias, pbScale, r; + __m256i pr; + pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); + pbSum[0] = _mm256_set1_epi32(-bSum[0] / 2); + pbSum[1] = _mm256_set1_epi32(-bSum[1] / 2); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pb[1] = _mm256_loadu_si256((const __m256i*)(bBuffer + j + packK)); + psum[0] = DPBUSD(psum[0], pa[0], _mm256_shuffle_epi32(pb[0], 0x00)); + psum[0] = DPBUSD(psum[0], pa[1], _mm256_shuffle_epi32(pb[0], 0x55)); + psum[0] = DPBUSD(psum[0], pa[2], _mm256_shuffle_epi32(pb[0], 0xAA)); + psum[0] = DPBUSD(psum[0], pa[3], _mm256_shuffle_epi32(pb[0], 0xFF)); + psum[1] = DPBUSD(psum[1], pa[0], _mm256_shuffle_epi32(pb[1], 0x00)); + psum[1] = DPBUSD(psum[1], pa[1], _mm256_shuffle_epi32(pb[1], 0x55)); + psum[1] = DPBUSD(psum[1], pa[2], _mm256_shuffle_epi32(pb[1], 0xAA)); + psum[1] = DPBUSD(psum[1], pa[3], _mm256_shuffle_epi32(pb[1], 0xFF)); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm256_add_epi32(psum[0], _mm256_castsi128_si256(_mm256_extracti128_si256(psum[0], 1))); + psum[1] = _mm256_add_epi32(psum[1], _mm256_castsi128_si256(_mm256_extracti128_si256(psum[1], 1))); + + // 00, 01, 10, 11, ... + pt[0] = _mm256_unpacklo_epi32(psum[0], psum[1]); + // 20, 21, 30, 31, ... + pt[1] = _mm256_unpackhi_epi32(psum[0], psum[1]); + + // 00, 01, 10, 11, 20, 21, 30, 31 + pr = _mm256_inserti128_si256(pt[0], _mm256_castsi256_si128(pt[1]), 1); + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + + aIdx += microM; + c += microM * 2; + } + } + + inline int32_t reduce_sum(__m128i x) + { + __m128i hi64 = _mm_unpackhi_epi64(x, x); + __m128i sum64 = _mm_add_epi32(hi64, x); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); + } + + inline int32_t reduce_sum(__m256i x) + { + __m128i sum128 = _mm_add_epi32( + _mm256_castsi256_si128(x), + _mm256_extracti128_si256(x, 1)); + return reduce_sum(sum128); + } + + template + inline void scatteredGEMMSmall_256(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + static_assert(m <= 3, "m should be less than or equal to 3"); + static_assert(n <= 3, "n should be less than or equal to 3"); + __m256i pa[3], pb[3], psum[3][3]; + const uint8_t* aPtr[3]; + const int8_t* bPtr[3]; + + psum[0][0] = _mm256_setzero_si256(); + if (m > 1) psum[1][0] = _mm256_setzero_si256(); + if (m > 2) psum[2][0] = _mm256_setzero_si256(); + if (n > 1) psum[0][1] = _mm256_setzero_si256(); + if (m > 1 && n > 1) psum[1][1] = _mm256_setzero_si256(); + if (m > 2 && n > 1) psum[2][1] = _mm256_setzero_si256(); + if (n > 2) psum[0][2] = _mm256_setzero_si256(); + if (m > 1 && n > 2) psum[1][2] = _mm256_setzero_si256(); + if (m > 2 && n > 2) psum[2][2] = _mm256_setzero_si256(); + + aPtr[0] = aBase + aIdx[0] * aIdxScale; + if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; + if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; + + bPtr[0] = bBase + bIdx[0] * bIdxScale; + if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; + if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; + + for (size_t x = 0; x < k; x += 32) + { + if (m > 0) + { + pa[0] = _mm256_loadu_si256((const __m256i*)aPtr[0]); + aPtr[0] += 32; + } + if (m > 1) + { + pa[1] = _mm256_loadu_si256((const __m256i*)aPtr[1]); + aPtr[1] += 32; + } + if (m > 2) + { + pa[2] = _mm256_loadu_si256((const __m256i*)aPtr[2]); + aPtr[2] += 32; + } + + if (n > 0) + { + pb[0] = _mm256_loadu_si256((const __m256i*)bPtr[0]); + bPtr[0] += 32; + } + if (n > 1) + { + pb[1] = _mm256_loadu_si256((const __m256i*)bPtr[1]); + bPtr[1] += 32; + } + if (n > 2) + { + pb[2] = _mm256_loadu_si256((const __m256i*)bPtr[2]); + bPtr[2] += 32; + } + + psum[0][0] = DPBUSD(psum[0][0], pa[0], pb[0]); + if (m > 1) psum[1][0] = DPBUSD(psum[1][0], pa[1], pb[0]); + if (m > 2) psum[2][0] = DPBUSD(psum[2][0], pa[2], pb[0]); + if (n > 1) psum[0][1] = DPBUSD(psum[0][1], pa[0], pb[1]); + if (m > 1 && n > 1) psum[1][1] = DPBUSD(psum[1][1], pa[1], pb[1]); + if (m > 2 && n > 1) psum[2][1] = DPBUSD(psum[2][1], pa[2], pb[1]); + if (n > 2) psum[0][2] = DPBUSD(psum[0][2], pa[0], pb[2]); + if (m > 1 && n > 2) psum[1][2] = DPBUSD(psum[1][2], pa[1], pb[2]); + if (m > 2 && n > 2) psum[2][2] = DPBUSD(psum[2][2], pa[2], pb[2]); + } + + float contextScale[3], outputScale[3], contextBias[3]; + int32_t hsum[3]; + + if (m > 0) + { + contextScale[0] = *reinterpret_cast(aPtr[0]); + contextBias[0] = *reinterpret_cast(aPtr[0] + 4); + } + if (m > 1) + { + contextScale[1] = *reinterpret_cast(aPtr[1]); + contextBias[1] = *reinterpret_cast(aPtr[1] + 4); + } + if (m > 2) + { + contextScale[2] = *reinterpret_cast(aPtr[2]); + contextBias[2] = *reinterpret_cast(aPtr[2] + 4); + } + + + if (n > 0) + { + outputScale[0] = *reinterpret_cast(bPtr[0]); + hsum[0] = *reinterpret_cast(bPtr[0] + 4); + } + if (n > 1) + { + outputScale[1] = *reinterpret_cast(bPtr[1]); + hsum[1] = *reinterpret_cast(bPtr[1] + 4); + } + if (n > 2) + { + outputScale[2] = *reinterpret_cast(bPtr[2]); + hsum[2] = *reinterpret_cast(bPtr[2] + 4); + } + + { + int32_t acc = reduce_sum(psum[0][0]); + c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; + } + if (m > 1) + { + int32_t acc = reduce_sum(psum[1][0]); + c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; + } + if (m > 2) + { + int32_t acc = reduce_sum(psum[2][0]); + c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; + } + if (n > 1) + { + int32_t acc = reduce_sum(psum[0][1]); + c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; + } + if (m > 1 && n > 1) + { + int32_t acc = reduce_sum(psum[1][1]); + c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; + } + if (m > 2 && n > 1) + { + int32_t acc = reduce_sum(psum[2][1]); + c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; + } + if (n > 2) + { + int32_t acc = reduce_sum(psum[0][2]); + c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; + } + if (m > 1 && n > 2) + { + int32_t acc = reduce_sum(psum[1][2]); + c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; + } + if (m > 2 && n > 2) + { + int32_t acc = reduce_sum(psum[2][2]); + c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; + } + } + } +} diff --git a/src/archImpl/avx512_qgemm.hpp b/src/archImpl/avx512_qgemm.hpp new file mode 100644 index 00000000..55f557f4 --- /dev/null +++ b/src/archImpl/avx512_qgemm.hpp @@ -0,0 +1,673 @@ +#pragma once +#include "../qgemm.hpp" +#include + +#define UNROLL4() do { {LOOP_BODY(0)} {LOOP_BODY(1)} {LOOP_BODY(2)} {LOOP_BODY(3)} } while(0) + +namespace kiwi +{ + namespace qgemm + { + inline void pack4x64to4x16x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m512i& p0, __m512i& p1, __m512i& p2, __m512i& p3 + ) + { + __m512i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm512_loadu_si512(a0); + p1 = _mm512_loadu_si512(a1); + p2 = _mm512_loadu_si512(a2); + p3 = _mm512_loadu_si512(a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm512_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm512_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm512_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm512_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm512_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm512_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm512_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm512_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 16, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum, pr = _mm512_setzero_si512(); + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); +#define LOOP_BODY(mj) \ + const int32_t aOffsets[4] = {\ + mj * 4 < microM ? aIdx[0] * aIdxScale : 0,\ + mj * 4 + 1 < microM ? aIdx[1] * aIdxScale : 0,\ + mj * 4 + 2 < microM ? aIdx[2] * aIdxScale : 0,\ + mj * 4 + 3 < microM ? aIdx[3] * aIdxScale : 0,\ + };\ + auto* aPtr = aBase;\ + psum = pbSum;\ + for (size_t j = 0; j < k; j += 64)\ + {\ + pack4x64to4x16x4(aPtr + aOffsets[0],\ + aPtr + aOffsets[1],\ + aPtr + aOffsets[2],\ + aPtr + aOffsets[3],\ + pa[0], pa[1], pa[2], pa[3]);\ + pb = _mm512_loadu_si512(bBuffer + j);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA);\ + psum = DPBUSD(psum, pa[0], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB);\ + psum = DPBUSD(psum, pa[1], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC);\ + psum = DPBUSD(psum, pa[2], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD);\ + psum = DPBUSD(psum, pa[3], pbs);\ + aPtr += 64;\ + }\ + for (size_t i = 0; i < 4; ++i)\ + {\ + aScale[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i]);\ + aBias[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i] + 4);\ + }\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4));\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8));\ + pr = _mm512_inserti32x4(pr, _mm512_castsi512_si128(psum), mj);\ + aIdx += 4; + + UNROLL4(); +#undef LOOP_BODY + + paScale = _mm512_loadu_ps(aScale); + paBias = _mm512_loadu_ps(aBias); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + c += microM; + } + } + + inline void scatteredGEMV8x1_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 8, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum; + __m256 paScale, paBias, pbScale, r; + __m256i pr = _mm256_setzero_si256(); + pbScale = _mm256_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum / 4); + + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm512_castsi512_si256(psum); + aIdx += 4; + } + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm256_inserti32x4(pr, _mm512_castsi512_si128(psum), 1); + aIdx += 4; + } + + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + c += 8; + } + + inline void scatteredGEMV2_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 2, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 2; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + + __m512i pa[4], pb[2], psum[2], pbSum[2], pt[2]; + __m256 paScale, paBias, pbScale, r; + __m256i pr; + pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + + // 00, 01, 10, 11, ... + pt[0] = _mm512_unpacklo_epi32(psum[0], psum[1]); + // 20, 21, 30, 31, ... + pt[1] = _mm512_unpackhi_epi32(psum[0], psum[1]); + + // 00, 01, 10, 11, 20, 21, 30, 31 + pr = _mm256_permute2x128_si256(_mm512_castsi512_si256(pt[0]), _mm512_castsi512_si256(pt[1]), 0x20); + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + + aIdx += microM; + c += microM * 2; + } + } + + inline void scatteredGEMV3_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 3, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + float bScale[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + 0 + }; + int32_t bSum[3] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4) + }; + __m512i pa[4], pb[3], psum[3], pbSum[3]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_permutexvar_ps( + _mm512_setr_epi32(0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 3, 3), + _mm512_castps128_ps512(_mm_loadu_ps(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + __m512i shfIdxT = _mm512_setr_epi32(0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + psum[2] = pbSum[2]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + psum[2] = DPBUSD(psum[2], pa[0], _mm512_shuffle_epi32(pb[2], _MM_PERM_AAAA)); + psum[2] = DPBUSD(psum[2], pa[1], _mm512_shuffle_epi32(pb[2], _MM_PERM_BBBB)); + psum[2] = DPBUSD(psum[2], pa[2], _mm512_shuffle_epi32(pb[2], _MM_PERM_CCCC)); + psum[2] = DPBUSD(psum[2], pa[3], _mm512_shuffle_epi32(pb[2], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); + + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); + + // 00, 01, 02, 10, 11, 12, 20, 21, 22, 30, 31, 32, ... + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 1, 5, 17, 2, 6, 18, 3, 7, 19, 0, 0, 0, 0 + ), psum[2]); + + paScale = _mm512_castps128_ps512(_mm_loadu_ps(aScale)); + paScale = _mm512_permutexvar_ps(shfIdxT, paScale); + paBias = _mm512_castps128_ps512(_mm_loadu_ps(aBias)); + paBias = _mm512_permutexvar_ps(shfIdxT, paBias); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_mask_storeu_ps(c, 0x0FFF, r); + + aIdx += microM; + c += microM * 3; + } + } + + inline void scatteredGEMV4_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 4, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 4; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + memcpy(bBuffer + packK * 3, bBase + bIdx[3] * bIdxScale, k); + float bScale[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k) + }; + int32_t bSum[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k + 4) + }; + __m512i pa[4], pb[4], psum[4], pbSum[4]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_broadcast_f32x4(_mm_loadu_ps(bScale)); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + pbSum[3] = _mm512_set1_epi32(-bSum[3] / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + psum[2] = pbSum[2]; + psum[3] = pbSum[3]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + pb[3] = _mm512_loadu_si512(bBuffer + packK * 3 + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + psum[2] = DPBUSD(psum[2], pa[0], _mm512_shuffle_epi32(pb[2], _MM_PERM_AAAA)); + psum[2] = DPBUSD(psum[2], pa[1], _mm512_shuffle_epi32(pb[2], _MM_PERM_BBBB)); + psum[2] = DPBUSD(psum[2], pa[2], _mm512_shuffle_epi32(pb[2], _MM_PERM_CCCC)); + psum[2] = DPBUSD(psum[2], pa[3], _mm512_shuffle_epi32(pb[2], _MM_PERM_DDDD)); + psum[3] = DPBUSD(psum[3], pa[0], _mm512_shuffle_epi32(pb[3], _MM_PERM_AAAA)); + psum[3] = DPBUSD(psum[3], pa[1], _mm512_shuffle_epi32(pb[3], _MM_PERM_BBBB)); + psum[3] = DPBUSD(psum[3], pa[2], _mm512_shuffle_epi32(pb[3], _MM_PERM_CCCC)); + psum[3] = DPBUSD(psum[3], pa[3], _mm512_shuffle_epi32(pb[3], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 4] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 4] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 4)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 8)); + + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); + // 02, 12, 22, 32, 03, 13, 23, 33 + psum[2] = _mm512_inserti32x4(psum[2], _mm512_castsi512_si128(psum[3]), 1); + + // 00, 01, 02, 03, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33 + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 20, 1, 5, 17, 21, 2, 6, 18, 22, 3, 7, 19, 23 + ), psum[2]); + + paScale = _mm512_loadu_ps(aScale); + paScale = _mm512_shuffle_ps(paScale, paScale, 0); + paBias = _mm512_loadu_ps(aBias); + paBias = _mm512_shuffle_ps(paBias, paBias, 0); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + + aIdx += microM; + c += microM * 4; + } + } + + template + inline void scatteredGEMMSmall_512(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + static_assert(m <= 3, "m should be less than or equal to 3"); + static_assert(n <= 3, "n should be less than or equal to 3"); + __m512i pa[3], pb[3], psum[3][3]; + const uint8_t* aPtr[3]; + const int8_t* bPtr[3]; + + psum[0][0] = _mm512_setzero_si512(); + if (m > 1) psum[1][0] = _mm512_setzero_si512(); + if (m > 2) psum[2][0] = _mm512_setzero_si512(); + if (n > 1) psum[0][1] = _mm512_setzero_si512(); + if (m > 1 && n > 1) psum[1][1] = _mm512_setzero_si512(); + if (m > 2 && n > 1) psum[2][1] = _mm512_setzero_si512(); + if (n > 2) psum[0][2] = _mm512_setzero_si512(); + if (m > 1 && n > 2) psum[1][2] = _mm512_setzero_si512(); + if (m > 2 && n > 2) psum[2][2] = _mm512_setzero_si512(); + + aPtr[0] = aBase + aIdx[0] * aIdxScale; + if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; + if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; + + bPtr[0] = bBase + bIdx[0] * bIdxScale; + if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; + if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; + + for (size_t x = 0; x < k; x += 64) + { + if (m > 0) + { + pa[0] = _mm512_loadu_si512(aPtr[0]); + aPtr[0] += 64; + } + if (m > 1) + { + pa[1] = _mm512_loadu_si512(aPtr[1]); + aPtr[1] += 64; + } + if (m > 2) + { + pa[2] = _mm512_loadu_si512(aPtr[2]); + aPtr[2] += 64; + } + + if (n > 0) + { + pb[0] = _mm512_loadu_si512(bPtr[0]); + bPtr[0] += 64; + } + if (n > 1) + { + pb[1] = _mm512_loadu_si512(bPtr[1]); + bPtr[1] += 64; + } + if (n > 2) + { + pb[2] = _mm512_loadu_si512(bPtr[2]); + bPtr[2] += 64; + } + + psum[0][0] = DPBUSD(psum[0][0], pa[0], pb[0]); + if (m > 1) psum[1][0] = DPBUSD(psum[1][0], pa[1], pb[0]); + if (m > 2) psum[2][0] = DPBUSD(psum[2][0], pa[2], pb[0]); + if (n > 1) psum[0][1] = DPBUSD(psum[0][1], pa[0], pb[1]); + if (m > 1 && n > 1) psum[1][1] = DPBUSD(psum[1][1], pa[1], pb[1]); + if (m > 2 && n > 1) psum[2][1] = DPBUSD(psum[2][1], pa[2], pb[1]); + if (n > 2) psum[0][2] = DPBUSD(psum[0][2], pa[0], pb[2]); + if (m > 1 && n > 2) psum[1][2] = DPBUSD(psum[1][2], pa[1], pb[2]); + if (m > 2 && n > 2) psum[2][2] = DPBUSD(psum[2][2], pa[2], pb[2]); + } + + float contextScale[3], outputScale[3], contextBias[3]; + int32_t hsum[3]; + + if (m > 0) + { + contextScale[0] = *reinterpret_cast(aPtr[0]); + contextBias[0] = *reinterpret_cast(aPtr[0] + 4); + } + if (m > 1) + { + contextScale[1] = *reinterpret_cast(aPtr[1]); + contextBias[1] = *reinterpret_cast(aPtr[1] + 4); + } + if (m > 2) + { + contextScale[2] = *reinterpret_cast(aPtr[2]); + contextBias[2] = *reinterpret_cast(aPtr[2] + 4); + } + + + if (n > 0) + { + outputScale[0] = *reinterpret_cast(bPtr[0]); + hsum[0] = *reinterpret_cast(bPtr[0] + 4); + } + if (n > 1) + { + outputScale[1] = *reinterpret_cast(bPtr[1]); + hsum[1] = *reinterpret_cast(bPtr[1] + 4); + } + if (n > 2) + { + outputScale[2] = *reinterpret_cast(bPtr[2]); + hsum[2] = *reinterpret_cast(bPtr[2] + 4); + } + + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][0]); + c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; + } + if (m > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][0]); + c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; + } + if (m > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][0]); + c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; + } + if (n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][1]); + c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; + } + if (m > 1 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][1]); + c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; + } + if (m > 2 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][1]); + c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; + } + if (n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][2]); + c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; + } + if (m > 1 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][2]); + c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; + } + if (m > 2 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][2]); + c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; + } + } + } +} \ No newline at end of file diff --git a/src/archImpl/avx512bw.cpp b/src/archImpl/avx512bw.cpp index 867418cd..9797f315 100644 --- a/src/archImpl/avx512bw.cpp +++ b/src/archImpl/avx512bw.cpp @@ -1,4 +1,27 @@ #include "../SkipBigramModelImpl.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenAVX512 +#include + +namespace kiwi +{ + namespace qgemm + { + // emulate _mm512_dpbusd_epi32 using AVX512BW + static FORCE_INLINE __m512i dpbusd(__m512i src, __m512i a, __m512i b) + { + __m512i one16 = _mm512_set1_epi16(1); + __m512i t0 = _mm512_maddubs_epi16(a, b); + __m512i t1 = _mm512_madd_epi16(t0, one16); + return _mm512_add_epi32(src, t1); + } + } +} + +#define DPBUSD dpbusd +#include "avx512_qgemm.hpp" namespace kiwi { @@ -9,9 +32,136 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template<> + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template<> + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV3_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV4_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_512(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/archImpl/avx512vnni.cpp b/src/archImpl/avx512vnni.cpp index f879d3ef..ef910ed4 100644 --- a/src/archImpl/avx512vnni.cpp +++ b/src/archImpl/avx512vnni.cpp @@ -1,5 +1,9 @@ #include "../SkipBigramModelImpl.hpp" -#include "../qgemm.h" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define DPBUSD _mm512_dpbusd_epi32 +#include "avx512_qgemm.hpp" namespace kiwi { @@ -9,10 +13,131 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template<> + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template<> + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); } namespace qgemm { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV3_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV4_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_512(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } } } diff --git a/src/archImpl/avx_vnni.cpp b/src/archImpl/avx_vnni.cpp index 1f8add2e..3c59f6f6 100644 --- a/src/archImpl/avx_vnni.cpp +++ b/src/archImpl/avx_vnni.cpp @@ -1,4 +1,9 @@ #include "../SkipBigramModelImpl.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define DPBUSD _mm256_dpbusd_epi32 +#include "avx2_qgemm.hpp" namespace kiwi { @@ -8,5 +13,93 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_256(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_256(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } } } diff --git a/src/archImpl/none.cpp b/src/archImpl/none.cpp index 223f1d4a..b5ff3d82 100644 --- a/src/archImpl/none.cpp +++ b/src/archImpl/none.cpp @@ -1,4 +1,7 @@ #include "../SkipBigramModelImpl.hpp" +#include "../gemm.h" + +#include namespace kiwi { @@ -9,9 +12,72 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } } } diff --git a/src/archImpl/sse2.cpp b/src/archImpl/sse2.cpp index 8df7a094..86edbdc4 100644 --- a/src/archImpl/sse2.cpp +++ b/src/archImpl/sse2.cpp @@ -1,4 +1,8 @@ #include "../SkipBigramModelImpl.hpp" +#include "../gemm.h" + +#define Eigen EigenSSE2 +#include namespace kiwi { @@ -8,5 +12,42 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } + diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index bff67239..14d2800e 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -1,4 +1,9 @@ #include "../SkipBigramModelImpl.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenSSE41 +#include namespace kiwi { @@ -8,5 +13,169 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + static FORCE_INLINE __m128i dpbusd(__m128i src, __m128i a, __m128i b) + { + __m128i one16 = _mm_set1_epi16(1); + __m128i t0 = _mm_maddubs_epi16(a, b); + __m128i t1 = _mm_madd_epi16(t0, one16); + return _mm_add_epi32(src, t1); + } + + inline void pack4x16to4x4x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m128i& p0, __m128i& p1, __m128i& p2, __m128i& p3 + ) + { + __m128i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm_loadu_si128((const __m128i*)a0); + p1 = _mm_loadu_si128((const __m128i*)a1); + p2 = _mm_loadu_si128((const __m128i*)a2); + p3 = _mm_loadu_si128((const __m128i*)a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_128( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 4, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m128i pa[4], pb, pbs, psum, pbSum; + __m128 paScale, paBias, pbScale, r; + pbScale = _mm_set1_ps(bScale); + pbSum = _mm_set1_epi32(-bSum); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + aIdx[0] * aIdxScale, + 1 < microM ? aIdx[1] * aIdxScale : 0, + 2 < microM ? aIdx[2] * aIdxScale : 0, + 3 < microM ? aIdx[3] * aIdxScale : 0, + }; + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 16) + { + pack4x16to4x4x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb = _mm_loadu_si128((const __m128i*)(bBuffer + j)); + pbs = _mm_shuffle_epi32(pb, 0x00); + psum = dpbusd(psum, pa[0], pbs); + pbs = _mm_shuffle_epi32(pb, 0x55); + psum = dpbusd(psum, pa[1], pbs); + pbs = _mm_shuffle_epi32(pb, 0xAA); + psum = dpbusd(psum, pa[2], pbs); + pbs = _mm_shuffle_epi32(pb, 0xFF); + psum = dpbusd(psum, pa[3], pbs); + aPtr += 16; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + aIdx += 4; + + paScale = _mm_loadu_ps(aScale); + paBias = _mm_loadu_ps(aBias); + r = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_cvtepi32_ps(psum), pbScale), paScale), paBias); + _mm_storeu_ps(c, r); + c += microM; + } + } + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_128(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/gemm.h b/src/gemm.h new file mode 100644 index 00000000..7ade297e --- /dev/null +++ b/src/gemm.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace kiwi +{ + namespace gemm + { + // c += a.transpose() * b + template + void gemm(size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ); + + // c += a.transpose() * b + template + void gemv(size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ); + } +} diff --git a/src/qgemm.cpp b/src/qgemm.cpp index 5b66c917..b750b0d4 100644 --- a/src/qgemm.cpp +++ b/src/qgemm.cpp @@ -2,1089 +2,13 @@ #include #include #include -#include "qgemm.h" +#include "qgemm.hpp" #include "SIMD.hpp" -#define UNROLL4() do { {LOOP_BODY(0)} {LOOP_BODY(1)} {LOOP_BODY(2)} {LOOP_BODY(3)} } while(0) - namespace kiwi { namespace qgemm { - static constexpr size_t TLBSize = 32768; - - template - struct SharedThreadLocalBuffer - { - thread_local static uint8_t buffer[size]; - static uint8_t* get() - { - return buffer; - } - }; - - template - thread_local uint8_t SharedThreadLocalBuffer::buffer[size]; - - template - int32_t dotprod( - const uint8_t* a, const int8_t* b, size_t n - ) - { - simd::Operator op; - return op.dotprod(a, b, n); - } - template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); - template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); - template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); - template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); - template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); - - inline void packScatteredAPanel(uint8_t* out, size_t ld, const uint8_t* base, const int32_t* idx, size_t scale, size_t m, size_t k) - { - for (size_t i = 0; i < m; ++i) - { - const uint8_t* src = base + idx[i] * scale; - memcpy(out + i * ld, src, k); - } - } - - template - inline void packScatteredBPanel(int8_t* out, size_t ld, int32_t* sum, - const int8_t* base, const int32_t* sumBase, const int32_t* idx, - size_t scale, size_t sumScale, size_t n, size_t k) - { - int32_t* pout = reinterpret_cast(out); - - for (size_t i = 0; i < n; i += blockSize) - { - const size_t innerN = std::min(blockSize, n - i); - for (size_t j = 0; j < k; j += 4) - { - for (size_t x = 0; x < innerN; ++x) - { - const int8_t* src = base + idx[i + x] * scale; - *pout++ = *reinterpret_cast(&src[j]); - } - pout += (blockSize - innerN); - } - - for (size_t x = 0; x < innerN; ++x) - { - sum[i + x] = sumBase[idx[i + x] * sumScale]; - } - } - } - - template - inline void qgemmKernel( - size_t m, size_t n, size_t k, - const uint8_t* a, const int8_t* b, - const float* aScale, const float* bScale, - const float* aBias, const int32_t* sumBuffer, - float* out, size_t ld) - { - // quantized sub-block gemm(m=4, n=64) - static constexpr size_t blockNStride = blockNSize * 4; - __m512i pa, pb[4], psum[16]; - __m512 paScale, paBias, pbScale[4], r; - - for (size_t i = 0; i < n; n += blockNSize * 4) - { - psum[0] = psum[4] = psum[8] = psum[12] = _mm512_loadu_si512(sumBuffer); - psum[1] = psum[5] = psum[9] = psum[13] = _mm512_loadu_si512(sumBuffer + blockNSize); - psum[2] = psum[6] = psum[10] = psum[14] = _mm512_loadu_si512(sumBuffer + blockNSize * 2); - psum[3] = psum[7] = psum[11] = psum[15] = _mm512_loadu_si512(sumBuffer + blockNSize * 3); - - for (size_t j = 0; j < k; j += 4) - { - pb[0] = _mm512_loadu_si512(b); - pb[1] = _mm512_loadu_si512(b + blockNStride * 1); - pb[2] = _mm512_loadu_si512(b + blockNStride * 2); - pb[3] = _mm512_loadu_si512(b + blockNStride * 3); - - pa = _mm512_set1_epi32(*reinterpret_cast(a)); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa, pb[0]); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa, pb[1]); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa, pb[2]); - psum[3] = _mm512_dpbusd_epi32(psum[3], pa, pb[3]); - - pa = _mm512_set1_epi32(*reinterpret_cast(a + k)); - psum[4] = _mm512_dpbusd_epi32(psum[4], pa, pb[0]); - psum[5] = _mm512_dpbusd_epi32(psum[5], pa, pb[1]); - psum[6] = _mm512_dpbusd_epi32(psum[6], pa, pb[2]); - psum[7] = _mm512_dpbusd_epi32(psum[7], pa, pb[3]); - - pa = _mm512_set1_epi32(*reinterpret_cast(a + k * 2)); - psum[8] = _mm512_dpbusd_epi32(psum[8], pa, pb[0]); - psum[9] = _mm512_dpbusd_epi32(psum[9], pa, pb[1]); - psum[10] = _mm512_dpbusd_epi32(psum[10], pa, pb[2]); - psum[11] = _mm512_dpbusd_epi32(psum[11], pa, pb[3]); - - pa = _mm512_set1_epi32(*reinterpret_cast(a + k * 3)); - psum[12] = _mm512_dpbusd_epi32(psum[12], pa, pb[0]); - psum[13] = _mm512_dpbusd_epi32(psum[13], pa, pb[1]); - psum[14] = _mm512_dpbusd_epi32(psum[14], pa, pb[2]); - psum[15] = _mm512_dpbusd_epi32(psum[15], pa, pb[3]); - - a += 4; - b += blockNStride * 4; - } - pbScale[0] = _mm512_loadu_ps(bScale); - pbScale[1] = _mm512_loadu_ps(bScale + blockNSize); - pbScale[2] = _mm512_loadu_ps(bScale + blockNSize * 2); - pbScale[3] = _mm512_loadu_ps(bScale + blockNSize * 3); - - paScale = _mm512_set1_ps(*aScale++); - paBias = _mm512_set1_ps(*aBias++); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale[0]), paScale, paBias); - _mm512_storeu_ps(out, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[1]), pbScale[1]), paScale, paBias); - _mm512_storeu_ps(out + blockNSize, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[2]), pbScale[2]), paScale, paBias); - _mm512_storeu_ps(out + blockNSize * 2, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[3]), pbScale[3]), paScale, paBias); - _mm512_storeu_ps(out + blockNSize * 3, r); - - paScale = _mm512_set1_ps(*aScale++); - paBias = _mm512_set1_ps(*aBias++); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[4]), pbScale[0]), paScale, paBias); - _mm512_storeu_ps(out + ld, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[5]), pbScale[1]), paScale, paBias); - _mm512_storeu_ps(out + ld + blockNSize, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[6]), pbScale[2]), paScale, paBias); - _mm512_storeu_ps(out + ld + blockNSize * 2, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[7]), pbScale[3]), paScale, paBias); - _mm512_storeu_ps(out + ld + blockNSize * 3, r); - - paScale = _mm512_set1_ps(*aScale++); - paBias = _mm512_set1_ps(*aBias++); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[8]), pbScale[0]), paScale, paBias); - _mm512_storeu_ps(out + ld * 2, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[9]), pbScale[1]), paScale, paBias); - _mm512_storeu_ps(out + ld * 2 + blockNSize, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[10]), pbScale[2]), paScale, paBias); - _mm512_storeu_ps(out + ld * 2 + blockNSize * 2, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[11]), pbScale[3]), paScale, paBias); - _mm512_storeu_ps(out + ld * 2 + blockNSize * 3, r); - - paScale = _mm512_set1_ps(*aScale++); - paBias = _mm512_set1_ps(*aBias++); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[12]), pbScale[0]), paScale, paBias); - _mm512_storeu_ps(out + ld * 3, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[13]), pbScale[1]), paScale, paBias); - _mm512_storeu_ps(out + ld * 3 + blockNSize, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[14]), pbScale[2]), paScale, paBias); - _mm512_storeu_ps(out + ld * 3 + blockNSize * 2, r); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[15]), pbScale[3]), paScale, paBias); - _mm512_storeu_ps(out + ld * 3 + blockNSize * 3, r); - sumBuffer += blockNSize * 4; - out += blockNSize * 4; - } - } - - template - void scatteredGEMM( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ) - { - // assert k <= 384 - constexpr size_t packM = 48, packN = 256, packK = 384; - thread_local uint8_t buffer[packM * packK + packN * packK]; - thread_local int32_t sumBuffer[packN]; - uint8_t* aBuffer = buffer; - int8_t* bBuffer = reinterpret_cast(buffer + packM * packK); - - for (size_t ni = 0; ni < n; ni += packN) - { - const size_t microN = std::min(packN, n - ni); - packScatteredBPanel(bBuffer, packK, sumBuffer, bBase, reinterpret_cast(bBase + k + 4), bIdx + ni, bIdxScale, bIdxScale / 4, microN, k); - - for (size_t mi = 0; mi < m; mi += packM) - { - const size_t microM = std::min(packM, m - mi); - packScatteredAPanel(aBuffer, packK, aBase, aIdx + mi, aIdxScale, microM, k); - - //qgemmKernel<16>(microM, microN, k, aBuffer, bBuffer, sumBuffer, nullptr, n); - } - } - } - - template void scatteredGEMM( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); - - template - void scatteredGEMMBaseline( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ) - { - thread_local Vector buffer; - buffer.resize((m + n) * (k + 8)); - uint8_t* aBuffer = buffer.data(); - int8_t* bBuffer = reinterpret_cast(aBuffer + m * (k + 8)); - simd::Operator op; - - for (size_t i = 0; i < m; ++i) - { - std::memcpy(aBuffer + i * (k + 8), &aBase[aIdx[i] * aIdxScale], k + 8); - } - for (size_t i = 0; i < n; ++i) - { - std::memcpy(bBuffer + i * (k + 8), &bBase[bIdx[i] * bIdxScale], k + 8); - } - - for (size_t i = 0; i < m; ++i) - { - for (size_t j = 0; j < n; ++j) - { - const auto* aPtr = aBuffer + i * (k + 8); - const auto* bPtr = bBuffer + j * (k + 8); - int32_t acc = op.dotprod(aPtr, bPtr, k); - const float contextScale = *reinterpret_cast(aPtr + k), - outputScale = *reinterpret_cast(bPtr + k), - contextBias = *reinterpret_cast(aPtr + k + 4); - const int32_t hsum = *reinterpret_cast(bPtr + k + 4); - c[i * ldc + j] = (acc - hsum) * contextScale * outputScale + contextBias; - } - } - } - - inline void pack16x4( - void* out, - const void* a0, - const void* a1, - const void* a2, - const void* a3, - const void* a4, - const void* a5, - const void* a6, - const void* a7, - const void* a8, - const void* a9, - const void* a10, - const void* a11, - const void* a12, - const void* a13, - const void* a14, - const void* a15 - ) - { - // 00, 01, 02, 03, 40, 41, 42, 43 - auto p0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a0)), _mm_loadu_epi32(a4), 1); - auto p1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a1)), _mm_loadu_epi32(a5), 1); - auto p2 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a2)), _mm_loadu_epi32(a6), 1); - auto p3 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a3)), _mm_loadu_epi32(a7), 1); - auto p4 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a8)), _mm_loadu_epi32(a12), 1); - auto p5 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a9)), _mm_loadu_epi32(a13), 1); - auto p6 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a10)), _mm_loadu_epi32(a14), 1); - auto p7 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_loadu_epi32(a11)), _mm_loadu_epi32(a15), 1); - - // 00, 10, 01, 11, 40, 50, 41, 51 - auto q0 = _mm256_unpacklo_epi32(p0, p1); - // 02, 12, 03, 13, 42, 52, 43, 53 - auto q1 = _mm256_unpackhi_epi32(p0, p1); - // 20, 30, 21, 31, 60, 70, 61, 71 - auto q2 = _mm256_unpacklo_epi32(p2, p3); - // 22, 32, 23, 33, 62, 72, 63, 73 - auto q3 = _mm256_unpackhi_epi32(p2, p3); - auto q4 = _mm256_unpacklo_epi32(p4, p5); - auto q5 = _mm256_unpackhi_epi32(p4, p5); - auto q6 = _mm256_unpacklo_epi32(p6, p7); - auto q7 = _mm256_unpackhi_epi32(p6, p7); - - // 00, 10, 20, 30, 40, 50, 60, 70 - p0 = _mm256_unpacklo_epi64(q0, q2); - // 01, 11, 21, 31, 41, 51, 61, 71 - p1 = _mm256_unpackhi_epi64(q0, q2); - p2 = _mm256_unpacklo_epi64(q1, q3); - p3 = _mm256_unpackhi_epi64(q1, q3); - p4 = _mm256_unpacklo_epi64(q4, q6); - p5 = _mm256_unpackhi_epi64(q4, q6); - p6 = _mm256_unpacklo_epi64(q5, q7); - p7 = _mm256_unpackhi_epi64(q5, q7); - - auto* pout = reinterpret_cast<__m256i*>(out); - _mm256_storeu_si256(pout++, p0); - _mm256_storeu_si256(pout++, p4); - _mm256_storeu_si256(pout++, p1); - _mm256_storeu_si256(pout++, p5); - _mm256_storeu_si256(pout++, p2); - _mm256_storeu_si256(pout++, p6); - _mm256_storeu_si256(pout++, p3); - _mm256_storeu_si256(pout++, p7); - } - - inline void pack4x64to4x16x4( - const void* a0, const void* a1, const void* a2, const void* a3, - __m512i& p0, __m512i& p1, __m512i& p2, __m512i& p3 - ) - { - __m512i q0, q1, q2, q3; - // 00, 01, 02, 03, 04, 05, 06, 07, ... - p0 = _mm512_loadu_si512(a0); - p1 = _mm512_loadu_si512(a1); - p2 = _mm512_loadu_si512(a2); - p3 = _mm512_loadu_si512(a3); - - // 00, 10, 01, 11, 04, 14, 05, 15, ... - q0 = _mm512_unpacklo_epi32(p0, p1); - // 02, 12, 03, 13, 06, 16, 07, 17, ... - q1 = _mm512_unpackhi_epi32(p0, p1); - // 20, 30, 21, 31, 24, 34, 25, 35, ... - q2 = _mm512_unpacklo_epi32(p2, p3); - // 22, 32, 23, 33, 26, 36, 27, 37, ... - q3 = _mm512_unpackhi_epi32(p2, p3); - - // 00, 10, 20, 30, 04, 14, 24, 34, ... - p0 = _mm512_unpacklo_epi64(q0, q2); - // 01, 11, 21, 31, 05, 15, 25, 35, ... - p1 = _mm512_unpackhi_epi64(q0, q2); - // 02, 12, 22, 32, 06, 16, 26, 36, ... - p2 = _mm512_unpacklo_epi64(q1, q3); - // 03, 13, 23, 33, 07, 17, 27, 37, ... - p3 = _mm512_unpackhi_epi64(q1, q3); - } - - void scatteredGEMV( - size_t m, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* b, - float* c - ) - { - constexpr size_t packM = 16, packN = 1, packK = 384; - auto* buffer = SharedThreadLocalBuffer<>::get(); - int8_t* bBuffer = reinterpret_cast(buffer); - float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM; - memcpy(bBuffer, b, k); - float bScale = *reinterpret_cast(b + k); - int32_t bSum = *reinterpret_cast(b + k + 4); - - __m512i pa[4], pb, pbs, psum, pbSum, pr = _mm512_setzero_si512(); - __m512 paScale, paBias, pbScale, r; - pbScale = _mm512_set1_ps(bScale); - pbSum = _mm512_set1_epi32(-bSum / 4); - __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), - shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), - shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), - shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); - - for (size_t mi = 0; mi < m; mi += packM) - { - const size_t microM = std::min(packM, m - mi); -#define LOOP_BODY(mj) \ - const int32_t aOffsets[4] = {\ - mj * 4 < microM ? aIdx[0] * aIdxScale : 0,\ - mj * 4 + 1 < microM ? aIdx[1] * aIdxScale : 0,\ - mj * 4 + 2 < microM ? aIdx[2] * aIdxScale : 0,\ - mj * 4 + 3 < microM ? aIdx[3] * aIdxScale : 0,\ - };\ - auto* aPtr = aBase;\ - psum = pbSum;\ - for (size_t j = 0; j < k; j += 64)\ - {\ - pack4x64to4x16x4(aPtr + aOffsets[0],\ - aPtr + aOffsets[1],\ - aPtr + aOffsets[2],\ - aPtr + aOffsets[3],\ - pa[0], pa[1], pa[2], pa[3]);\ - pb = _mm512_loadu_si512(bBuffer + j);\ - pbs = _mm512_permutexvar_epi32(shfIdx0, pb);\ - psum = _mm512_dpbusd_epi32(psum, pa[0], pbs);\ - pbs = _mm512_permutexvar_epi32(shfIdx1, pb);\ - psum = _mm512_dpbusd_epi32(psum, pa[1], pbs);\ - pbs = _mm512_permutexvar_epi32(shfIdx2, pb);\ - psum = _mm512_dpbusd_epi32(psum, pa[2], pbs);\ - pbs = _mm512_permutexvar_epi32(shfIdx3, pb);\ - psum = _mm512_dpbusd_epi32(psum, pa[3], pbs);\ - aPtr += 64;\ - }\ - for (size_t i = 0; i < 4; ++i)\ - {\ - aScale[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i]);\ - aBias[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i] + 4);\ - }\ - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4));\ - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8));\ - pr = _mm512_inserti32x4(pr, _mm512_castsi512_si128(psum), mj);\ - aIdx += 4; - - UNROLL4(); -#undef LOOP_BODY - - paScale = _mm512_loadu_ps(aScale); - paBias = _mm512_loadu_ps(aBias); - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(pr), pbScale), paScale, paBias); - _mm512_storeu_ps(c, r); - c += microM; - } - } - - void scatteredGEMV8x1( - size_t m, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* b, - float* c - ) - { - constexpr size_t packM = 8, packN = 1, packK = 384; - auto* buffer = SharedThreadLocalBuffer<>::get(); - int8_t* bBuffer = reinterpret_cast(buffer); - float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM; - memcpy(bBuffer, b, k); - float bScale = *reinterpret_cast(b + k); - int32_t bSum = *reinterpret_cast(b + k + 4); - - __m512i pa[4], pb, pbs, psum, pbSum; - __m256 paScale, paBias, pbScale, r; - __m256i pr = _mm256_setzero_si256(); - pbScale = _mm256_set1_ps(bScale); - pbSum = _mm512_set1_epi32(-bSum / 4); - __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), - shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), - shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), - shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); - - { - auto* aPtr = aBase; - psum = pbSum; - for (size_t j = 0; j < k; j += 64) - { - pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, - aPtr + aIdx[1] * aIdxScale, - aPtr + aIdx[2] * aIdxScale, - aPtr + aIdx[3] * aIdxScale, - pa[0], pa[1], pa[2], pa[3]); - pb = _mm512_loadu_si512(bBuffer + j); - pbs = _mm512_permutexvar_epi32(shfIdx0, pb); - psum = _mm512_dpbusd_epi32(psum, pa[0], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx1, pb); - psum = _mm512_dpbusd_epi32(psum, pa[1], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx2, pb); - psum = _mm512_dpbusd_epi32(psum, pa[2], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx3, pb); - psum = _mm512_dpbusd_epi32(psum, pa[3], pbs); - aPtr += 64; - } - for (size_t i = 0; i < 4; ++i) - { - aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); - aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); - } - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); - pr = _mm512_castsi512_si256(psum); - aIdx += 4; - } - { - auto* aPtr = aBase; - psum = pbSum; - for (size_t j = 0; j < k; j += 64) - { - pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, - aPtr + aIdx[1] * aIdxScale, - aPtr + aIdx[2] * aIdxScale, - aPtr + aIdx[3] * aIdxScale, - pa[0], pa[1], pa[2], pa[3]); - pb = _mm512_loadu_si512(bBuffer + j); - pbs = _mm512_permutexvar_epi32(shfIdx0, pb); - psum = _mm512_dpbusd_epi32(psum, pa[0], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx1, pb); - psum = _mm512_dpbusd_epi32(psum, pa[1], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx2, pb); - psum = _mm512_dpbusd_epi32(psum, pa[2], pbs); - pbs = _mm512_permutexvar_epi32(shfIdx3, pb); - psum = _mm512_dpbusd_epi32(psum, pa[3], pbs); - aPtr += 64; - } - for (size_t i = 0; i < 4; ++i) - { - aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); - aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); - } - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); - psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); - pr = _mm256_inserti32x4(pr, _mm512_castsi512_si128(psum), 1); - aIdx += 4; - } - - paScale = _mm256_loadu_ps(aScale); - paBias = _mm256_loadu_ps(aBias); - r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); - _mm256_storeu_ps(c, r); - c += 8; - } - - void scatteredGEMV2( - size_t m, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c - ) - { - constexpr size_t packM = 4, packN = 2, packK = 384; - auto* buffer = SharedThreadLocalBuffer<>::get(); - int8_t* bBuffer = reinterpret_cast(buffer); - float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM * 2; - memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); - memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); - float bScale[2] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) - }; - int32_t bSum[2] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) - }; - - __m512i pa[4], pb[2], psum[2], pbSum[2], pt[2]; - __m256 paScale, paBias, pbScale, r; - __m256i pr; - pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); - pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); - pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); - __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), - shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), - shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), - shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); - - for (size_t mi = 0; mi < m; mi += packM) - { - const size_t microM = std::min(packM, m - mi); - const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, - }; - auto* aPtr = aBase; - psum[0] = pbSum[0]; - psum[1] = pbSum[1]; - for (size_t j = 0; j < k; j += 64) - { - pack4x64to4x16x4(aPtr + aOffsets[0], - aPtr + aOffsets[1], - aPtr + aOffsets[2], - aPtr + aOffsets[3], - pa[0], pa[1], pa[2], pa[3]); - pb[0] = _mm512_loadu_si512(bBuffer + j); - pb[1] = _mm512_loadu_si512(bBuffer + packK + j); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); - aPtr += 64; - } - for (size_t i = 0; i < 4; ++i) - { - aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); - aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); - } - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); - - // 00, 01, 10, 11, ... - pt[0] = _mm512_unpacklo_epi32(psum[0], psum[1]); - // 20, 21, 30, 31, ... - pt[1] = _mm512_unpackhi_epi32(psum[0], psum[1]); - - // 00, 01, 10, 11, 20, 21, 30, 31 - pr = _mm256_permute2x128_si256(_mm512_castsi512_si256(pt[0]), _mm512_castsi512_si256(pt[1]), 0x20); - paScale = _mm256_loadu_ps(aScale); - paBias = _mm256_loadu_ps(aBias); - r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); - _mm256_storeu_ps(c, r); - - aIdx += microM; - c += microM * 2; - } - } - - void scatteredGEMV3( - size_t m, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c - ) - { - constexpr size_t packM = 4, packN = 3, packK = 384; - auto* buffer = SharedThreadLocalBuffer<>::get(); - int8_t* bBuffer = reinterpret_cast(buffer); - float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM; - memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); - memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); - memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); - float bScale[4] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), - 0 - }; - int32_t bSum[3] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4) - }; - __m512i pa[4], pb[3], psum[3], pbSum[3]; - __m512 paScale, paBias, pbScale, r; - pbScale = _mm512_permutexvar_ps( - _mm512_setr_epi32(0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 3, 3), - _mm512_castps128_ps512(_mm_loadu_ps(bScale))); - pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); - pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); - pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); - __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), - shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), - shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), - shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15), - shfIdxT = _mm512_setr_epi32(0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4); - - for (size_t mi = 0; mi < m; mi += packM) - { - const size_t microM = std::min(packM, m - mi); - const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, - }; - auto* aPtr = aBase; - psum[0] = pbSum[0]; - psum[1] = pbSum[1]; - psum[2] = pbSum[2]; - for (size_t j = 0; j < k; j += 64) - { - pack4x64to4x16x4(aPtr + aOffsets[0], - aPtr + aOffsets[1], - aPtr + aOffsets[2], - aPtr + aOffsets[3], - pa[0], pa[1], pa[2], pa[3]); - pb[0] = _mm512_loadu_si512(bBuffer + j); - pb[1] = _mm512_loadu_si512(bBuffer + packK + j); - pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[2])); - aPtr += 64; - } - for (size_t i = 0; i < 4; ++i) - { - aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); - aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); - } - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); - psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); - psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); - - // 00, 10, 20, 30, 01, 11, 21, 31 - psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); - - // 00, 01, 02, 10, 11, 12, 20, 21, 22, 30, 31, 32, ... - psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( - 0, 4, 16, 1, 5, 17, 2, 6, 18, 3, 7, 19, 0, 0, 0, 0 - ), psum[2]); - - paScale = _mm512_castps128_ps512(_mm_loadu_ps(aScale)); - paScale = _mm512_permutexvar_ps(shfIdxT, paScale); - paBias = _mm512_castps128_ps512(_mm_loadu_ps(aBias)); - paBias = _mm512_permutexvar_ps(shfIdxT, paBias); - - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); - _mm512_mask_storeu_ps(c, 0x0FFF, r); - - aIdx += microM; - c += microM * 3; - } - } - - void scatteredGEMV4( - size_t m, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c - ) - { - constexpr size_t packM = 4, packN = 4, packK = 384; - auto* buffer = SharedThreadLocalBuffer<>::get(); - int8_t* bBuffer = reinterpret_cast(buffer); - float* aScale = reinterpret_cast(bBuffer + packN * packK); - float* aBias = aScale + packM * 4; - memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); - memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); - memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); - memcpy(bBuffer + packK * 3, bBase + bIdx[3] * bIdxScale, k); - float bScale[4] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), - *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k) - }; - int32_t bSum[4] = { - *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4), - *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k + 4) - }; - __m512i pa[4], pb[4], psum[4], pbSum[4]; - __m512 paScale, paBias, pbScale, r; - pbScale = _mm512_broadcast_f32x4(_mm_loadu_ps(bScale)); - pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); - pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); - pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); - pbSum[3] = _mm512_set1_epi32(-bSum[3] / 4); - __m512i shfIdx0 = _mm512_setr_epi32(0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12), - shfIdx1 = _mm512_setr_epi32(1, 1, 1, 1, 5, 5, 5, 5, 9, 9, 9, 9, 13, 13, 13, 13), - shfIdx2 = _mm512_setr_epi32(2, 2, 2, 2, 6, 6, 6, 6, 10, 10, 10, 10, 14, 14, 14, 14), - shfIdx3 = _mm512_setr_epi32(3, 3, 3, 3, 7, 7, 7, 7, 11, 11, 11, 11, 15, 15, 15, 15); - - for (size_t mi = 0; mi < m; mi += packM) - { - const size_t microM = std::min(packM, m - mi); - const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, - }; - auto* aPtr = aBase; - psum[0] = pbSum[0]; - psum[1] = pbSum[1]; - psum[2] = pbSum[2]; - psum[3] = pbSum[3]; - for (size_t j = 0; j < k; j += 64) - { - pack4x64to4x16x4(aPtr + aOffsets[0], - aPtr + aOffsets[1], - aPtr + aOffsets[2], - aPtr + aOffsets[3], - pa[0], pa[1], pa[2], pa[3]); - pb[0] = _mm512_loadu_si512(bBuffer + j); - pb[1] = _mm512_loadu_si512(bBuffer + packK + j); - pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); - pb[3] = _mm512_loadu_si512(bBuffer + packK * 3 + j); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[0])); - psum[0] = _mm512_dpbusd_epi32(psum[0], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[0])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[1])); - psum[1] = _mm512_dpbusd_epi32(psum[1], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[1])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[2])); - psum[2] = _mm512_dpbusd_epi32(psum[2], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[2])); - psum[3] = _mm512_dpbusd_epi32(psum[3], pa[0], _mm512_permutexvar_epi32(shfIdx0, pb[3])); - psum[3] = _mm512_dpbusd_epi32(psum[3], pa[1], _mm512_permutexvar_epi32(shfIdx1, pb[3])); - psum[3] = _mm512_dpbusd_epi32(psum[3], pa[2], _mm512_permutexvar_epi32(shfIdx2, pb[3])); - psum[3] = _mm512_dpbusd_epi32(psum[3], pa[3], _mm512_permutexvar_epi32(shfIdx3, pb[3])); - aPtr += 64; - } - for (size_t i = 0; i < 4; ++i) - { - aScale[i * 4] = *reinterpret_cast(aPtr + aOffsets[i]); - aBias[i * 4] = *reinterpret_cast(aPtr + aOffsets[i] + 4); - } - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); - psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); - psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); - psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); - psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); - psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 4)); - psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 8)); - - // 00, 10, 20, 30, 01, 11, 21, 31 - psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); - // 02, 12, 22, 32, 03, 13, 23, 33 - psum[2] = _mm512_inserti32x4(psum[2], _mm512_castsi512_si128(psum[3]), 1); - - // 00, 01, 02, 03, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33 - psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( - 0, 4, 16, 20, 1, 5, 17, 21, 2, 6, 18, 22, 3, 7, 19, 23 - ), psum[2]); - - paScale = _mm512_loadu_ps(aScale); - paScale = _mm512_shuffle_ps(paScale, paScale, 0); - paBias = _mm512_loadu_ps(aBias); - paBias = _mm512_shuffle_ps(paBias, paBias, 0); - - r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); - _mm512_storeu_ps(c, r); - - aIdx += microM; - c += microM * 4; - } - } - - template - void scatteredGEMMSmall( - size_t, size_t, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ) - { - static_assert(m <= 3, "m should be less than or equal to 3"); - static_assert(n <= 3, "n should be less than or equal to 3"); - __m512i pa[3], pb[3], psum[3][3]; - const uint8_t* aPtr[3]; - const int8_t* bPtr[3]; - - psum[0][0] = _mm512_setzero_si512(); - if (m > 1) psum[1][0] = _mm512_setzero_si512(); - if (m > 2) psum[2][0] = _mm512_setzero_si512(); - if (n > 1) psum[0][1] = _mm512_setzero_si512(); - if (m > 1 && n > 1) psum[1][1] = _mm512_setzero_si512(); - if (m > 2 && n > 1) psum[2][1] = _mm512_setzero_si512(); - if (n > 2) psum[0][2] = _mm512_setzero_si512(); - if (m > 1 && n > 2) psum[1][2] = _mm512_setzero_si512(); - if (m > 2 && n > 2) psum[2][2] = _mm512_setzero_si512(); - - aPtr[0] = aBase + aIdx[0] * aIdxScale; - if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; - if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; - - bPtr[0] = bBase + bIdx[0] * bIdxScale; - if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; - if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; - - for (size_t x = 0; x < k; x += 64) - { - if (m > 0) - { - pa[0] = _mm512_loadu_si512(aPtr[0]); - aPtr[0] += 64; - } - if (m > 1) - { - pa[1] = _mm512_loadu_si512(aPtr[1]); - aPtr[1] += 64; - } - if (m > 2) - { - pa[2] = _mm512_loadu_si512(aPtr[2]); - aPtr[2] += 64; - } - - if (n > 0) - { - pb[0] = _mm512_loadu_si512(bPtr[0]); - bPtr[0] += 64; - } - if (n > 1) - { - pb[1] = _mm512_loadu_si512(bPtr[1]); - bPtr[1] += 64; - } - if (n > 2) - { - pb[2] = _mm512_loadu_si512(bPtr[2]); - bPtr[2] += 64; - } - - psum[0][0] = _mm512_dpbusd_epi32(psum[0][0], pa[0], pb[0]); - if (m > 1) psum[1][0] = _mm512_dpbusd_epi32(psum[1][0], pa[1], pb[0]); - if (m > 2) psum[2][0] = _mm512_dpbusd_epi32(psum[2][0], pa[2], pb[0]); - if (n > 1) psum[0][1] = _mm512_dpbusd_epi32(psum[0][1], pa[0], pb[1]); - if (m > 1 && n > 1) psum[1][1] = _mm512_dpbusd_epi32(psum[1][1], pa[1], pb[1]); - if (m > 2 && n > 1) psum[2][1] = _mm512_dpbusd_epi32(psum[2][1], pa[2], pb[1]); - if (n > 2) psum[0][2] = _mm512_dpbusd_epi32(psum[0][2], pa[0], pb[2]); - if (m > 1 && n > 2) psum[1][2] = _mm512_dpbusd_epi32(psum[1][2], pa[1], pb[2]); - if (m > 2 && n > 2) psum[2][2] = _mm512_dpbusd_epi32(psum[2][2], pa[2], pb[2]); - } - - float contextScale[3], outputScale[3], contextBias[3]; - int32_t hsum[3]; - - if (m > 0) - { - contextScale[0] = *reinterpret_cast(aPtr[0]); - contextBias[0] = *reinterpret_cast(aPtr[0] + 4); - } - if (m > 1) - { - contextScale[1] = *reinterpret_cast(aPtr[1]); - contextBias[1] = *reinterpret_cast(aPtr[1] + 4); - } - if (m > 2) - { - contextScale[2] = *reinterpret_cast(aPtr[2]); - contextBias[2] = *reinterpret_cast(aPtr[2] + 4); - } - - - if (n > 0) - { - outputScale[0] = *reinterpret_cast(bPtr[0]); - hsum[0] = *reinterpret_cast(bPtr[0] + 4); - } - if (n > 1) - { - outputScale[1] = *reinterpret_cast(bPtr[1]); - hsum[1] = *reinterpret_cast(bPtr[1] + 4); - } - if (n > 2) - { - outputScale[2] = *reinterpret_cast(bPtr[2]); - hsum[2] = *reinterpret_cast(bPtr[2] + 4); - } - - { - int32_t acc = _mm512_reduce_add_epi32(psum[0][0]); - c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; - } - if (m > 1) - { - int32_t acc = _mm512_reduce_add_epi32(psum[1][0]); - c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; - } - if (m > 2) - { - int32_t acc = _mm512_reduce_add_epi32(psum[2][0]); - c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; - } - if (n > 1) - { - int32_t acc = _mm512_reduce_add_epi32(psum[0][1]); - c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; - } - if (m > 1 && n > 1) - { - int32_t acc = _mm512_reduce_add_epi32(psum[1][1]); - c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; - } - if (m > 2 && n > 1) - { - int32_t acc = _mm512_reduce_add_epi32(psum[2][1]); - c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; - } - if (n > 2) - { - int32_t acc = _mm512_reduce_add_epi32(psum[0][2]); - c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; - } - if (m > 1 && n > 2) - { - int32_t acc = _mm512_reduce_add_epi32(psum[1][2]); - c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; - } - if (m > 2 && n > 2) - { - int32_t acc = _mm512_reduce_add_epi32(psum[2][2]); - c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; - } - } - - template - void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ) - { - using Fn = decltype(&scatteredGEMMBaseline); - static constexpr Fn fnTable[] = { - scatteredGEMMBaseline, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall, - scatteredGEMMSmall - }; - - if (m <= 3 && n <= 3) - { - return (*fnTable[(m - 1) * 3 + (n - 1)])(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); - } - - if (n == 1 && ldc == 1) - { - if (m == 8) - { - return scatteredGEMV8x1(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); - } - else - { - return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); - } - } - - if (m >= 4) - { - if (n == 2 && ldc == 2) return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); - if (n == 3 && ldc == 3) return scatteredGEMV3(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); - if (n == 4 && ldc == 4) return scatteredGEMV4(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); - } - return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); - } - - template void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); - template void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); - template void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); - template void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); - template void scatteredGEMMOpt( - size_t m, size_t n, size_t k, - const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, - const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, - float* c, size_t ldc - ); } } diff --git a/src/qgemm.hpp b/src/qgemm.hpp new file mode 100644 index 00000000..765af5a0 --- /dev/null +++ b/src/qgemm.hpp @@ -0,0 +1,190 @@ +#pragma once +#include "qgemm.h" +#include "SIMD.hpp" + +namespace kiwi +{ + namespace qgemm + { + static constexpr size_t TLBSize = 32768; + + template + struct SharedThreadLocalBuffer + { + thread_local static uint8_t buffer[size]; + static uint8_t* get() + { + return buffer; + } + }; + + template + thread_local uint8_t SharedThreadLocalBuffer::buffer[size]; + + template + int32_t dotprod( + const uint8_t* a, const int8_t* b, size_t n + ) + { + simd::Operator op; + return op.dotprod(a, b, n); + } + + template + void scatteredGEMMBaseline( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + thread_local Vector buffer; + buffer.resize((m + n) * (k + 8)); + uint8_t* aBuffer = buffer.data(); + int8_t* bBuffer = reinterpret_cast(aBuffer + m * (k + 8)); + simd::Operator op; + + for (size_t i = 0; i < m; ++i) + { + std::memcpy(aBuffer + i * (k + 8), &aBase[aIdx[i] * aIdxScale], k + 8); + } + for (size_t i = 0; i < n; ++i) + { + std::memcpy(bBuffer + i * (k + 8), &bBase[bIdx[i] * bIdxScale], k + 8); + } + + for (size_t i = 0; i < m; ++i) + { + for (size_t j = 0; j < n; ++j) + { + const auto* aPtr = aBuffer + i * (k + 8); + const auto* bPtr = bBuffer + j * (k + 8); + int32_t acc = op.dotprod(aPtr, bPtr, k); + const float contextScale = *reinterpret_cast(aPtr + k), + outputScale = *reinterpret_cast(bPtr + k), + contextBias = *reinterpret_cast(aPtr + k + 4); + const int32_t hsum = *reinterpret_cast(bPtr + k + 4); + c[i * ldc + j] = (acc - hsum) * contextScale * outputScale + contextBias; + } + } + } + + template + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + const int32_t bIdx[1] = { 0 }; + return scatteredGEMMBaseline(m, 1, k, aBase, aIdx, aIdxScale, b, bIdx, 0, c, 1); + } + + template + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + const int32_t bIdx[1] = { 0 }; + return scatteredGEMMBaseline(m, 1, k, aBase, aIdx, aIdxScale, b, bIdx, 0, c, 1); + } + + template + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 2, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 2); + } + + template + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 3, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 3); + } + + template + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 4, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 4); + } + + template + struct ScatteredGEMMSmall + { + template + static void op( + size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template + void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + using Fn = decltype(&scatteredGEMMBaseline); + static constexpr Fn fnTable[] = { + scatteredGEMMBaseline, + ScatteredGEMMSmall::template op<1, 2>, + ScatteredGEMMSmall::template op<1, 3>, + ScatteredGEMMSmall::template op<2, 1>, + ScatteredGEMMSmall::template op<2, 2>, + ScatteredGEMMSmall::template op<2, 3>, + ScatteredGEMMSmall::template op<3, 1>, + ScatteredGEMMSmall::template op<3, 2>, + ScatteredGEMMSmall::template op<3, 3> + }; + + if (m <= 3 && n <= 3) + { + return (*fnTable[(m - 1) * 3 + (n - 1)])(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + + if (n == 1 && ldc == 1) + { + if (m == 8) + { + return scatteredGEMV8x1(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + else + { + return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + } + + if (m >= 4) + { + if (n == 2 && ldc == 2) return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 3 && ldc == 3) return scatteredGEMV3(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 4 && ldc == 4) return scatteredGEMV4(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + } +} diff --git a/src/search.cpp b/src/search.cpp index 2aca6f11..e3249872 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -178,7 +178,7 @@ namespace kiwi template ValueTy detail::searchKVImpl(const void* kv, size_t size, IntTy target) { - return OptimizedImpl::searchKV(kv, size, target); + return OptimizedImpl::template searchKV(kv, size, target); } template @@ -300,7 +300,7 @@ namespace kiwi { template ARCH_TARGET("sse2") - FORCE_INLINE bool testEq(__m128i p, size_t offset, size_t size, size_t& ret) + inline bool testEq(__m128i p, size_t offset, size_t size, size_t& ret) { uint32_t m = _mm_movemask_epi8(p); uint32_t b = utils::countTrailingZeroes(m); @@ -314,7 +314,7 @@ namespace kiwi template ARCH_TARGET("sse2") - FORCE_INLINE bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + inline bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -476,7 +476,7 @@ namespace kiwi template ARCH_TARGET("sse2") - FORCE_INLINE ValueTy nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target) + inline ValueTy nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target) { size_t i = 0, r; @@ -636,7 +636,7 @@ namespace kiwi { template ARCH_TARGET("avx2") - FORCE_INLINE bool testEq(__m256i p, size_t offset, size_t size, size_t& ret) + inline bool testEq(__m256i p, size_t offset, size_t size, size_t& ret) { uint32_t m = _mm256_movemask_epi8(p); uint32_t b = utils::countTrailingZeroes(m); @@ -648,7 +648,7 @@ namespace kiwi return false; } - FORCE_INLINE bool testEqMask(uint64_t m, size_t offset, size_t size, size_t& ret) + inline bool testEqMask(uint64_t m, size_t offset, size_t size, size_t& ret) { uint32_t b = utils::countTrailingZeroes(m); if (m && (offset + b) < size) @@ -661,7 +661,7 @@ namespace kiwi template ARCH_TARGET("avx2") - FORCE_INLINE bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + inline bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -850,7 +850,7 @@ namespace kiwi template ARCH_TARGET("avx2") - FORCE_INLINE ValueTy nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target) + inline ValueTy nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target) { if (size < (n + 1) / 2) { @@ -962,8 +962,8 @@ namespace kiwi } template - ARCH_TARGET("avx512bw") - FORCE_INLINE bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) + ARCH_TARGET("avx,avx2,avx512f,avx512bw,avx512dq") + inline bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -1152,8 +1152,8 @@ namespace kiwi } template - ARCH_TARGET("avx512bw") - FORCE_INLINE ValueTy nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target) + ARCH_TARGET("avx,avx2,avx512f,avx512bw,avx512dq") + inline ValueTy nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target) { if (size < (n + 1) / 2) { @@ -1377,7 +1377,7 @@ namespace kiwi { template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int8_t* keys, size_t size, int8_t target, size_t& ret) + inline bool nstSearchNeon(const int8_t* keys, size_t size, int8_t target, size_t& ret) { size_t i = 0; @@ -1410,7 +1410,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int16_t* keys, size_t size, int16_t target, size_t& ret) + inline bool nstSearchNeon(const int16_t* keys, size_t size, int16_t target, size_t& ret) { size_t i = 0; @@ -1440,7 +1440,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int32_t* keys, size_t size, int32_t target, size_t& ret) + inline bool nstSearchNeon(const int32_t* keys, size_t size, int32_t target, size_t& ret) { size_t i = 0; @@ -1470,7 +1470,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int64_t* keys, size_t size, int64_t target, size_t& ret) + inline bool nstSearchNeon(const int64_t* keys, size_t size, int64_t target, size_t& ret) { size_t i = 0; diff --git a/src/search.h b/src/search.h index 6ec129f3..1e153452 100644 --- a/src/search.h +++ b/src/search.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include From 5e4eb305c10ce20ffc181ce677094a1bf010e603 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 3 Mar 2025 19:54:04 +0900 Subject: [PATCH 25/53] Implement tools --- tools/diff_tokens.cpp | 245 +++++++++++++++++++++++++++++++++++++++ tools/evaluator_main.cpp | 39 ++----- tools/runner.cpp | 2 +- tools/toolUtils.h | 42 ++++++- 4 files changed, 294 insertions(+), 34 deletions(-) create mode 100644 tools/diff_tokens.cpp diff --git a/tools/diff_tokens.cpp b/tools/diff_tokens.cpp new file mode 100644 index 00000000..436077ad --- /dev/null +++ b/tools/diff_tokens.cpp @@ -0,0 +1,245 @@ +#include +#include + +#include +#include + +#include "toolUtils.h" + +using namespace std; +using namespace kiwi; +using namespace TCLAP; + +Kiwi loadKiwiFromArg(const string& model, const string& modelType, size_t numThreads = 2) +{ + ModelType kiwiModelType = tutils::parseModelType(modelType); + BuildOption opt = BuildOption::default_; + opt &= ~BuildOption::loadMultiDict; + KiwiBuilder builder{ model, numThreads < 2 ? 2 : numThreads, opt, kiwiModelType }; + return builder.build(); +} + +inline bool isEqual(const TokenInfo* a, size_t aSize, const TokenInfo* b, size_t bSize, bool ignoreTag = false) +{ + if (aSize != bSize) return false; + for (size_t i = 0; i < aSize; ++i) + { + if (a[i].str != b[i].str) return false; + if (!ignoreTag && a[i].tag != b[i].tag) return false; + } + return true; +} + +inline ostream& operator<<(ostream& ostr, const TokenInfo& token) +{ + return ostr << utf16To8(token.str) << '/' << tagToString(token.tag); +} + +bool printDiffTokens(ostream& ostr, const string& raw, const TokenInfo* a, size_t aSize, const TokenInfo* b, size_t bSize, bool ignoreTag = false, bool showSame = false) +{ + if (isEqual(a, aSize, b, bSize, ignoreTag) == showSame) return false; + ostr << raw << '\t'; + for (size_t i = 0; i < aSize; ++i) + { + if (i) ostr << ' '; + ostr << a[i]; + } + if (!showSame || ignoreTag) + { + ostr << '\t'; + for (size_t i = 0; i < bSize; ++i) + { + if (i) ostr << ' '; + ostr << b[i]; + } + } + ostr << endl; + return true; +} + +pair diffTokens(ostream& ostr, const string& raw, const TokenResult& a, const TokenResult& b, bool sentenceLevel, bool ignoreTag = false, bool showSame = false) +{ + size_t diff = 0, total = 0; + if (sentenceLevel) + { + thread_local vector> aBounds, bBounds, sentBounds; + aBounds.clear(); + bBounds.clear(); + sentBounds.clear(); + auto& aTokens = a.first; + auto& bTokens = b.first; + for (size_t i = 1; i < aTokens.size(); ++i) + { + if (aTokens[i - 1].sentPosition != aTokens[i].sentPosition) + { + aBounds.emplace_back(aTokens[i - 1].endPos(), i); + } + } + + for (size_t i = 1; i < bTokens.size(); ++i) + { + if (bTokens[i - 1].sentPosition != bTokens[i].sentPosition) + { + bBounds.emplace_back(bTokens[i - 1].endPos(), i); + } + } + + // find intersection between aBounds and bBounds and store in sentBounds + sentBounds.emplace_back(0, 0); + auto aIt = aBounds.begin(); + auto bIt = bBounds.begin(); + while (aIt != aBounds.end() && bIt != bBounds.end()) + { + if (aIt->first < bIt->first) + { + ++aIt; + } + else if (aIt->first > bIt->first) + { + ++bIt; + } + else + { + sentBounds.emplace_back(aIt->second, bIt->second); + ++aIt; + ++bIt; + } + } + sentBounds.emplace_back(aTokens.size(), bTokens.size()); + + const u16string u16raw = utf8To16(raw); + + for (size_t i = 1; i < sentBounds.size(); ++i) + { + const auto aStart = sentBounds[i - 1].first; + const auto aEnd = sentBounds[i].first; + const auto bStart = sentBounds[i - 1].second; + const auto bEnd = sentBounds[i].second; + const auto rawSent = u16raw.substr(aTokens[aStart].position, aTokens[aEnd - 1].endPos() - aTokens[aStart].position); + const bool isDiff = printDiffTokens(ostr, utf16To8(rawSent), aTokens.data() + aStart, aEnd - aStart, bTokens.data() + bStart, bEnd - bStart, ignoreTag, showSame); + if (isDiff) ++diff; + ++total; + } + } + else + { + const bool isDiff = printDiffTokens(ostr, raw, a.first.data(), a.first.size(), b.first.data(), b.first.size(), ignoreTag, showSame); + if (isDiff) ++diff; + ++total; + } + return { diff, total }; +} + +pair diffInputs(Kiwi& kiwiA, Kiwi& kiwiB, const string& inputs, ostream& ostr, bool sentenceLevel, bool ignoreTag = false, bool showSame = false) +{ + ifstream ifs{ inputs }; + if (!ifs) + { + cerr << "Cannot open " << inputs << endl; + return { 0, 0 }; + } + string line; + deque, future>> futures; + auto* poolA = kiwiA.getThreadPool(); + auto* poolB = kiwiB.getThreadPool(); + size_t diff = 0, total = 0; + + while (getline(ifs, line)) + { + while (futures.size() > kiwiA.getNumThreads() * 2) + { + auto rawInput = move(get<0>(futures.front())); + auto resultA = get<1>(futures.front()).get(); + auto resultB = get<2>(futures.front()).get(); + futures.pop_front(); + + auto p = diffTokens(ostr, rawInput, resultA, resultB, sentenceLevel, ignoreTag, showSame); + diff += p.first; + total += p.second; + } + + futures.emplace_back( + line, + poolA->enqueue([&, line](size_t tid) { return kiwiA.analyze(line, Match::allWithNormalizing);}), + poolB->enqueue([&, line](size_t tid) { return kiwiB.analyze(line, Match::allWithNormalizing);}) + ); + } + + while (!futures.empty()) + { + auto rawInput = move(get<0>(futures.front())); + auto resultA = get<1>(futures.front()).get(); + auto resultB = get<2>(futures.front()).get(); + futures.pop_front(); + + auto p = diffTokens(ostr, rawInput, resultA, resultB, sentenceLevel, ignoreTag, showSame); + diff += p.first; + total += p.second; + } + return { diff, total }; +} + +int main(int argc, const char* argv[]) +{ + tutils::setUTF8Output(); + + CmdLine cmd{ "Kiwi Diff Tokenizations" }; + + ValueArg modelA{ "", "model-a", "Model A path", true, "", "string" }; + ValueArg modelAType{ "", "model-a-type", "Model A Type", false, "none", "string" }; + ValueArg modelB{ "", "model-b", "Model B path", true, "", "string" }; + ValueArg modelBType{ "", "model-b-type", "Model B Type", false, "none", "string" }; + ValueArg output{ "o", "output", "output path", false, "", "string" }; + ValueArg numThreads{ "t", "threads", "number of threads", false, 2, "int" }; + SwitchArg sentence{ "", "sentence", "diff in sentence level", false }; + SwitchArg ignoreTag{ "i", "ignore-tag", "ignore tag", false }; + SwitchArg showSame{ "s", "show-same", "show the same result only", false }; + SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; + SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; + SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; + ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float" }; + SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; + SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; + SwitchArg lTypo{ "", "ltypo", "make lengthening-typo-tolerant model", false }; + UnlabeledMultiArg inputs{ "inputs", "targets", false, "string" }; + + cmd.add(modelA); + cmd.add(modelAType); + cmd.add(modelB); + cmd.add(modelBType); + cmd.add(output); + cmd.add(inputs); + cmd.add(numThreads); + cmd.add(sentence); + cmd.add(ignoreTag); + cmd.add(showSame); + + try + { + cmd.parse(argc, argv); + } + catch (const ArgException& e) + { + cerr << "error: " << e.error() << " for arg " << e.argId() << endl; + return -1; + } + + Kiwi kiwiA = loadKiwiFromArg(modelA, modelAType, numThreads); + Kiwi kiwiB = loadKiwiFromArg(modelB, modelBType, numThreads); + + unique_ptr ofstr; + ostream* ostr = &cout; + if (!output.getValue().empty()) + { + ofstr = std::make_unique(output); + ostr = ofstr.get(); + } + + for (auto& input : inputs) + { + cout << "input: " << input << " "; + cout.flush(); + auto p = diffInputs(kiwiA, kiwiB, input, *ostr, sentence, ignoreTag, showSame); + cout << "(diff: " << p.first << " / " << p.second << ")" << endl; + } +} diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index 483771e8..1e967e97 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -14,6 +14,8 @@ using namespace TCLAP; int main(int argc, const char* argv[]) { + tutils::setUTF8Output(); + CmdLine cmd{ "Kiwi evaluator" }; ValueArg model{ "m", "model", "Kiwi model path", false, "models/base", "string" }; @@ -21,7 +23,7 @@ int main(int argc, const char* argv[]) SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; - ValueArg modelType{ "t", "type", "model type", false, "knlm", "string" }; + ValueArg modelType{ "t", "type", "model type", false, "none", "string" }; ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float"}; SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; @@ -52,41 +54,14 @@ int main(int argc, const char* argv[]) return -1; } ModelType kiwiModelType = ModelType::none; - { - auto v = modelType.getValue(); - if (v == "knlm") - { - kiwiModelType = ModelType::knlm; - } - else if (v == "sbg") - { - kiwiModelType = ModelType::sbg; - } - else if (v == "knlm-transposed") - { - kiwiModelType = ModelType::knlmTransposed; - } - else if (v == "cong") - { - kiwiModelType = ModelType::cong; - } - else if (v == "cong-global") - { - kiwiModelType = ModelType::congGlobal; - } - else if (v == "cong-fp32") - { - kiwiModelType = ModelType::congFp32; - } - else if (v == "cong-global-fp32") + try { - kiwiModelType = ModelType::congGlobalFp32; + kiwiModelType = tutils::parseModelType(modelType); } - else + catch (const exception& e) { - cerr << "Invalid model type" << endl; + cerr << e.what() << endl; return -1; - } } vector morphInputs, disambInputs; diff --git a/tools/runner.cpp b/tools/runner.cpp index 3b8f8b7e..96e1f549 100644 --- a/tools/runner.cpp +++ b/tools/runner.cpp @@ -46,7 +46,7 @@ int run(const string& modelPath, bool benchmark, const string& output, const str { cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; cout << "ArchType : " << archToStr(kw.archType()) << endl; - cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; + cout << "LM Size : " << (kw.getLangModel()->getMemorySize() / 1024. / 1024.) << " MB" << endl; cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB" << endl; cout << "ModelType : " << (sbg ? "sbg" : "knlm") << endl; } diff --git a/tools/toolUtils.h b/tools/toolUtils.h index 04b002ae..327a9c5c 100644 --- a/tools/toolUtils.h +++ b/tools/toolUtils.h @@ -102,4 +102,44 @@ namespace tutils { } #endif -} \ No newline at end of file + + inline kiwi::ModelType parseModelType(const std::string& v) + { + if (v == "none") + { + return kiwi::ModelType::none; + } + else if (v == "knlm") + { + return kiwi::ModelType::knlm; + } + else if (v == "sbg") + { + return kiwi::ModelType::sbg; + } + else if (v == "knlm-transposed") + { + return kiwi::ModelType::knlmTransposed; + } + else if (v == "cong") + { + return kiwi::ModelType::cong; + } + else if (v == "cong-global") + { + return kiwi::ModelType::congGlobal; + } + else if (v == "cong-fp32") + { + return kiwi::ModelType::congFp32; + } + else if (v == "cong-global-fp32") + { + return kiwi::ModelType::congGlobalFp32; + } + else + { + throw std::invalid_argument{ "Invalid model type" }; + } + } +} From 6a9b3299198515f5cc0dd54e53000b05f12dc20e Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 3 Mar 2025 19:55:36 +0900 Subject: [PATCH 26/53] Update CMakeLists.txt --- CMakeLists.txt | 9 ++------- third_party/cpuinfo | 2 +- vsproj/kiwi_shared_library.vcxproj | 6 ++++++ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 79092974..3385aae2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,9 +106,6 @@ if(KIWI_USE_CPUINFO) set ( ADDITIONAL_FLAGS ${ADDITIONAL_FLAGS} "-DKIWI_USE_CPUINFO" ) if(MSVC) - target_compile_options("clog" PUBLIC - /MT - ) target_compile_options("cpuinfo" PUBLIC /MT ) @@ -118,11 +115,9 @@ if(KIWI_USE_CPUINFO) endif() set ( CPUINFO_OBJECTS_STATIC - $ $ ) set ( CPUINFO_OBJECTS_SHARED - $ $ ) endif() @@ -184,8 +179,8 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma") set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavxvnni") - set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw") - set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vnni") + set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw") + set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw -mavx512vnni") endif() endif() elseif (KIWI_CPU_ARCH MATCHES "arm64") diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 866ae6e5..05dd959f 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 866ae6e5ffe93a1f63be738078da94cf3005cce2 +Subproject commit 05dd959fa26c7e68fa229495a35f55e06a3b9655 diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index dc2a5e0c..89abd67d 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -56,11 +56,14 @@ + + + @@ -69,12 +72,15 @@ + + + From af4f11695d52b801822c5b6ac36011c839e6087b Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 3 Mar 2025 20:58:16 +0900 Subject: [PATCH 27/53] Implement `nounAugmentingProb` of `HSDataset` --- include/kiwi/Dataset.h | 11 +++++- include/kiwi/Kiwi.h | 3 ++ src/Dataset.cpp | 81 ++++++++++++++++++++++++++++++++++-------- src/KiwiBuilder.cpp | 40 +++++++++++++-------- 4 files changed, 106 insertions(+), 29 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 9d97d089..5e6931de 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -51,6 +51,7 @@ namespace kiwi std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; std::bernoulli_distribution dropoutOnHistory; + std::discrete_distribution<> nounAugmentor; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; @@ -68,6 +69,7 @@ namespace kiwi size_t totalTokens = 0; size_t passedSents = 0; size_t passedWorkItems = 0; + std::array(Kiwi::SpecialMorph::max)> specialMorphIds = { { 0, } }; size_t numValidTokensInSent(size_t sentId) const; @@ -75,7 +77,14 @@ namespace kiwi size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); public: - HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, bool _exclusiveWindow = true, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0); + HSDataset(size_t _batchSize = 0, + size_t _causalContextSize = 0, + size_t _windowSize = 0, + bool _exclusiveWindow = true, + size_t _workers = 0, + double _dropoutProb = 0, + double _dropoutProbOnHistory = 0, + double _nounAugmentingProb = 0); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 49834c6c..9a013c13 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -628,6 +628,8 @@ namespace kiwi void addAllomorphsToRule(); + std::array(Kiwi::SpecialMorph::max)> getSpecialMorphs() const; + public: /** @@ -823,6 +825,7 @@ namespace kiwi size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, double dropoutProbOnHistory = 0, + double nounAugmentingProb = 0, const TokenFilter& tokenFilter = {}, const TokenFilter& windowFilter = {}, double splitRatio = 0, diff --git a/src/Dataset.cpp b/src/Dataset.cpp index 1b17add8..3a30a553 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -5,13 +5,18 @@ using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, - size_t _windowSize, bool _exclusiveWindow, +HSDataset::HSDataset(size_t _batchSize, + size_t _causalContextSize, + size_t _windowSize, + bool _exclusiveWindow, size_t _workers, - double _dropoutProb, double _dropoutProbOnHistory) + double _dropoutProb, + double _dropoutProbOnHistory, + double _nounAugmentingProb) : workers{ _workers ? make_unique(_workers) : nullptr }, - dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, + dropout{ {1 - _dropoutProb, _dropoutProb / 3, _dropoutProb / 3, _dropoutProb / 6, _dropoutProb / 6} }, dropoutOnHistory{ _dropoutProbOnHistory }, + nounAugmentor{ {1 - _nounAugmentingProb, _nounAugmentingProb / 9, _nounAugmentingProb / 9, _nounAugmentingProb / 9, _nounAugmentingProb / 3, _nounAugmentingProb / 3} }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, causalContextSize{ _causalContextSize }, @@ -101,20 +106,64 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, tokens.emplace_back(sent[0]); for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p) { - auto t = *p; - switch (dropout(local.rng)) + const auto t = *p; + const auto nounAugment = ((*morphemes)[t].tag == POSTag::nnp && !isSpecialClass((*morphemes)[*(p + 1)].tag)) ? nounAugmentor(local.rng) : 0; + + switch (nounAugment) { - case 0: // no dropout - tokens.emplace_back(t); + case 1: // circumfix with sso and ssc + tokens.emplace_back(getDefaultMorphemeId(POSTag::sso)); + break; + case 2: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::singleQuoteOpen]); break; - case 1: // replacement - tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + case 3: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::doubleQuoteOpen]); break; - case 2: // deletion + case 4: // circumfix with sw + tokens.emplace_back(getDefaultMorphemeId(POSTag::sw)); break; - case 3: // insertion - tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); - tokens.emplace_back(t); + case 5: // replace with w_hashtag + tokens.emplace_back(getDefaultMorphemeId(POSTag::w_hashtag)); + break; + } + + if (nounAugment != 5) + { + switch (dropout(local.rng)) + { + case 0: // no dropout + tokens.emplace_back(t); + break; + case 1: // replacement + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; + case 2: // deletion + break; + case 3: // insertion + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + tokens.emplace_back(t); + break; + case 4: // insertion + tokens.emplace_back(t); + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; + } + } + + switch (nounAugment) + { + case 1: // circumfix with sso and ssc + tokens.emplace_back(getDefaultMorphemeId(POSTag::ssc)); + break; + case 2: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::singleQuoteClose]); + break; + case 3: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::doubleQuoteClose]); + break; + case 4: // circumfix with sw + tokens.emplace_back(getDefaultMorphemeId(POSTag::sw)); break; } } @@ -434,6 +483,10 @@ std::vector HSDataset::getAugmentedSent(size_t idx) ret.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); ret.emplace_back(t); break; + case 4: // insertion + ret.emplace_back(t); + ret.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; } } ret.emplace_back(*sent.rbegin()); diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 75bf48c0..f15efe4a 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -2205,22 +2205,29 @@ Kiwi KiwiBuilder::build(const TypoTransformer& typos, float typoCostThreshold) c ret.formTrie = freezeTrie(move(formTrie), archType); - for (auto& m : ret.morphemes) + ret.specialMorphIds = getSpecialMorphs(); + return ret; +} + +std::array(Kiwi::SpecialMorph::max)> KiwiBuilder::getSpecialMorphs() const +{ + std::array(Kiwi::SpecialMorph::max)> specialMorphIds = { {0,} }; + for (auto& m : morphemes) { - if (m.kform && *m.kform == u"'") + if (forms[m.kform].form == u"'") { - if (m.tag == POSTag::sso) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteOpen)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ssc) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteClose)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ss) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteNA)] = &m - ret.morphemes.data(); + if (m.tag == POSTag::sso) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteOpen)] = &m - morphemes.data(); + else if (m.tag == POSTag::ssc) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteClose)] = &m - morphemes.data(); + else if (m.tag == POSTag::ss) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteNA)] = &m - morphemes.data(); } - else if (m.kform && *m.kform == u"\"") + else if (forms[m.kform].form == u"\"") { - if (m.tag == POSTag::sso) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteOpen)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ssc) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteClose)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ss) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteNA)] = &m - ret.morphemes.data(); + if (m.tag == POSTag::sso) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteOpen)] = &m - morphemes.data(); + else if (m.tag == POSTag::ssc) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteClose)] = &m - morphemes.data(); + else if (m.tag == POSTag::ss) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteNA)] = &m - morphemes.data(); } } - return ret; + return specialMorphIds; } vector KiwiBuilder::extractWords(const U16MultipleReader& reader, size_t minCnt, size_t maxWordLen, float minScore, float posThreshold, bool lmFilter) const @@ -2371,6 +2378,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb, double dropoutProbOnHistory, + double nounAugmentingProb, const TokenFilter& tokenFilter, const TokenFilter& windowFilter, double splitRatio, @@ -2381,14 +2389,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, HSDataset* splitDataset ) const { - HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory, nounAugmentingProb }; auto& sents = dataset.sents.get(); const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; size_t maxTokenId = 0; + shared_ptr knlm; + if (morphemeDefPath.empty()) { realMorph = restoreMorphemeMap(separateDefaultMorpheme); + knlm = dynamic_pointer_cast(langMdl); } else { @@ -2407,18 +2418,19 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, } } - auto knlm = dynamic_pointer_cast(langMdl); dataset.knlm = knlm; dataset.morphemes = &srcBuilder->morphemes; dataset.forms = &srcBuilder->forms; + dataset.specialMorphIds = getSpecialMorphs(); if (splitDataset) { *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb }; splitDataset->dummyBuilder = dataset.dummyBuilder; splitDataset->knlm = knlm; - splitDataset->morphemes = &srcBuilder->morphemes; - splitDataset->forms = &srcBuilder->forms; + splitDataset->morphemes = dataset.morphemes; + splitDataset->forms = dataset.forms; + splitDataset->specialMorphIds = dataset.specialMorphIds; } for (auto& path : inputPathes) From 595478851d6badd99774756147a1d32545b9ff0b Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 4 Mar 2025 02:31:06 +0900 Subject: [PATCH 28/53] Fix compilation errors on Arm NEON --- include/kiwi/Kiwi.h | 2 +- src/CoNgramModel.cpp | 2 +- src/Knlm.cpp | 2 +- src/MathFunc.hpp | 16 ++- src/PathEvaluator.h | 4 +- src/PathEvaluator.hpp | 2 +- src/SIMD.hpp | 28 ++-- src/SkipBigramModel.cpp | 2 +- src/archImpl/neon.cpp | 53 ++++++++ src/qgemm.cpp | 14 -- src/qgemm.hpp | 2 + src/search.cpp | 284 ++++++++++++++++++++++++++++++++++++++++ 12 files changed, 374 insertions(+), 37 deletions(-) delete mode 100644 src/qgemm.cpp diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 9a013c13..86942c79 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -58,7 +58,7 @@ namespace kiwi class Kiwi { friend class KiwiBuilder; - friend struct BestPathFinder; + template friend struct BestPathFinder; template friend struct PathEvaluator; template friend struct MorphemeEvaluator; friend class cmb::AutoJoiner; diff --git a/src/CoNgramModel.cpp b/src/CoNgramModel.cpp index 40dc38a7..43ef3392 100644 --- a/src/CoNgramModel.cpp +++ b/src/CoNgramModel.cpp @@ -1663,7 +1663,7 @@ namespace kiwi template void* CoNgramModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder>::findBestPath; } template diff --git a/src/Knlm.cpp b/src/Knlm.cpp index 58ffd0f0..4c09c230 100644 --- a/src/Knlm.cpp +++ b/src/Knlm.cpp @@ -330,7 +330,7 @@ namespace kiwi template void* KnLangModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder>::findBestPath; } template diff --git a/src/MathFunc.hpp b/src/MathFunc.hpp index fbf68602..1645bbc8 100644 --- a/src/MathFunc.hpp +++ b/src/MathFunc.hpp @@ -198,12 +198,16 @@ namespace kiwi }; template<> - struct LogSoftmaxTransposed : public LogSoftmaxTransposed + struct LogSoftmaxTransposed { + void operator()(float* arr, size_t batchSize, size_t stride) + { + throw std::runtime_error("Unsupported architecture"); + } }; template<> - struct LogSoftmaxTransposed : public LogSoftmaxTransposed + struct LogSoftmaxTransposed : public LogSoftmaxTransposed { }; @@ -271,12 +275,16 @@ namespace kiwi }; template<> - struct LogSumExpTransposed : public LogSumExpTransposed + struct LogSumExpTransposed { + void operator()(float* arr, size_t batchSize, size_t stride) + { + throw std::runtime_error("Unsupported architecture"); + } }; template<> - struct LogSumExpTransposed : public LogSumExpTransposed + struct LogSumExpTransposed : public LogSumExpTransposed { }; diff --git a/src/PathEvaluator.h b/src/PathEvaluator.h index b6d4b1d0..ef1eae65 100644 --- a/src/PathEvaluator.h +++ b/src/PathEvaluator.h @@ -83,9 +83,9 @@ namespace kiwi } }; + template struct BestPathFinder { - template static Vector findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, @@ -99,5 +99,5 @@ namespace kiwi ); }; - using FnFindBestPath = decltype(&BestPathFinder::findBestPath); + using FnFindBestPath = decltype(&BestPathFinder::findBestPath); } diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index fccfb34d..68cf16ef 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -1115,7 +1115,7 @@ namespace kiwi } template - Vector BestPathFinder::findBestPath(const Kiwi* kw, + Vector BestPathFinder::findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, diff --git a/src/SIMD.hpp b/src/SIMD.hpp index 6b9c2a71..88e795bb 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -813,28 +813,32 @@ namespace kiwi static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) { - int32x4_t pa, pb, sum = vdupq_n_s32(0); + int32x4_t sum = vdupq_n_s32(0); + uint16x8_t pa; + int8x16_t pb; for (size_t i = 0; i < size; i += 16) { - pa = vreinterpretq_s32_u32(vmovl_u16(vld1_u16(reinterpret_cast(a + i)))); - pb = vreinterpretq_s32_s8(vld1_s8(reinterpret_cast(b + i))); - sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(vreinterpretq_s8_s32(pa)), vget_low_s8(pb))); - sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(vreinterpretq_s8_s32(pa)), vget_high_s8(pb))); + // } - return vgetq_lane_s32(vpadd_s32(vpadd_s32(sum, sum), sum), 0); + sum = vpaddq_s32(sum, sum); + sum = vpaddq_s32(sum, sum); + return vgetq_lane_s32(sum, 0); } static STRONG_INLINE int32_t dotprod(const int8_t* a, const int8_t* b, size_t size) { - int32x4_t pa, pb, sum = vdupq_n_s32(0); + int32x4_t sum = vdupq_n_s32(0); + int8x16_t pa, pb; for (size_t i = 0; i < size; i += 16) { - pa = vreinterpretq_s32_s8(vld1q_s8(a + i)); - pb = vreinterpretq_s32_s8(vld1q_s8(b + i)); - sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(pa), vget_low_s8(pb))); - sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(pa), vget_high_s8(pb))); + pa = vld1q_s8(a + i); + pb = vld1q_s8(b + i); + sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(pb), vget_low_s8(pa))); + sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(pb), vget_high_s8(pa))); } - return vgetq_lane_s32(vpadd_s32(vpadd_s32(sum, sum), sum), 0); + sum = vpaddq_s32(sum, sum); + sum = vpaddq_s32(sum, sum); + return vgetq_lane_s32(sum, 0); } }; diff --git a/src/SkipBigramModel.cpp b/src/SkipBigramModel.cpp index 2fd4b0d7..36d731fb 100644 --- a/src/SkipBigramModel.cpp +++ b/src/SkipBigramModel.cpp @@ -63,7 +63,7 @@ namespace kiwi template void* SkipBigramModel::getFindBestPathFn() const { - return (void*)&BestPathFinder::findBestPath>; + return (void*)&BestPathFinder>::findBestPath; } template diff --git a/src/archImpl/neon.cpp b/src/archImpl/neon.cpp index 06c3a5ad..4ffe4c2f 100644 --- a/src/archImpl/neon.cpp +++ b/src/archImpl/neon.cpp @@ -1,4 +1,9 @@ #include "../SkipBigramModelImpl.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenNEON +#include namespace kiwi { @@ -8,5 +13,53 @@ namespace kiwi template class SkipBigramModel; template class SkipBigramModel; template class SkipBigramModel; + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/qgemm.cpp b/src/qgemm.cpp deleted file mode 100644 index b750b0d4..00000000 --- a/src/qgemm.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include -#include -#include -#include -#include "qgemm.hpp" -#include "SIMD.hpp" - -namespace kiwi -{ - namespace qgemm - { - - } -} diff --git a/src/qgemm.hpp b/src/qgemm.hpp index 765af5a0..0c34d246 100644 --- a/src/qgemm.hpp +++ b/src/qgemm.hpp @@ -186,5 +186,7 @@ namespace kiwi } return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); } + + // real implementations are in `archImpl/.cpp` } } diff --git a/src/search.cpp b/src/search.cpp index e3249872..140a18cf 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -1498,6 +1498,278 @@ namespace kiwi return false; } + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int8_t target) + { + using IntTy = int8_t; + size_t i = 0; + int8x16_t ptarget = vdupq_n_s8(target), pkey; + uint8x16_t peq, pgt, pmasked; + + static const uint8x16_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128 }; + + if (size < n) + { + pkey = vld1q_s8(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s8(ptarget, pkey); + pmasked = vandq_u8(peq, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + uint32_t mm = mm0 | ((uint32_t)mm1 << 8); + uint32_t r = utils::countTrailingZeroes(mm); + + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s8(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s8(ptarget, pkey); + pgt = vcgtq_s8(ptarget, pkey); + pmasked = vandq_u8(peq, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + uint32_t mm = mm0 | ((uint32_t)mm1 << 8); + uint32_t r = utils::countTrailingZeroes(mm); + + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u8(vandq_u8(pgt, vdupq_n_u8(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int16_t target) + { + using IntTy = int16_t; + size_t i = 0; + int16x8_t ptarget = vdupq_n_s16(target), pkey; + uint16x8_t peq, pgt; + + static const uint16x8_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128 }; + + if (size < n) + { + pkey = vld1q_s16(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s16(ptarget, pkey); + uint32_t mm = vaddvq_u16(vandq_u16(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s16(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s16(ptarget, pkey); + pgt = vcgtq_s16(ptarget, pkey); + uint32_t mm = vaddvq_u16(vandq_u16(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u16(vandq_u16(pgt, vdupq_n_u16(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int32_t target) + { + using IntTy = int32_t; + size_t i = 0; + int32x4_t ptarget = vdupq_n_s32(target), pkey; + uint32x4_t peq, pgt; + + static const uint32x4_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8 }; + + if (size < n) + { + pkey = vld1q_s32(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s32(ptarget, pkey); + uint32_t mm = vaddvq_u32(vandq_u32(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s32(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s32(ptarget, pkey); + pgt = vcgtq_s32(ptarget, pkey); + uint32_t mm = vaddvq_u32(vandq_u32(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u32(vandq_u32(pgt, vdupq_n_u32(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int64_t target) + { + using IntTy = int64_t; + size_t i = 0; + int64x2_t ptarget = vdupq_n_s64(target), pkey; + uint64x2_t peq, pgt; + + static const uint64x2_t __attribute__((aligned(16))) mask = { 1, 2 }; + + if (size < n) + { + pkey = vld1q_s64(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s64(ptarget, pkey); + uint32_t mm = vaddvq_u64(vandq_u64(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s64(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s64(ptarget, pkey); + pgt = vcgtq_s64(ptarget, pkey); + uint32_t mm = vaddvq_u64(vandq_u64(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u64(vandq_u64(pgt, vdupq_n_u64(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + ARCH_TARGET("arch=armv8-a") + inline size_t findAllNeon(const uint8_t* arr, size_t size, uint8_t key) + { + int8x16_t pkey = vdupq_n_s8(key); + static const uint8x16_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128 }; + uint8x16_t pmasked; + if (size <= 16) + { + int8x16_t parr = vld1q_s8(reinterpret_cast(arr)); + uint8x16_t pcmp = vceqq_s8(pkey, parr); + pmasked = vandq_u8(pcmp, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + return (mm0 | ((size_t)mm1 << 8)) & (((size_t)1 << size) - 1); + } + else if (size <= 32) + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16)) & (((size_t)1 << size) - 1); + } + else if (size <= 48) + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + int8x16_t parr2 = vld1q_s8(reinterpret_cast(arr + 32)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + uint8x16_t pcmp2 = vceqq_s8(pkey, parr2); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp2, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r2 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16) | (r2 << 32)) & (((size_t)1 << size) - 1); + } + else + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + int8x16_t parr2 = vld1q_s8(reinterpret_cast(arr + 32)); + int8x16_t parr3 = vld1q_s8(reinterpret_cast(arr + 48)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + uint8x16_t pcmp2 = vceqq_s8(pkey, parr2); + uint8x16_t pcmp3 = vceqq_s8(pkey, parr3); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp2, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r2 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp3, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r3 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16) | (r2 << 32) | (r3 << 48)) & (((size_t)1 << size) - 1); + } + } + + template<> struct OptimizedImpl { @@ -1516,6 +1788,18 @@ namespace kiwi using SignedIntTy = typename SignedType::type; return nstSearchNeon((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVNeon((const uint8_t*)kv, size, (SignedIntTy)target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllNeon(arr, size, key); + } }; INSTANTIATE_IMPL(ArchType::neon); } From 463778663318412d6bfbaa06ec367964763dacf1 Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 4 Mar 2025 02:33:18 +0900 Subject: [PATCH 29/53] Remove `qgemm.cpp` --- CMakeLists.txt | 1 - vsproj/kiwi_shared_library.vcxproj | 1 - 2 files changed, 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3385aae2..62568eae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,7 +56,6 @@ set ( CORE_SRCS src/Knlm.cpp src/KTrie.cpp src/PatternMatcher.cpp - src/qgemm.cpp src/search.cpp src/ScriptType.cpp src/SkipBigramModel.cpp diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index 89abd67d..b2f49933 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -141,7 +141,6 @@ - From db1329a925a87234dfcf7af3680039171acc3e39 Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 21 Mar 2025 01:05:34 +0900 Subject: [PATCH 30/53] Fix missing form filtering logic of CoNgramModel --- src/CoNgramModel.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/CoNgramModel.cpp b/src/CoNgramModel.cpp index 43ef3392..5639bd0d 100644 --- a/src/CoNgramModel.cpp +++ b/src/CoNgramModel.cpp @@ -122,7 +122,7 @@ namespace kiwi nextWids.insert(nextWids.end(), nextDistantWids.begin(), nextDistantWids.end()); } - if (nextWids.size() > 0) + if (prevLmStates.size() > 0 && nextWids.size() > 0) { if (prevLmStates.size() == 1 && nextWids.size() == 1) { @@ -278,6 +278,10 @@ namespace kiwi else continue; } Wid firstWid = morphBase[prevPath->wid].getCombined()->lmMorphemeId; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + auto state = prevPath->lmState; score += state.next(langMdl, firstWid); From ad17bfbeda60578887fd41cd6d053a87482b3d32 Mon Sep 17 00:00:00 2001 From: bab2min Date: Fri, 21 Mar 2025 01:07:41 +0900 Subject: [PATCH 31/53] Implement `generateUnlikelihoods` --- include/kiwi/Dataset.h | 38 ++- include/kiwi/Kiwi.h | 25 +- include/kiwi/SubstringExtractor.h | 1 + src/Dataset.cpp | 445 ++++++++++++++++++++---------- src/KiwiBuilder.cpp | 129 +++++++-- src/SubstringExtractor.cpp | 5 + 6 files changed, 468 insertions(+), 175 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 5e6931de..69e32c73 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -31,31 +31,36 @@ namespace kiwi struct ThreadLocal { std::mt19937_64 rng; - Vector tokenBuf, contextualTokenBuf; + Vector tokenBuf; Vector lmLProbsBuf; Vector outNgramNodeBuf; - Deque historyBuf; Deque inData; Deque outData; Deque lmLProbsData; Deque outNgramNodeData; Deque restLmLProbsData; Deque restLmLProbsCntData; + Vector> unlikelihoodBuf; + Deque unlikelihoodInData; + Deque unlikelihoodOutData; }; static constexpr int32_t nonVocab = -1; - HiddenMember, sizeof(Vector) * 2> sents; - std::shared_ptr knlm; + HiddenMember, sizeof(Vector) * 2> sents; + std::shared_ptr langModel; + std::shared_ptr kiwiInst; + std::shared_ptr>> oovDict; std::unique_ptr workers; std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; - std::bernoulli_distribution dropoutOnHistory; + float dropoutProbOnHistory = 0; std::discrete_distribution<> nounAugmentor; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; Vector tokenToVocab, vocabToToken; + Vector windowTokenValidness; Deque> futures; const Vector* morphemes = nullptr; @@ -66,6 +71,7 @@ namespace kiwi size_t causalContextSize = 0; size_t windowSize = 0; bool exclusiveWindow = true; + size_t generateUnlikelihoods = -1; size_t totalTokens = 0; size_t passedSents = 0; size_t passedWorkItems = 0; @@ -73,8 +79,14 @@ namespace kiwi size_t numValidTokensInSent(size_t sentId) const; - template - size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); + template + void prepareInOutData(Deque& inData, Deque& outData, const Vector& tokens, std::mt19937_64& rng) const; + + bool tokenizeUnlikely(Vector>& out, int32_t prefix, int32_t target, int32_t suffix, std::mt19937_64& rng) const; + + template + size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + UlInTy unlikelihoodIn, UlOutTy unlikelihoodOut, size_t* unlikelihoodSize); public: HSDataset(size_t _batchSize = 0, @@ -84,7 +96,8 @@ namespace kiwi size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0, - double _nounAugmentingProb = 0); + double _nounAugmentingProb = 0, + size_t _generateUnlikelihoods = -1); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; @@ -94,6 +107,7 @@ namespace kiwi size_t numEstimBatches() const; size_t numSents() const; size_t numTokens() const; + bool doesGenerateUnlikelihoods() const { return generateUnlikelihoods < (size_t)-1; } size_t getBatchSize() const { return batchSize; } size_t getCausalContextSize() const { return causalContextSize; } @@ -102,8 +116,10 @@ namespace kiwi void seed(size_t newSeed); void reset(); - size_t next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut); - size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut); + size_t next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int32_t* unlikelihoodIn = nullptr, int32_t* unlikelihoodOut = nullptr, size_t* unlikelihoodSize = nullptr); + size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int64_t* unlikelihoodIn = nullptr, int64_t* unlikelihoodOut = nullptr, size_t* unlikelihoodSize = nullptr); size_t vocabSize() const { return vocabToToken.size(); } size_t getKnlmVocabSize() const; @@ -112,7 +128,7 @@ namespace kiwi std::u16string vocabForm(uint32_t vocab) const; std::vector estimVocabFrequency() const; - Range::const_iterator> getSent(size_t idx) const; + Range::const_iterator> getSent(size_t idx) const; std::vector getAugmentedSent(size_t idx); std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1, bool exclusiveCnt = false) const; diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 86942c79..49516a7d 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -577,11 +577,22 @@ namespace kiwi MorphemeMap restoreMorphemeMap(bool separateDefaultMorpheme = false) const; template - void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const; - - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; + void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, + UnorderedMap, size_t>* oovDict = nullptr) const; + + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; void updateForms(); void updateMorphemes(size_t vocabSize = 0); @@ -816,7 +827,8 @@ namespace kiwi const std::vector& inputPathes, const std::string& outputPath, const std::string& morphemeDefPath = {}, - size_t morphemeDefMinCnt = 0 + size_t morphemeDefMinCnt = 0, + bool generateOovDict = false ) const; using TokenFilter = std::function; @@ -826,6 +838,7 @@ namespace kiwi double dropoutProb = 0, double dropoutProbOnHistory = 0, double nounAugmentingProb = 0, + size_t generateUnlikelihoods = -1, const TokenFilter& tokenFilter = {}, const TokenFilter& windowFilter = {}, double splitRatio = 0, diff --git a/include/kiwi/SubstringExtractor.h b/include/kiwi/SubstringExtractor.h index 67115867..0b07a6db 100644 --- a/include/kiwi/SubstringExtractor.h +++ b/include/kiwi/SubstringExtractor.h @@ -39,6 +39,7 @@ namespace kiwi ); void addArray(const uint16_t* first, const uint16_t* last); void addArray(const uint32_t* first, const uint32_t* last); + void addArray(const int32_t* first, const int32_t* last); void addArray(const uint64_t* first, const uint64_t* last); utils::FrozenTrie count() const; std::unique_ptr buildLM( diff --git a/src/Dataset.cpp b/src/Dataset.cpp index 3a30a553..dcc0de38 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include "FrozenTrie.hpp" #include "RaggedVector.hpp" @@ -12,16 +12,25 @@ HSDataset::HSDataset(size_t _batchSize, size_t _workers, double _dropoutProb, double _dropoutProbOnHistory, - double _nounAugmentingProb) + double _nounAugmentingProb, + size_t _generateUnlikelihoods) : workers{ _workers ? make_unique(_workers) : nullptr }, dropout{ {1 - _dropoutProb, _dropoutProb / 3, _dropoutProb / 3, _dropoutProb / 6, _dropoutProb / 6} }, - dropoutOnHistory{ _dropoutProbOnHistory }, - nounAugmentor{ {1 - _nounAugmentingProb, _nounAugmentingProb / 9, _nounAugmentingProb / 9, _nounAugmentingProb / 9, _nounAugmentingProb / 3, _nounAugmentingProb / 3} }, + dropoutProbOnHistory{ (float)_dropoutProbOnHistory }, + nounAugmentor{ { + 1 - _nounAugmentingProb, + _nounAugmentingProb / 12, + _nounAugmentingProb / 12, + _nounAugmentingProb / 12, + _nounAugmentingProb / 4, + _nounAugmentingProb / 4, + _nounAugmentingProb / 4} }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, causalContextSize{ _causalContextSize }, windowSize{ _windowSize }, - exclusiveWindow{ _exclusiveWindow } + exclusiveWindow{ _exclusiveWindow }, + generateUnlikelihoods{ _generateUnlikelihoods } { } @@ -84,20 +93,228 @@ size_t HSDataset::numValidTokensInSent(size_t sentId) const size_t c = 0; for (auto t : sents.get()[sentId]) { + if (oovDict && t < 0) + { + POSTag tag = (*oovDict)[-t - 1].second; + t = getDefaultMorphemeId(clearIrregular(tag)); + } + if (tokenToVocab[t] == nonVocab) continue; ++c; } return c; } -template -size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +bool HSDataset::tokenizeUnlikely(Vector>& out, int32_t prefix, int32_t target, int32_t suffix, std::mt19937_64& rng) const +{ + auto form = (oovDict && target < 0) ? (*oovDict)[-target - 1].first : joinHangul((*forms)[(*morphemes)[target].kform].form); + + if (oovDict && prefix < 0) prefix = getDefaultMorphemeId((*oovDict)[-prefix - 1].second); + if (oovDict && suffix < 0) suffix = getDefaultMorphemeId((*oovDict)[-suffix - 1].second); + auto prefixForm = joinHangul((*forms)[(*morphemes)[prefix].kform].form); + auto suffixForm = joinHangul((*forms)[(*morphemes)[suffix].kform].form); + if (form.size() < 2) return false; + auto blocklist = kiwiInst->findMorpheme(form); + std::unordered_set blockset(blocklist.begin(), blocklist.end()); + + thread_local std::vector pretokenized; + pretokenized.clear(); + pretokenized.emplace_back(0, 1, std::vector{ BasicToken(prefixForm, -1, -1, (*morphemes)[prefix].tag) }); + pretokenized.emplace_back(form.size() + 1, form.size() + 2, std::vector{ BasicToken(suffixForm, -1, -1, (*morphemes)[suffix].tag) }); + + form.insert(form.begin(), ' '); + form.push_back(' '); + auto res = kiwiInst->analyze(form, 8, Match::allWithNormalizing, &blockset, pretokenized); + thread_local Vector validResIdx; + validResIdx.clear(); + for (size_t i = 0; i < res.size(); ++i) + { + auto& tokens = res[i].first; + if (tokens.size() <= 3) continue; + if (std::all_of(tokens.begin() + 1, tokens.end() - 1, [&](const TokenInfo& t) + { + return t.morph && !t.morph->getForm().empty() /*&& t.morph->lmMorphemeId != getDefaultMorphemeId(t.morph->tag)*/; + })) + { + validResIdx.emplace_back(i); + } + } + if (validResIdx.empty()) return false; + const float r = std::generate_canonical(rng); + auto& tokens = res[validResIdx[(size_t)(r * (float)validResIdx.size())]].first; + for (size_t i = 1; i < tokens.size() - 1; ++i) + { + out.emplace_back(tokens[i].morph->lmMorphemeId, tokens[i].morph->lmMorphemeId); + } + return true; +} + +inline int32_t getInput(int32_t t, const Vector>* oovDict) +{ + if (oovDict && t < 0) + { + POSTag tag = (*oovDict)[-t - 1].second; + return getDefaultMorphemeId(clearIrregular(tag)); + } + return t; +} + +inline int32_t getOutput(int32_t t, const Vector>* oovDict) +{ + return getInput(t, oovDict); +} + +inline int32_t getInput(const std::pair& t, const Vector>* oovDict) +{ + return getInput(t.first, oovDict); +} + +inline int32_t getOutput(const std::pair& t, const Vector>* oovDict) +{ + return getOutput(t.second, oovDict); +} + +template +void HSDataset::prepareInOutData(Deque& inData, Deque& outData, const Vector& tokens, std::mt19937_64& rng) const +{ + thread_local Deque history; + thread_local Vector contextualTokens; + if (windowSize) + { + history.clear(); + history.resize(windowSize, -1); + if (windowTokenValidness[getInput(tokens[0], oovDict.get())]) + { + history.back() = tokenToVocab[getInput(tokens[0], oovDict.get())]; + } + } + + if (causalContextSize && contextualMapper.size()) + { + auto* node = contextualMapper.root(); + contextualTokens.clear(); + contextualTokens.reserve(tokens.size()); + for (size_t i = 0; i < tokens.size(); ++i) + { + const int32_t v = tokenToVocab[getInput(tokens[i], oovDict.get())]; + auto* next = node->template nextOpt(contextualMapper, v); + while (!next) + { + node = node->fail(); + if (!node) break; + next = node->template nextOpt(contextualMapper, v); + } + if (next) + { + auto val = next->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + contextualTokens.emplace_back(val - 1); + } + else if (contextualMapper.hasSubmatch(val)) + { + auto sub = next->fail(); + for (; sub; sub = sub->fail()) + { + val = sub->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + break; + } + } + if (sub) contextualTokens.emplace_back(val - 1); + else contextualTokens.emplace_back(nonVocab); + } + node = next; + } + else + { + contextualTokens.emplace_back(nonVocab); + node = contextualMapper.root(); + } + } + } + + int32_t lastV = nonVocab; + for (size_t i = 1; i < tokens.size(); ++i) + { + const int32_t v = tokenToVocab[getInput(tokens[i], oovDict.get())]; + if (v == nonVocab) + { + continue; + } + const int32_t outV = getOutput(tokens[i], oovDict.get()) == 0 ? nonVocab : tokenToVocab[getOutput(tokens[i], oovDict.get())]; + + if (causalContextSize) + { + for (size_t j = 0; j < causalContextSize; ++j) + { + if (i + j < causalContextSize) + { + if (outV != nonVocab) inData.emplace_back(nonVocab); + } + else if (contextualMapper.size()) + { + if (outV != nonVocab) inData.emplace_back(contextualTokens[i + j - causalContextSize]); + } + else + { + auto t = getInput(tokens[i + j - causalContextSize], oovDict.get()); + if (dropoutProbOnHistory > 0 && std::generate_canonical(rng) < dropoutProbOnHistory) + { + t = getDefaultMorphemeId((*morphemes)[t].tag); + } + if (outV != nonVocab) inData.emplace_back(tokenToVocab[t]); + } + } + } + if (windowSize) + { + if (windowTokenValidness[v]) + { + if (outV != nonVocab) std::copy(history.begin(), history.end(), std::back_inserter(inData)); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = v; + } + else + { + history.pop_front(); + history.push_back(v); + } + } + else + { + if (outV != nonVocab) inData.resize(inData.size() + windowSize, -1); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = nonVocab; + } + } + } + + if (outV != nonVocab) outData.emplace_back(v); + } +} + +template +size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + UlInTy unlikelihoodIn, UlOutTy unlikelihoodOut, size_t* unlikelihoodSize) { const auto& prepareNext = [&](size_t, size_t localId, size_t sentFirst, size_t sentLast) { auto& local = locals[localId]; auto& tokens = local.tokenBuf; - auto& contextualTokens = local.contextualTokenBuf; tokens.reserve(sents.get()[shuffledIdx[sentFirst]].size()); for (size_t s = sentFirst; s < sentLast; ++s) { @@ -106,8 +323,18 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, tokens.emplace_back(sent[0]); for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p) { - const auto t = *p; - const auto nounAugment = ((*morphemes)[t].tag == POSTag::nnp && !isSpecialClass((*morphemes)[*(p + 1)].tag)) ? nounAugmentor(local.rng) : 0; + int32_t t = *p; + int32_t tWithOOV = *p; + if (oovDict && t < 0) + { + t = getDefaultMorphemeId((*oovDict)[-t - 1].second); + } + int32_t t1 = *(p + 1); + if (oovDict && t1 < 0) + { + t1 = getDefaultMorphemeId((*oovDict)[-t1 - 1].second); + } + const auto nounAugment = ((*morphemes)[t].tag == POSTag::nnp && !isSpecialClass((*morphemes)[t1].tag)) ? nounAugmentor(local.rng) : 0; switch (nounAugment) { @@ -126,14 +353,17 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, case 5: // replace with w_hashtag tokens.emplace_back(getDefaultMorphemeId(POSTag::w_hashtag)); break; + case 6: // replace with sh + tokens.emplace_back(getDefaultMorphemeId(POSTag::sh)); + break; } - if (nounAugment != 5) + if (nounAugment < 5) { switch (dropout(local.rng)) { case 0: // no dropout - tokens.emplace_back(t); + tokens.emplace_back(tWithOOV); break; case 1: // replacement tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); @@ -142,10 +372,10 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, break; case 3: // insertion tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); - tokens.emplace_back(t); + tokens.emplace_back(tWithOOV); break; case 4: // insertion - tokens.emplace_back(t); + tokens.emplace_back(tWithOOV); tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); break; } @@ -168,78 +398,26 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, } } tokens.emplace_back(sent[sent.size() - 1]); + const size_t offset = local.outData.size(); + prepareInOutData(local.inData, local.outData, tokens, local.rng); local.lmLProbsBuf.resize(tokens.size()); local.outNgramNodeBuf.resize(tokens.size()); - if (knlm) + if (auto knlm = std::dynamic_pointer_cast(langModel)) { knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); } - - auto& history = local.historyBuf; - history.clear(); - if (windowSize) - { - history.resize(windowSize, -1); - if (windowTokenValidness[tokens[0]]) - { - history.back() = tokenToVocab[tokens[0]]; - } - } - - if (causalContextSize && contextualMapper.size()) + for (size_t i = 1; i < tokens.size(); ++i) { - auto* node = contextualMapper.root(); - contextualTokens.clear(); - contextualTokens.reserve(tokens.size()); - for (size_t i = 0; i < tokens.size(); ++i) + int32_t t = tokens[i]; + if (oovDict && t < 0) { - const int32_t v = tokenToVocab[tokens[i]]; - auto* next = node->template nextOpt(contextualMapper, v); - while (!next) - { - node = node->fail(); - if (!node) break; - next = node->template nextOpt(contextualMapper, v); - } - if (next) - { - auto val = next->val(contextualMapper); - if (contextualMapper.hasMatch(val)) - { - contextualTokens.emplace_back(val - 1); - } - else if (contextualMapper.hasSubmatch(val)) - { - auto sub = next->fail(); - for (; sub; sub = sub->fail()) - { - val = sub->val(contextualMapper); - if (contextualMapper.hasMatch(val)) - { - break; - } - } - if (sub) contextualTokens.emplace_back(val - 1); - else contextualTokens.emplace_back(nonVocab); - } - node = next; - } - else - { - contextualTokens.emplace_back(nonVocab); - node = contextualMapper.root(); - } + t = getDefaultMorphemeId((*oovDict)[-t - 1].second); } - } - - int32_t lastV = nonVocab; - for (size_t i = 1; i < tokens.size(); ++i) - { - int32_t v = tokenToVocab[tokens[i]]; + int32_t v = tokenToVocab[t]; if (v == nonVocab) { - size_t r = local.outData.size() / batchSize; + size_t r = (offset + i - 1) / batchSize; if (local.restLmLProbsData.size() <= r) { local.restLmLProbsData.resize(r + 1); @@ -250,65 +428,6 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, continue; } - if (causalContextSize) - { - for (size_t j = 0; j < causalContextSize; ++j) - { - if (i + j < causalContextSize) - { - local.inData.emplace_back(nonVocab); - } - else if (contextualMapper.size()) - { - local.inData.emplace_back(contextualTokens[i + j - causalContextSize]); - } - else - { - auto t = tokens[i + j - causalContextSize]; - if (dropoutOnHistory.p() > 0 && dropoutOnHistory(local.rng)) - { - t = getDefaultMorphemeId((*morphemes)[t].tag); - } - local.inData.emplace_back(tokenToVocab[t]); - } - } - } - if (windowSize) - { - if (windowTokenValidness[v]) - { - std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); - if (exclusiveWindow) - { - if (lastV != nonVocab) - { - history.pop_front(); - history.push_back(lastV); - } - lastV = v; - } - else - { - history.pop_front(); - history.push_back(v); - } - } - else - { - local.inData.resize(local.inData.size() + windowSize, -1); - if (exclusiveWindow) - { - if (lastV != nonVocab) - { - history.pop_front(); - history.push_back(lastV); - } - lastV = nonVocab; - } - } - } - - local.outData.emplace_back(v); local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]); local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]); } @@ -319,6 +438,35 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.restLmLProbsData.resize(r + 1); local.restLmLProbsCntData.resize(r + 1); } + + if (doesGenerateUnlikelihoods()) + { + local.unlikelihoodBuf.clear(); + local.unlikelihoodBuf.emplace_back(tokens[0], 0); + for (size_t i = 1; i < tokens.size() - 1; ++i) + { + if (oovDict && tokens[i] < 0) + { + if (!tokenizeUnlikely(local.unlikelihoodBuf, tokens[i - 1], tokens[i], tokens[i + 1], local.rng)) + { + local.unlikelihoodBuf.emplace_back(tokens[i], 0); + } + continue; + } + + auto& morph = (*morphemes)[tokens[i]]; + if (tokens[i] < generateUnlikelihoods + || !(morph.tag == POSTag::nng || morph.tag == POSTag::nnp) + || getDefaultMorphemeId(morph.tag) == tokens[i] + || !tokenizeUnlikely(local.unlikelihoodBuf, tokens[i - 1], tokens[i], tokens[i + 1], local.rng)) + { + local.unlikelihoodBuf.emplace_back(tokens[i], 0); + } + } + local.unlikelihoodBuf.emplace_back(tokens.back(), 0); + + prepareInOutData(local.unlikelihoodInData, local.unlikelihoodOutData, local.unlikelihoodBuf, local.rng); + } } return localId; }; @@ -386,13 +534,20 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, auto& l = locals[localId]; - size_t rest = std::min(l.outData.size(), batchSize); + const size_t rest = std::min(l.outData.size(), batchSize); + const size_t unlikelihoodRest = std::min(l.unlikelihoodOutData.size(), batchSize); std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in); std::copy(l.outData.begin(), l.outData.begin() + rest, out); std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs); std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode); restLmOut = l.restLmLProbsData.front(); restLmCntOut = l.restLmLProbsCntData.front(); + if (doesGenerateUnlikelihoods() && unlikelihoodIn && unlikelihoodOut) + { + std::copy(l.unlikelihoodInData.begin(), l.unlikelihoodInData.begin() + unlikelihoodRest * (causalContextSize + windowSize), unlikelihoodIn); + std::copy(l.unlikelihoodOutData.begin(), l.unlikelihoodOutData.begin() + unlikelihoodRest, unlikelihoodOut); + if (unlikelihoodSize) *unlikelihoodSize = unlikelihoodRest; + } l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize)); l.outData.erase(l.outData.begin(), l.outData.begin() + rest); @@ -400,21 +555,29 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest); l.restLmLProbsData.pop_front(); l.restLmLProbsCntData.pop_front(); + if (doesGenerateUnlikelihoods() && unlikelihoodIn && unlikelihoodOut) + { + l.unlikelihoodInData.erase(l.unlikelihoodInData.begin(), l.unlikelihoodInData.begin() + unlikelihoodRest * (causalContextSize + windowSize)); + l.unlikelihoodOutData.erase(l.unlikelihoodOutData.begin(), l.unlikelihoodOutData.begin() + unlikelihoodRest); + } return rest; } -size_t HSDataset::next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +size_t HSDataset::next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int32_t* unlikelihoodIn, int32_t* unlikelihoodOut, size_t* unlikelihoodSize) { - return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut); + return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut, unlikelihoodIn, unlikelihoodOut, unlikelihoodSize); } -size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int64_t* unlikelihoodIn, int64_t* unlikelihoodOut, size_t* unlikelihoodSize) { - return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut); + return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut, unlikelihoodIn, unlikelihoodOut, unlikelihoodSize); } size_t HSDataset::ngramNodeSize() const { + auto knlm = std::dynamic_pointer_cast(langModel); return knlm ? knlm->nonLeafNodeSize() : 0; } @@ -454,7 +617,7 @@ std::vector kiwi::HSDataset::estimVocabFrequency() const return ret; } -Range::const_iterator> HSDataset::getSent(size_t idx) const +Range::const_iterator> HSDataset::getSent(size_t idx) const { return sents.get()[idx]; } @@ -493,7 +656,7 @@ std::vector HSDataset::getAugmentedSent(size_t idx) return ret; } -std::vector, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers, bool exclusiveCnt) const +std::vector, size_t>> HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers, bool exclusiveCnt) const { using Pair = std::pair, size_t>; std::vector ret; diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index f15efe4a..dfc92f34 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -571,7 +571,8 @@ void KiwiBuilder::_addCorpusTo( std::istream& is, MorphemeMap& morphMap, double splitRatio, - RaggedVector* splitOut + RaggedVector* splitOut, + UnorderedMap, size_t>* oovDict ) const { Vector wids; @@ -690,8 +691,16 @@ void KiwiBuilder::_addCorpusTo( } if (t < POSTag::p && t != POSTag::unknown) + { + if (oovDict && (t == POSTag::nng || t == POSTag::nnp)) + { + auto oovId = oovDict->emplace(make_pair(f, t), oovDict->size()).first->second; + wids.emplace_back(-(ptrdiff_t)(oovId + 1)); + } + else { wids.emplace_back(getDefaultMorphemeId(t)); + } continue; } @@ -700,19 +709,28 @@ void KiwiBuilder::_addCorpusTo( } } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); +} + +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const +{ + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } void KiwiBuilder::updateForms() @@ -2341,7 +2359,8 @@ void KiwiBuilder::convertHSData( const vector& inputPathes, const string& outputPath, const string& morphemeDefPath, - size_t morphemeDefMinCnt + size_t morphemeDefMinCnt, + bool generateOovDict ) const { unique_ptr dummyBuilder; @@ -2363,15 +2382,33 @@ void KiwiBuilder::convertHSData( srcBuilder = dummyBuilder.get(); } - RaggedVector sents; + UnorderedMap, size_t> oovDict; + RaggedVector sents; for (auto& path : inputPathes) { ifstream ifs; - srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, 0, nullptr, generateOovDict ? &oovDict: nullptr); } ofstream ofs; sents.write_to_memory(openFile(ofs, outputPath, ios_base::binary)); + if (generateOovDict) + { + Vector> oovDictStr(oovDict.size()); + for (auto& p : oovDict) + { + oovDictStr[p.second] = make_pair(joinHangul(p.first.first), p.first.second); + } + + const uint32_t size = oovDictStr.size(); + ofs.write((const char*)&size, sizeof(uint32_t)); + for (auto& p : oovDictStr) + { + const uint32_t tagAndSize = (uint32_t)p.second | ((uint32_t)p.first.size() << 8); + ofs.write((const char*)&tagAndSize, sizeof(uint32_t)); + ofs.write((const char*)p.first.data(), p.first.size() * sizeof(char16_t)); + } + } } HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, @@ -2379,6 +2416,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, double dropoutProb, double dropoutProbOnHistory, double nounAugmentingProb, + size_t generateUnlikelihoods, const TokenFilter& tokenFilter, const TokenFilter& windowFilter, double splitRatio, @@ -2389,20 +2427,32 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, HSDataset* splitDataset ) const { - HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory, nounAugmentingProb }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory, nounAugmentingProb, generateUnlikelihoods }; auto& sents = dataset.sents.get(); const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; size_t maxTokenId = 0; shared_ptr knlm; + const bool doesGenerateUnlikelihoods = generateUnlikelihoods != (size_t)-1; + if (morphemeDefPath.empty()) { realMorph = restoreMorphemeMap(separateDefaultMorpheme); - knlm = dynamic_pointer_cast(langMdl); + dataset.langModel = langMdl; + if (doesGenerateUnlikelihoods) + { + dataset.kiwiInst = make_unique(build()); + dataset.kiwiInst->setMaxUnkFormSize(2); + } } else { + if (doesGenerateUnlikelihoods) + { + throw invalid_argument{ "cannot generate unlikelihoods with morpheme definition file" }; + } + dataset.dummyBuilder = make_shared(); dataset.dummyBuilder->initMorphemes(); ifstream ifs; @@ -2418,41 +2468,76 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, } } - dataset.knlm = knlm; dataset.morphemes = &srcBuilder->morphemes; dataset.forms = &srcBuilder->forms; dataset.specialMorphIds = getSpecialMorphs(); if (splitDataset) { - *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb }; + *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, 0, 0, generateUnlikelihoods }; splitDataset->dummyBuilder = dataset.dummyBuilder; - splitDataset->knlm = knlm; + splitDataset->langModel = dataset.langModel; + splitDataset->kiwiInst = dataset.kiwiInst; splitDataset->morphemes = dataset.morphemes; splitDataset->forms = dataset.forms; splitDataset->specialMorphIds = dataset.specialMorphIds; } + UnorderedMap, size_t> oovDict; for (auto& path : inputPathes) { try { ifstream ifs; - auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); + auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); + uint32_t oovDictSize = 0; + Vector oovDictMap; + if (ifs.read((char*)&oovDictSize, sizeof(uint32_t))) + { + for (uint32_t i = 0; i < oovDictSize; ++i) + { + uint32_t tagAndSize = 0; + ifs.read((char*)&tagAndSize, sizeof(uint32_t)); + u16string form(tagAndSize >> 8, 0); + ifs.read((char*)form.data(), form.size() * sizeof(char16_t)); + const POSTag tag = (POSTag)(tagAndSize & 0xff); + if (doesGenerateUnlikelihoods) + { + KString kform = normalizeHangul(form); + const auto oovId = (int32_t)oovDict.emplace(make_pair(kform, tag), oovDict.size()).first->second; + oovDictMap.emplace_back(-oovId - 1); + } + else + { + oovDictMap.emplace_back(getDefaultMorphemeId(tag)); + } + } + } + double splitCnt = 0; for (auto s : cvtSents) { splitCnt += splitRatio; auto& o = splitDataset && splitCnt >= 1 ? splitDataset->sents.get() : sents; o.emplace_back(); + if (oovDictMap.empty()) + { o.insert_data(s.begin(), s.end()); + } + else + { + for (auto i : s) + { + o.add_data(i < 0 ? oovDictMap[-i - 1] : i); + } + } splitCnt = fmod(splitCnt, 1.); } } catch (const runtime_error&) { ifstream ifs; - srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr, doesGenerateUnlikelihoods ? &oovDict : nullptr); } } size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1; @@ -2463,6 +2548,16 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, tokenSize = max(tokenSize, sents.raw().empty() ? (size_t)0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1); } + if (doesGenerateUnlikelihoods) + { + dataset.oovDict = make_unique>>(oovDict.size()); + for (auto& p : oovDict) + { + (*dataset.oovDict)[p.second] = make_pair(joinHangul(p.first.first), p.first.second); + } + if (splitDataset) splitDataset->oovDict = dataset.oovDict; + } + const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId; tokenSize = max(tokenSize, knlmVocabSize); size_t filteredKnlmVocabSize = 0; diff --git a/src/SubstringExtractor.cpp b/src/SubstringExtractor.cpp index 452c6b24..ac238df1 100644 --- a/src/SubstringExtractor.cpp +++ b/src/SubstringExtractor.cpp @@ -266,6 +266,11 @@ namespace kiwi _addArray(first, last); } + void PrefixCounter::addArray(const int32_t* first, const int32_t* last) + { + _addArray(first, last); + } + void PrefixCounter::addArray(const uint64_t* first, const uint64_t* last) { _addArray(first, last); From 2a4e9d7813bf2ab7735d7442c3090cd4a29e6a63 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 22 Mar 2025 03:20:07 +0900 Subject: [PATCH 32/53] Fix wrong calculation of dataset's vocab size --- src/KiwiBuilder.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index dfc92f34..7d9d89e9 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -2432,7 +2432,6 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; size_t maxTokenId = 0; - shared_ptr knlm; const bool doesGenerateUnlikelihoods = generateUnlikelihoods != (size_t)-1; @@ -2558,7 +2557,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, if (splitDataset) splitDataset->oovDict = dataset.oovDict; } - const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId; + const size_t knlmVocabSize = dataset.langModel ? dataset.langModel->vocabSize() : maxTokenId; tokenSize = max(tokenSize, knlmVocabSize); size_t filteredKnlmVocabSize = 0; for (size_t i = 0; i < tokenSize; ++i) From d1e570e34fd92c23e047f25663a2573172055b3a Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 14:46:02 +0900 Subject: [PATCH 33/53] Fix a bug related to diff_tokens' showSame option --- tools/diff_tokens.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/diff_tokens.cpp b/tools/diff_tokens.cpp index 436077ad..41d94aff 100644 --- a/tools/diff_tokens.cpp +++ b/tools/diff_tokens.cpp @@ -37,7 +37,7 @@ inline ostream& operator<<(ostream& ostr, const TokenInfo& token) bool printDiffTokens(ostream& ostr, const string& raw, const TokenInfo* a, size_t aSize, const TokenInfo* b, size_t bSize, bool ignoreTag = false, bool showSame = false) { - if (isEqual(a, aSize, b, bSize, ignoreTag) == showSame) return false; + if (isEqual(a, aSize, b, bSize, ignoreTag) != showSame) return false; ostr << raw << '\t'; for (size_t i = 0; i < aSize; ++i) { From 16f07a6e2d088d8bfb3af13d6c5384a12541ced7 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 14:49:19 +0900 Subject: [PATCH 34/53] Update test cases to fit recent Kiwi API --- test/test_cpp.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 17906b87..64a7972a 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -50,7 +50,7 @@ constexpr std::vector> toPair(const ATy(&init)[n]) Kiwi& reuseKiwiInstance() { - static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, }.build(); + static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm }.build(); return kiwi; } @@ -147,7 +147,7 @@ TEST(KiwiCpp, SingleConsonantMorpheme) TEST(KiwiCpp, SpecialTokenErrorOnContinualTypo) { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm }; Kiwi typoKiwi = builder.build(DefaultTypoSet::continualTypoSet); auto res = typoKiwi.analyze(u"감사합니다 -친구들과", Match::allWithNormalizing).first; @@ -368,7 +368,7 @@ TEST(KiwiCpp, TagRoundTrip) TEST(KiwiCpp, UserTag) { - KiwiBuilder kw{ MODEL_PATH, 0, BuildOption::default_, }; + KiwiBuilder kw{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm, }; EXPECT_TRUE(kw.addWord(u"사용자태그", POSTag::user0).second); EXPECT_TRUE(kw.addWord(u"이것도유저", POSTag::user1).second); EXPECT_TRUE(kw.addWord(u"특수한표지", POSTag::user2).second); @@ -436,7 +436,7 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., tokenFilter, {}, 0.1, false, {}, 0, {}, &devset); + trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., 0., false, tokenFilter, {}, 0.1, false, {}, 0, {}, &devset); for (size_t i = 0; i < 2; ++i) { { @@ -1196,7 +1196,7 @@ TEST(KiwiCpp, IssueP111_SentenceSplitError) auto res = kiwi.splitIntoSents(text); EXPECT_GT(res.size(), 1); - KiwiBuilder builder{ MODEL_PATH, 1 }; + KiwiBuilder builder{ MODEL_PATH, 1, BuildOption::default_, ModelType::knlm }; EXPECT_TRUE(builder.addWord(u"모", POSTag::nng).second); Kiwi kiwi2 = builder.build(); auto res2 = kiwi2.splitIntoSents(text); @@ -1246,7 +1246,7 @@ TEST(KiwiCpp, AddRule) auto ores = okiwi.analyze(u"했어요! 하잖아요! 할까요? 좋아요!", Match::allWithNormalizing); { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict, ModelType::knlm }; auto inserted = builder.addRule(POSTag::ef, [](std::u16string input) { if (input.back() == u'요') @@ -1263,7 +1263,7 @@ TEST(KiwiCpp, AddRule) } { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict, ModelType::knlm }; auto inserted = builder.addRule(POSTag::ef, [](std::u16string input) { if (input.back() == u'요') From 94d85b395c1aba660a09b67df4ce8b4e69d97111 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:11:16 +0900 Subject: [PATCH 35/53] Update code base to C++17 and remove `nonstd::string_view` --- CMakeLists.txt | 2 +- bindings/java/CMakeLists.txt | 1 - include/kiwi/Types.h | 15 +- src/Combiner.cpp | 8 +- src/Combiner.h | 1 - src/Joiner.cpp | 4 +- src/KiwiBuilder.cpp | 20 +- src/StrUtils.h | 52 +- src/SwTokenizer.cpp | 4 +- src/Utils.cpp | 8 +- src/WordDetector.cpp | 2 +- src/string_view.hpp | 1773 ---------------------------------- test/test_combiner.cpp | 18 +- 13 files changed, 65 insertions(+), 1843 deletions(-) delete mode 100644 src/string_view.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 62568eae..63575909 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.12) project(kiwi VERSION 0.20.3 DESCRIPTION "Kiwi, Korean Intelligent Word Identifier") -set ( CMAKE_CXX_STANDARD 14 ) +set ( CMAKE_CXX_STANDARD 17 ) set ( CMAKE_VERBOSE_MAKEFILE true ) option(KIWI_USE_MIMALLOC "Use mimalloc for faster memory allocation" ON) diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt index 6c26dea3..ffbc188d 100644 --- a/bindings/java/CMakeLists.txt +++ b/bindings/java/CMakeLists.txt @@ -9,7 +9,6 @@ set(CMAKE_JAVA_COMPILE_FLAGS -source 8 -target 8 -encoding utf-8) set(pkg_name "KiwiJava-${PROJECT_VERSION}") add_library (${pkg_name} SHARED kiwi_java.cpp $ - $ $ ) if(UNIX AND NOT APPLE) diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index e4f86f05..f7c4c11c 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -1,4 +1,4 @@ -/** +/** * @file Types.h * @author bab2min (bab2min@gmail.com) * @brief Kiwi C++ API에 쓰이는 주요 타입들을 모아놓은 헤더 파일 @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -62,16 +63,6 @@ inline Type operator^=(Type& a, Type b)\ return reinterpret_cast(reinterpret_cast::type&>(a) ^= static_cast::type>(b));\ } -namespace nonstd -{ - namespace sv_lite - { - template class basic_string_view; - } - - using string_view = sv_lite::basic_string_view>; - using u16string_view = sv_lite::basic_string_view>; -} namespace kiwi { @@ -187,7 +178,7 @@ namespace kiwi using KcScores = Vector>; #endif - using U16StringView = nonstd::u16string_view; + using U16StringView = std::u16string_view; /** * @brief 형태소 품사 태그와 관련된 열거형 diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 8c72606c..ade620d8 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -789,8 +789,8 @@ void RuleSet::loadRules(istream& istr) } else if (fields.size() == 2) { - lTag = fields[0].to_string(); - rTag = fields[1].to_string(); + lTag = fields[0]; + rTag = fields[1]; } else { @@ -807,13 +807,13 @@ void RuleSet::loadRules(istream& istr) "+ignorercond", }; - transform(fields[3].begin(), fields[3].end(), const_cast(fields[3].begin()), static_cast(tolower)); + transform(fields[3].begin(), fields[3].end(), const_cast(fields[3].data()), static_cast(tolower)); for (auto f : split(fields[3], ',')) { size_t t = find(fs.begin(), fs.end(), f) - fs.begin(); if (t >= fs.size()) { - throw runtime_error{ "invalid feature value: " + f.to_string()}; + throw runtime_error{ "invalid feature value: " + string{ f } }; } switch (t) diff --git a/src/Combiner.h b/src/Combiner.h index 0d407d83..4ded2ae5 100644 --- a/src/Combiner.h +++ b/src/Combiner.h @@ -4,7 +4,6 @@ #include #include #include -#include "string_view.hpp" #include "bitset.hpp" namespace kiwi diff --git a/src/Joiner.cpp b/src/Joiner.cpp index 21c4c007..c0899453 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -162,7 +162,7 @@ namespace kiwi void Joiner::add(const u16string& form, POSTag tag, Space space) { - return add(nonstd::to_string_view(form), tag, space); + return add(toStringView(form), tag, space); } void Joiner::add(const char16_t* form, POSTag tag, Space space) @@ -229,7 +229,7 @@ namespace kiwi void AutoJoiner::add(const u16string& form, POSTag tag, bool inferRegularity, Space space) { - return (*dfAdd2)(this, nonstd::to_string_view(form), tag, inferRegularity, space, candBuf.get>>()); + return (*dfAdd2)(this, toStringView(form), tag, inferRegularity, space, candBuf.get>>()); } void AutoJoiner::add(const char16_t* form, POSTag tag, bool inferRegularity, Space space) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 7d9d89e9..e74fa85d 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -213,11 +213,11 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem if (cpolar != CondPolarity::none) throw FormatException{ "wrong line: " + line }; cpolar = CondPolarity::non_adj; } - else if (f.starts_with(u"complex ")) + else if (f.size() >= 8 && f.substr(0, 8) == u"complex ") { if (complex) throw FormatException{ "wrong line: " + line }; complex = true; - complexStr = f.substr(8).to_string(); + complexStr = u16string{ f.substr(8) }; } else if (f[0] == u'=') { @@ -1371,7 +1371,7 @@ pair KiwiBuilder::addWord(U16StringView newForm, POSTag tag, flo pair KiwiBuilder::addWord(const std::u16string& newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId) { - return addWord(nonstd::to_string_view(newForm), tag, score, origMorphemeId, lmMorphemeId); + return addWord(toStringView(newForm), tag, score, origMorphemeId, lmMorphemeId); } void KiwiBuilder::addCombinedMorpheme( @@ -1678,7 +1678,7 @@ pair KiwiBuilder::addWord(U16StringView form, POSTag tag, float pair KiwiBuilder::addWord(const u16string& form, POSTag tag, float score) { - return addWord(nonstd::to_string_view(form), tag, score); + return addWord(toStringView(form), tag, score); } pair KiwiBuilder::addWord(const char16_t* form, POSTag tag, float score) @@ -1716,7 +1716,7 @@ pair KiwiBuilder::addWord(U16StringView newForm, POSTag tag, flo pair KiwiBuilder::addWord(const u16string& newForm, POSTag tag, float score, const u16string& origForm) { - return addWord(nonstd::to_string_view(newForm), tag, score, origForm); + return addWord(toStringView(newForm), tag, score, origForm); } pair KiwiBuilder::addWord(const char16_t* newForm, POSTag tag, float score, const char16_t* origForm) @@ -1780,7 +1780,7 @@ bool KiwiBuilder::addPreAnalyzedWord(U16StringView form, const vector>& analyzed, vector> positions, float score) { - return addPreAnalyzedWord(nonstd::to_string_view(form), analyzed, positions, score); + return addPreAnalyzedWord(toStringView(form), analyzed, positions, score); } bool KiwiBuilder::addPreAnalyzedWord(const char16_t* form, const vector>& analyzed, vector> positions, float score) @@ -1798,7 +1798,7 @@ size_t KiwiBuilder::loadDictionary(const string& dictPath) u16string wstr; for (size_t lineNo = 1; getline(ifs, line); ++lineNo) { - utf8To16(nonstd::to_string_view(line), wstr); + utf8To16(toStringView(line), wstr); while (!wstr.empty() && kiwi::identifySpecialChr(wstr.back()) == POSTag::unknown) wstr.pop_back(); if (wstr.empty()) continue; if (wstr[0] == u'#') continue; @@ -1852,9 +1852,9 @@ size_t KiwiBuilder::loadDictionary(const string& dictPath) auto suffix = fields[0].substr(0, fields[0].size() - 1); addedCnt += addRule(morphemes[0].second, [&](const u16string& str) { - auto strv = nonstd::to_string_view(str); - if (!strv.ends_with(suffix)) return str; - return strv.substr(0, strv.size() - suffix.size()).to_string() + morphemes[0].first.to_string(); + auto strv = toStringView(str); + if (!(strv.size() >= suffix.size() && strv.substr(strv.size() - suffix.size()) == suffix)) return str; + return u16string{ strv.substr(0, strv.size() - suffix.size()) } + u16string{ morphemes[0].first }; }, score).size(); } else diff --git a/src/StrUtils.h b/src/StrUtils.h index 73b6390a..142bc9ad 100644 --- a/src/StrUtils.h +++ b/src/StrUtils.h @@ -3,7 +3,7 @@ #include #include #include -#include "string_view.hpp" +#include namespace kiwi { @@ -76,7 +76,7 @@ namespace kiwi size_t t = s.find(delim, p); if (t == s.npos) { - *(result++) = nonstd::basic_string_view{ &s[e] , s.size() - e}; + *(result++) = std::basic_string_view{ &s[e] , s.size() - e}; return result; } else @@ -91,28 +91,28 @@ namespace kiwi } else { - *(result++) = nonstd::basic_string_view{ &s[e] , t - e }; + *(result++) = std::basic_string_view{ &s[e] , t - e }; p = t + 1; e = t + 1; } } } - *(result++) = nonstd::basic_string_view{ &s[e] , s.size() - e }; + *(result++) = std::basic_string_view{ &s[e] , s.size() - e }; return result; } template - inline std::vector> split(nonstd::basic_string_view s, BaseChr delim, BaseChr delimEscape = 0) + inline std::vector> split(std::basic_string_view s, BaseChr delim, BaseChr delimEscape = 0) { - std::vector> ret; + std::vector> ret; split(s, delim, std::back_inserter(ret), -1, delimEscape); return ret; } template - inline std::vector> split(const std::basic_string& s, BaseChr delim, BaseChr delimEscape = 0) + inline std::vector> split(const std::basic_string& s, BaseChr delim, BaseChr delimEscape = 0) { - std::vector> ret; + std::vector> ret; split(s, delim, std::back_inserter(ret), -1, delimEscape); return ret; } @@ -141,9 +141,9 @@ namespace kiwi template> inline std::basic_string replace( - nonstd::basic_string_view s, - nonstd::basic_string_view from, - nonstd::basic_string_view to) + std::basic_string_view s, + std::basic_string_view from, + std::basic_string_view to) { std::basic_string ret; ret.reserve(s.size()); @@ -153,15 +153,15 @@ namespace kiwi template> inline std::basic_string replace( - nonstd::basic_string_view s, + std::basic_string_view s, const BaseChr(&from)[fromSize], const BaseChr(&to)[toSize]) { - return replace(s, nonstd::basic_string_view{ from, fromSize - 1 }, nonstd::basic_string_view{ to, toSize - 1 }); + return replace(s, std::basic_string_view{ from, fromSize - 1 }, std::basic_string_view{ to, toSize - 1 }); } - inline void utf8To16(nonstd::string_view str, std::u16string& ret) + inline void utf8To16(std::string_view str, std::u16string& ret) { ret.clear(); for (auto it = str.begin(); it != str.end(); ++it) @@ -224,7 +224,7 @@ namespace kiwi } } - inline std::u16string utf8To16(nonstd::string_view str) + inline std::u16string utf8To16(std::string_view str) { std::u16string ret; utf8To16(str, ret); @@ -232,7 +232,7 @@ namespace kiwi } template - inline std::u16string utf8To16(nonstd::string_view str, std::vector& bytePositions) + inline std::u16string utf8To16(std::string_view str, std::vector& bytePositions) { std::u16string ret; bytePositions.clear(); @@ -302,7 +302,7 @@ namespace kiwi } template - inline std::u16string utf8To16ChrPoisition(nonstd::string_view str, std::vector& chrPositions) + inline std::u16string utf8To16ChrPoisition(std::string_view str, std::vector& chrPositions) { std::u16string ret; size_t chrPosition = 0; @@ -371,7 +371,7 @@ namespace kiwi return ret; } - inline std::string utf16To8(nonstd::u16string_view str) + inline std::string utf16To8(std::u16string_view str) { std::string ret; for (auto it = str.begin(); it != str.end(); ++it) @@ -417,7 +417,7 @@ namespace kiwi } template - inline std::string utf16To8(nonstd::u16string_view str, std::vector& positions) + inline std::string utf16To8(std::u16string_view str, std::vector& positions) { std::string ret; positions.clear(); @@ -504,7 +504,7 @@ namespace kiwi return normalizeHangul(hangul.begin(), hangul.end()); } - inline KString normalizeHangul(nonstd::u16string_view hangul) + inline KString normalizeHangul(std::u16string_view hangul) { return normalizeHangul(hangul.begin(), hangul.end()); } @@ -553,7 +553,7 @@ namespace kiwi return normalizeHangulWithPosition(hangul.begin(), hangul.end()); } - inline std::pair> normalizeHangulWithPosition(nonstd::u16string_view hangul) + inline std::pair> normalizeHangulWithPosition(std::u16string_view hangul) { return normalizeHangulWithPosition(hangul.begin(), hangul.end()); } @@ -563,12 +563,12 @@ namespace kiwi return normalizeHangul(utf8To16(hangul)); } - inline KString normalizeHangul(nonstd::string_view hangul) + inline KString normalizeHangul(std::string_view hangul) { return normalizeHangul(utf8To16(hangul)); } - inline POSTag toPOSTag(nonstd::u16string_view tagStr) + inline POSTag toPOSTag(std::u16string_view tagStr) { if (tagStr == u"NNG") return POSTag::nng; if (tagStr == u"NNP") return POSTag::nnp; @@ -745,4 +745,10 @@ namespace kiwi || (0x2E80 <= c && c <= 0x2EFF) ; } + + template + inline std::basic_string_view toStringView(const std::basic_string& str) + { + return std::basic_string_view{ str.data(), str.size() }; + } } diff --git a/src/SwTokenizer.cpp b/src/SwTokenizer.cpp index ad948989..6ef3d29b 100644 --- a/src/SwTokenizer.cpp +++ b/src/SwTokenizer.cpp @@ -576,7 +576,7 @@ namespace kiwi }; }; - inline void utf8To16IgnoringErrors(nonstd::string_view str, std::u16string& ret) + inline void utf8To16IgnoringErrors(std::string_view str, std::u16string& ret) { ret.clear(); for (auto it = str.begin(); it != str.end(); ++it) @@ -675,7 +675,7 @@ namespace kiwi } } - inline std::u16string utf8To16IgnoringErrors(nonstd::string_view str) + inline std::u16string utf8To16IgnoringErrors(std::string_view str) { std::u16string ret; utf8To16IgnoringErrors(str, ret); diff --git a/src/Utils.cpp b/src/Utils.cpp index a1a4f17c..0976377d 100644 --- a/src/Utils.cpp +++ b/src/Utils.cpp @@ -6,12 +6,12 @@ namespace kiwi { std::u16string utf8To16(const std::string & str) { - return utf8To16(nonstd::to_string_view(str)); + return utf8To16(toStringView(str)); } std::u16string utf8To16(const std::string& str, std::vector& bytePositions) { - return utf8To16(nonstd::to_string_view(str), bytePositions); + return utf8To16(toStringView(str), bytePositions); } size_t utf8FromCode(std::string& ret, char32_t code) @@ -54,7 +54,7 @@ namespace kiwi std::string utf16To8(const std::u16string & str) { - return utf16To8(nonstd::to_string_view(str)); + return utf16To8(toStringView(str)); } /** @@ -292,7 +292,7 @@ namespace kiwi POSTag toPOSTag(const std::u16string& tagStr) { - return toPOSTag(nonstd::to_string_view(tagStr)); + return toPOSTag(toStringView(tagStr)); } const char* tagToString(POSTag t) diff --git a/src/WordDetector.cpp b/src/WordDetector.cpp index 61dd759e..01552f13 100644 --- a/src/WordDetector.cpp +++ b/src/WordDetector.cpp @@ -461,7 +461,7 @@ void WordDetector::loadNounTailModelFromTxt(std::istream & is) auto fields = split(utf8To16(line), u'\t'); if (fields.size() < 4) continue; float p = stof(fields[1].begin(), fields[1].end()); - nounTailScore[fields[0].to_string()] = p; + nounTailScore[u16string{ fields[0] }] = p; } } diff --git a/src/string_view.hpp b/src/string_view.hpp deleted file mode 100644 index 6f4f724b..00000000 --- a/src/string_view.hpp +++ /dev/null @@ -1,1773 +0,0 @@ -// Copyright 2017-2020 by Martin Moene -// -// string-view lite, a C++17-like string_view for C++98 and later. -// For more information see https://github.com/martinmoene/string-view-lite -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) - -#pragma once - -#ifndef NONSTD_SV_LITE_H_INCLUDED -#define NONSTD_SV_LITE_H_INCLUDED - -#define string_view_lite_MAJOR 1 -#define string_view_lite_MINOR 6 -#define string_view_lite_PATCH 0 - -#define string_view_lite_VERSION nssv_STRINGIFY(string_view_lite_MAJOR) "." nssv_STRINGIFY(string_view_lite_MINOR) "." nssv_STRINGIFY(string_view_lite_PATCH) - -#define nssv_STRINGIFY( x ) nssv_STRINGIFY_( x ) -#define nssv_STRINGIFY_( x ) #x - -// string-view lite configuration: - -#define nssv_STRING_VIEW_DEFAULT 0 -#define nssv_STRING_VIEW_NONSTD 1 -#define nssv_STRING_VIEW_STD 2 - -// tweak header support: - -#ifdef __has_include -# if __has_include() -# include -# endif -#define nssv_HAVE_TWEAK_HEADER 1 -#else -#define nssv_HAVE_TWEAK_HEADER 0 -//# pragma message("string_view.hpp: Note: Tweak header not supported.") -#endif - -// string_view selection and configuration: - -#if !defined( nssv_CONFIG_SELECT_STRING_VIEW ) -# define nssv_CONFIG_SELECT_STRING_VIEW ( nssv_HAVE_STD_STRING_VIEW ? nssv_STRING_VIEW_STD : nssv_STRING_VIEW_NONSTD ) -#endif - -#ifndef nssv_CONFIG_STD_SV_OPERATOR -# define nssv_CONFIG_STD_SV_OPERATOR 0 -#endif - -#ifndef nssv_CONFIG_USR_SV_OPERATOR -# define nssv_CONFIG_USR_SV_OPERATOR 1 -#endif - -#ifdef nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS nssv_CONFIG_CONVERSION_STD_STRING -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS 1 -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS 1 -#endif - -#ifndef nssv_CONFIG_NO_STREAM_INSERTION -# define nssv_CONFIG_NO_STREAM_INSERTION 0 -#endif - -// Control presence of exception handling (try and auto discover): - -#ifndef nssv_CONFIG_NO_EXCEPTIONS -# if defined(_MSC_VER) -# include // for _HAS_EXCEPTIONS -# endif -# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (_HAS_EXCEPTIONS) -# define nssv_CONFIG_NO_EXCEPTIONS 0 -# else -# define nssv_CONFIG_NO_EXCEPTIONS 1 -# endif -#endif - -// C++ language version detection (C++20 is speculative): -// Note: VC14.0/1900 (VS2015) lacks too much from C++14. - -#ifndef nssv_CPLUSPLUS -# if defined(_MSVC_LANG ) && !defined(__clang__) -# define nssv_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) -# else -# define nssv_CPLUSPLUS __cplusplus -# endif -#endif - -#define nssv_CPP98_OR_GREATER ( nssv_CPLUSPLUS >= 199711L ) -#define nssv_CPP11_OR_GREATER ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP11_OR_GREATER_ ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP14_OR_GREATER ( nssv_CPLUSPLUS >= 201402L ) -#define nssv_CPP17_OR_GREATER ( nssv_CPLUSPLUS >= 201703L ) -#define nssv_CPP20_OR_GREATER ( nssv_CPLUSPLUS >= 202000L ) - -// use C++17 std::string_view if available and requested: - -#if nssv_CPP17_OR_GREATER && defined(__has_include ) -# if __has_include( ) -# define nssv_HAVE_STD_STRING_VIEW 1 -# else -# define nssv_HAVE_STD_STRING_VIEW 0 -# endif -#else -# define nssv_HAVE_STD_STRING_VIEW 0 -#endif - -#define nssv_USES_STD_STRING_VIEW ( (nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_STD) || ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_DEFAULT) && nssv_HAVE_STD_STRING_VIEW) ) - -#define nssv_HAVE_STARTS_WITH ( nssv_CPP20_OR_GREATER || !nssv_USES_STD_STRING_VIEW ) -#define nssv_HAVE_ENDS_WITH nssv_HAVE_STARTS_WITH - -// -// Use C++17 std::string_view: -// - -#if nssv_USES_STD_STRING_VIEW - -#include - -// Extensions for std::string: - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - - template< class CharT, class Traits, class Allocator = std::allocator > - std::basic_string - to_string(std::basic_string_view v, Allocator const& a = Allocator()) - { - return std::basic_string(v.begin(), v.end(), a); - } - - template< class CharT, class Traits, class Allocator > - std::basic_string_view - to_string_view(std::basic_string const& s) - { - return std::basic_string_view(s.data(), s.size()); - } - - // Literal operators sv and _sv: - -#if nssv_CONFIG_STD_SV_OPERATOR - - using namespace std::literals::string_view_literals; - -#endif - -#if nssv_CONFIG_USR_SV_OPERATOR - - inline namespace literals { - inline namespace string_view_literals { - - - constexpr std::string_view operator "" _sv(const char* str, size_t len) noexcept // (1) - { - return std::string_view{ str, len }; - } - - constexpr std::u16string_view operator "" _sv(const char16_t* str, size_t len) noexcept // (2) - { - return std::u16string_view{ str, len }; - } - - constexpr std::u32string_view operator "" _sv(const char32_t* str, size_t len) noexcept // (3) - { - return std::u32string_view{ str, len }; - } - - constexpr std::wstring_view operator "" _sv(const wchar_t* str, size_t len) noexcept // (4) - { - return std::wstring_view{ str, len }; - } - - } - } // namespace literals::string_view_literals - -#endif // nssv_CONFIG_USR_SV_OPERATOR - -} // namespace nonstd - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - - using std::string_view; - using std::wstring_view; - using std::u16string_view; - using std::u32string_view; - using std::basic_string_view; - - // literal "sv" and "_sv", see above - - using std::operator==; - using std::operator!=; - using std::operator<; - using std::operator<=; - using std::operator>; - using std::operator>=; - - using std::operator<<; - -} // namespace nonstd - -#else // nssv_HAVE_STD_STRING_VIEW - -// -// Before C++17: use string_view lite: -// - -// Compiler versions: -// -// MSVC++ 6.0 _MSC_VER == 1200 nssv_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0) -// MSVC++ 7.0 _MSC_VER == 1300 nssv_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002) -// MSVC++ 7.1 _MSC_VER == 1310 nssv_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003) -// MSVC++ 8.0 _MSC_VER == 1400 nssv_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005) -// MSVC++ 9.0 _MSC_VER == 1500 nssv_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008) -// MSVC++ 10.0 _MSC_VER == 1600 nssv_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010) -// MSVC++ 11.0 _MSC_VER == 1700 nssv_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012) -// MSVC++ 12.0 _MSC_VER == 1800 nssv_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013) -// MSVC++ 14.0 _MSC_VER == 1900 nssv_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015) -// MSVC++ 14.1 _MSC_VER >= 1910 nssv_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017) -// MSVC++ 14.2 _MSC_VER >= 1920 nssv_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019) - -#if defined(_MSC_VER ) && !defined(__clang__) -# define nssv_COMPILER_MSVC_VER (_MSC_VER ) -# define nssv_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) ) -#else -# define nssv_COMPILER_MSVC_VER 0 -# define nssv_COMPILER_MSVC_VERSION 0 -#endif - -#define nssv_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) ) - -#if defined( __apple_build_version__ ) -# define nssv_COMPILER_APPLECLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -# define nssv_COMPILER_CLANG_VERSION 0 -#elif defined( __clang__ ) -# define nssv_COMPILER_APPLECLANG_VERSION 0 -# define nssv_COMPILER_CLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -#else -# define nssv_COMPILER_APPLECLANG_VERSION 0 -# define nssv_COMPILER_CLANG_VERSION 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) -# define nssv_COMPILER_GNUC_VERSION nssv_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#else -# define nssv_COMPILER_GNUC_VERSION 0 -#endif - -// half-open range [lo..hi): -#define nssv_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) - -// Presence of language and library features: - -#ifdef _HAS_CPP0X -# define nssv_HAS_CPP0X _HAS_CPP0X -#else -# define nssv_HAS_CPP0X 0 -#endif - -// Unless defined otherwise below, consider VC14 as C++11 for variant-lite: - -#if nssv_COMPILER_MSVC_VER >= 1900 -# undef nssv_CPP11_OR_GREATER -# define nssv_CPP11_OR_GREATER 1 -#endif - -#define nssv_CPP11_90 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1500) -#define nssv_CPP11_100 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1600) -#define nssv_CPP11_110 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1700) -#define nssv_CPP11_120 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1800) -#define nssv_CPP11_140 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1900) -#define nssv_CPP11_141 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1910) - -#define nssv_CPP14_000 (nssv_CPP14_OR_GREATER) -#define nssv_CPP17_000 (nssv_CPP17_OR_GREATER) - -// Presence of C++11 language features: - -#define nssv_HAVE_CONSTEXPR_11 nssv_CPP11_140 -#define nssv_HAVE_EXPLICIT_CONVERSION nssv_CPP11_140 -#define nssv_HAVE_INLINE_NAMESPACE nssv_CPP11_140 -#define nssv_HAVE_IS_DEFAULT nssv_CPP11_140 -#define nssv_HAVE_IS_DELETE nssv_CPP11_140 -#define nssv_HAVE_NOEXCEPT nssv_CPP11_140 -#define nssv_HAVE_NULLPTR nssv_CPP11_100 -#define nssv_HAVE_REF_QUALIFIER nssv_CPP11_140 -#define nssv_HAVE_UNICODE_LITERALS nssv_CPP11_140 -#define nssv_HAVE_USER_DEFINED_LITERALS nssv_CPP11_140 -#define nssv_HAVE_WCHAR16_T nssv_CPP11_100 -#define nssv_HAVE_WCHAR32_T nssv_CPP11_100 - -#if ! ( ( nssv_CPP11_OR_GREATER && nssv_COMPILER_CLANG_VERSION ) || nssv_BETWEEN( nssv_COMPILER_CLANG_VERSION, 300, 400 ) ) -# define nssv_HAVE_STD_DEFINED_LITERALS nssv_CPP11_140 -#else -# define nssv_HAVE_STD_DEFINED_LITERALS 0 -#endif - -// Presence of C++14 language features: - -#define nssv_HAVE_CONSTEXPR_14 nssv_CPP14_000 - -// Presence of C++17 language features: - -#define nssv_HAVE_NODISCARD nssv_CPP17_000 - -// Presence of C++ library features: - -#define nssv_HAVE_STD_HASH nssv_CPP11_120 - -// Presence of compiler intrinsics: - -// Providing char-type specializations for compare() and length() that -// use compiler intrinsics can improve compile- and run-time performance. -// -// The challenge is in using the right combinations of builtin availability -// and its constexpr-ness. -// -// | compiler | __builtin_memcmp (constexpr) | memcmp (constexpr) | -// |----------|------------------------------|---------------------| -// | clang | 4.0 (>= 4.0 ) | any (? ) | -// | clang-a | 9.0 (>= 9.0 ) | any (? ) | -// | gcc | any (constexpr) | any (? ) | -// | msvc | >= 14.2 C++17 (>= 14.2 ) | any (? ) | - -#define nssv_HAVE_BUILTIN_VER ( (nssv_CPP17_000 && nssv_COMPILER_MSVC_VERSION >= 142) || nssv_COMPILER_GNUC_VERSION > 0 || nssv_COMPILER_CLANG_VERSION >= 400 || nssv_COMPILER_APPLECLANG_VERSION >= 900 ) -#define nssv_HAVE_BUILTIN_CE ( nssv_HAVE_BUILTIN_VER ) - -#define nssv_HAVE_BUILTIN_MEMCMP ( (nssv_HAVE_CONSTEXPR_14 && nssv_HAVE_BUILTIN_CE) || !nssv_HAVE_CONSTEXPR_14 ) -#define nssv_HAVE_BUILTIN_STRLEN ( (nssv_HAVE_CONSTEXPR_11 && nssv_HAVE_BUILTIN_CE) || !nssv_HAVE_CONSTEXPR_11 ) - -#ifdef __has_builtin -# define nssv_HAVE_BUILTIN( x ) __has_builtin( x ) -#else -# define nssv_HAVE_BUILTIN( x ) 0 -#endif - -#if nssv_HAVE_BUILTIN(__builtin_memcmp) || nssv_HAVE_BUILTIN_VER -# define nssv_BUILTIN_MEMCMP __builtin_memcmp -#else -# define nssv_BUILTIN_MEMCMP memcmp -#endif - -#if nssv_HAVE_BUILTIN(__builtin_strlen) || nssv_HAVE_BUILTIN_VER -# define nssv_BUILTIN_STRLEN __builtin_strlen -#else -# define nssv_BUILTIN_STRLEN strlen -#endif - -// C++ feature usage: - -#if nssv_HAVE_CONSTEXPR_11 -# define nssv_constexpr constexpr -#else -# define nssv_constexpr /*constexpr*/ -#endif - -#if nssv_HAVE_CONSTEXPR_14 -# define nssv_constexpr14 constexpr -#else -# define nssv_constexpr14 /*constexpr*/ -#endif - -#if nssv_HAVE_EXPLICIT_CONVERSION -# define nssv_explicit explicit -#else -# define nssv_explicit /*explicit*/ -#endif - -#if nssv_HAVE_INLINE_NAMESPACE -# define nssv_inline_ns inline -#else -# define nssv_inline_ns /*inline*/ -#endif - -#if nssv_HAVE_NOEXCEPT -# define nssv_noexcept noexcept -#else -# define nssv_noexcept /*noexcept*/ -#endif - -//#if nssv_HAVE_REF_QUALIFIER -//# define nssv_ref_qual & -//# define nssv_refref_qual && -//#else -//# define nssv_ref_qual /*&*/ -//# define nssv_refref_qual /*&&*/ -//#endif - -#if nssv_HAVE_NULLPTR -# define nssv_nullptr nullptr -#else -# define nssv_nullptr NULL -#endif - -#if nssv_HAVE_NODISCARD -# define nssv_nodiscard [[nodiscard]] -#else -# define nssv_nodiscard /*[[nodiscard]]*/ -#endif - -// Additional includes: - -#include -#include -#include -#include -#include // std::char_traits<> - -#if ! nssv_CONFIG_NO_STREAM_INSERTION -# include -#endif - -#if ! nssv_CONFIG_NO_EXCEPTIONS -# include -#endif - -#if nssv_CPP11_OR_GREATER -# include -#endif - -// Clang, GNUC, MSVC warning suppression macros: - -#if defined(__clang__) -# pragma clang diagnostic ignored "-Wreserved-user-defined-literal" -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wuser-defined-literals" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wliteral-suffix" -#endif // __clang__ - -#if nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_SUPPRESS_MSGSL_WARNING(expr) [[gsl::suppress(expr)]] -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) __pragma(warning(suppress: code) ) -# define nssv_DISABLE_MSVC_WARNINGS(codes) __pragma(warning(push)) __pragma(warning(disable: codes)) -#else -# define nssv_SUPPRESS_MSGSL_WARNING(expr) -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) -# define nssv_DISABLE_MSVC_WARNINGS(codes) -#endif - -#if defined(__clang__) -# define nssv_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") -#elif defined(__GNUC__) -# define nssv_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") -#elif nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_RESTORE_WARNINGS() __pragma(warning(pop )) -#else -# define nssv_RESTORE_WARNINGS() -#endif - -// Suppress the following MSVC (GSL) warnings: -// - C4455, non-gsl : 'operator ""sv': literal suffix identifiers that do not -// start with an underscore are reserved -// - C26472, gsl::t.1 : don't use a static_cast for arithmetic conversions; -// use brace initialization, gsl::narrow_cast or gsl::narow -// - C26481: gsl::b.1 : don't use pointer arithmetic. Use span instead - -nssv_DISABLE_MSVC_WARNINGS(4455 26481 26472) -//nssv_DISABLE_CLANG_WARNINGS( "-Wuser-defined-literals" ) -//nssv_DISABLE_GNUC_WARNINGS( -Wliteral-suffix ) - -namespace nonstd { - namespace sv_lite { - - // - // basic_string_view declaration: - // - - template - < - class CharT, - class Traits = std::char_traits - > - class basic_string_view; - - namespace detail { - - // support constexpr comparison in C++14; - // for C++17 and later, use provided traits: - - template< typename CharT > - inline nssv_constexpr14 int compare(CharT const* s1, CharT const* s2, std::size_t count) - { - while (count-- != 0) - { - if (*s1 < *s2) return -1; - if (*s1 > *s2) return +1; - ++s1; ++s2; - } - return 0; - } - -#if nssv_HAVE_BUILTIN_MEMCMP - - // specialization of compare() for char, see also generic compare() above: - - inline nssv_constexpr14 int compare(char const* s1, char const* s2, std::size_t count) - { - return nssv_BUILTIN_MEMCMP(s1, s2, count); - } - -#endif - -#if nssv_HAVE_BUILTIN_STRLEN - - // specialization of length() for char, see also generic length() further below: - - inline nssv_constexpr std::size_t length(char const* s) - { - return nssv_BUILTIN_STRLEN(s); - } - -#endif - -#if defined(__OPTIMIZE__) - - // gcc, clang provide __OPTIMIZE__ - // Expect tail call optimization to make length() non-recursive: - - template< typename CharT > - inline nssv_constexpr std::size_t length(CharT* s, std::size_t result = 0) - { - return *s == '\0' ? result : length(s + 1, result + 1); - } - -#else // OPTIMIZE - - // non-recursive: - - template< typename CharT > - inline nssv_constexpr14 std::size_t length(CharT* s) - { - std::size_t result = 0; - while (*s++ != '\0') - { - ++result; - } - return result; - } - -#endif // OPTIMIZE - -#if nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER -#if defined(__OPTIMIZE__) - - // gcc, clang provide __OPTIMIZE__ - // Expect tail call optimization to make search() non-recursive: - - template< class CharT, class Traits = std::char_traits > - constexpr const CharT* search(basic_string_view haystack, basic_string_view needle) - { - return haystack.starts_with(needle) ? haystack.begin() : - haystack.empty() ? haystack.end() : search(haystack.substr(1), needle); - } - -#else // OPTIMIZE - - // non-recursive: - - template< class CharT, class Traits = std::char_traits > - constexpr const CharT* search(basic_string_view haystack, basic_string_view needle) - { - return std::search(haystack.begin(), haystack.end(), needle.begin(), needle.end()); - } - -#endif // OPTIMIZE -#endif // nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER - - } // namespace detail - - // - // basic_string_view: - // - - template - < - class CharT, - class Traits /* = std::char_traits */ - > - class basic_string_view - { - public: - // Member types: - - typedef Traits traits_type; - typedef CharT value_type; - - typedef CharT* pointer; - typedef CharT const* const_pointer; - typedef CharT& reference; - typedef CharT const& const_reference; - - typedef const_pointer iterator; - typedef const_pointer const_iterator; - typedef std::reverse_iterator< const_iterator > reverse_iterator; - typedef std::reverse_iterator< const_iterator > const_reverse_iterator; - - typedef std::size_t size_type; - typedef std::ptrdiff_t difference_type; - - // 24.4.2.1 Construction and assignment: - - nssv_constexpr basic_string_view() nssv_noexcept - : data_(nssv_nullptr) - , size_(0) - {} - -#if nssv_CPP11_OR_GREATER - nssv_constexpr basic_string_view(basic_string_view const& other) nssv_noexcept = default; -#else - nssv_constexpr basic_string_view(basic_string_view const& other) nssv_noexcept - : data_(other.data_) - , size_(other.size_) - {} -#endif - - nssv_constexpr basic_string_view(CharT const* s, size_type count) nssv_noexcept // non-standard noexcept - : data_(s) - , size_(count) - {} - - nssv_constexpr basic_string_view(CharT const* s) nssv_noexcept // non-standard noexcept - : data_(s) -#if nssv_CPP17_OR_GREATER - , size_(Traits::length(s)) -#elif nssv_CPP11_OR_GREATER - , size_(detail::length(s)) -#else - , size_(Traits::length(s)) -#endif - {} - -#if nssv_HAVE_NULLPTR -# if nssv_HAVE_IS_DELETE - nssv_constexpr basic_string_view(std::nullptr_t) nssv_noexcept = delete; -# else - private: nssv_constexpr basic_string_view(std::nullptr_t) nssv_noexcept; public: -# endif -#endif - - // Assignment: - -#if nssv_CPP11_OR_GREATER - nssv_constexpr14 basic_string_view& operator=(basic_string_view const& other) nssv_noexcept = default; -#else - nssv_constexpr14 basic_string_view& operator=(basic_string_view const& other) nssv_noexcept - { - data_ = other.data_; - size_ = other.size_; - return *this; - } -#endif - - // 24.4.2.2 Iterator support: - - nssv_constexpr const_iterator begin() const nssv_noexcept { return data_; } - nssv_constexpr const_iterator end() const nssv_noexcept { return data_ + size_; } - - nssv_constexpr const_iterator cbegin() const nssv_noexcept { return begin(); } - nssv_constexpr const_iterator cend() const nssv_noexcept { return end(); } - - nssv_constexpr const_reverse_iterator rbegin() const nssv_noexcept { return const_reverse_iterator(end()); } - nssv_constexpr const_reverse_iterator rend() const nssv_noexcept { return const_reverse_iterator(begin()); } - - nssv_constexpr const_reverse_iterator crbegin() const nssv_noexcept { return rbegin(); } - nssv_constexpr const_reverse_iterator crend() const nssv_noexcept { return rend(); } - - // 24.4.2.3 Capacity: - - nssv_constexpr size_type size() const nssv_noexcept { return size_; } - nssv_constexpr size_type length() const nssv_noexcept { return size_; } - nssv_constexpr size_type max_size() const nssv_noexcept { return (std::numeric_limits< size_type >::max)(); } - - // since C++20 - nssv_nodiscard nssv_constexpr bool empty() const nssv_noexcept - { - return 0 == size_; - } - - // 24.4.2.4 Element access: - - nssv_constexpr const_reference operator[](size_type pos) const - { - return data_at(pos); - } - - nssv_constexpr14 const_reference at(size_type pos) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos < size()); -#else - if (pos >= size()) - { - throw std::out_of_range("nonstd::string_view::at()"); - } -#endif - return data_at(pos); - } - - nssv_constexpr const_reference front() const { return data_at(0); } - nssv_constexpr const_reference back() const { return data_at(size() - 1); } - - nssv_constexpr const_pointer data() const nssv_noexcept { return data_; } - - // 24.4.2.5 Modifiers: - - nssv_constexpr14 void remove_prefix(size_type n) - { - assert(n <= size()); - data_ += n; - size_ -= n; - } - - nssv_constexpr14 void remove_suffix(size_type n) - { - assert(n <= size()); - size_ -= n; - } - - nssv_constexpr14 void swap(basic_string_view& other) nssv_noexcept - { - const basic_string_view tmp(other); - other = *this; - *this = tmp; - } - - // 24.4.2.6 String operations: - - size_type copy(CharT* dest, size_type n, size_type pos = 0) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) - { - throw std::out_of_range("nonstd::string_view::copy()"); - } -#endif - const size_type rlen = (std::min)(n, size() - pos); - - (void)Traits::copy(dest, data() + pos, rlen); - - return rlen; - } - - nssv_constexpr14 basic_string_view substr(size_type pos = 0, size_type n = npos) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) - { - throw std::out_of_range("nonstd::string_view::substr()"); - } -#endif - return basic_string_view(data() + pos, (std::min)(n, size() - pos)); - } - - // compare(), 6x: - - nssv_constexpr14 int compare(basic_string_view other) const nssv_noexcept // (1) - { -#if nssv_CPP17_OR_GREATER - if (const int result = Traits::compare(data(), other.data(), (std::min)(size(), other.size()))) -#else - if (const int result = detail::compare(data(), other.data(), (std::min)(size(), other.size()))) -#endif - { - return result; - } - - return size() == other.size() ? 0 : size() < other.size() ? -1 : 1; - } - - nssv_constexpr int compare(size_type pos1, size_type n1, basic_string_view other) const // (2) - { - return substr(pos1, n1).compare(other); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, basic_string_view other, size_type pos2, size_type n2) const // (3) - { - return substr(pos1, n1).compare(other.substr(pos2, n2)); - } - - nssv_constexpr int compare(CharT const* s) const // (4) - { - return compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, CharT const* s) const // (5) - { - return substr(pos1, n1).compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, CharT const* s, size_type n2) const // (6) - { - return substr(pos1, n1).compare(basic_string_view(s, n2)); - } - - // 24.4.2.7 Searching: - - // starts_with(), 3x, since C++20: - - nssv_constexpr bool starts_with(basic_string_view v) const nssv_noexcept // (1) - { - return size() >= v.size() && compare(0, v.size(), v) == 0; - } - - nssv_constexpr bool starts_with(CharT c) const nssv_noexcept // (2) - { - return starts_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool starts_with(CharT const* s) const // (3) - { - return starts_with(basic_string_view(s)); - } - - // ends_with(), 3x, since C++20: - - nssv_constexpr bool ends_with(basic_string_view v) const nssv_noexcept // (1) - { - return size() >= v.size() && compare(size() - v.size(), npos, v) == 0; - } - - nssv_constexpr bool ends_with(CharT c) const nssv_noexcept // (2) - { - return ends_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool ends_with(CharT const* s) const // (3) - { - return ends_with(basic_string_view(s)); - } - - // find(), 4x: - - nssv_constexpr size_type find(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return assert(v.size() == 0 || v.data() != nssv_nullptr) - , pos >= size() - ? npos : to_pos( -#if nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER - detail::search(substr(pos), v) -#else - std::search(cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq) -#endif - ); - } - - nssv_constexpr size_type find(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find(CharT const* s, size_type pos, size_type n) const // (3) - { - return find(basic_string_view(s, n), pos); - } - - nssv_constexpr size_type find(CharT const* s, size_type pos = 0) const // (4) - { - return find(basic_string_view(s), pos); - } - - // rfind(), 4x: - - nssv_constexpr14 size_type rfind(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - if (size() < v.size()) - { - return npos; - } - - if (v.empty()) - { - return (std::min)(size(), pos); - } - - const_iterator last = cbegin() + (std::min)(size() - v.size(), pos) + v.size(); - const_iterator result = std::find_end(cbegin(), last, v.cbegin(), v.cend(), Traits::eq); - - return result != last ? size_type(result - cbegin()) : npos; - } - - nssv_constexpr14 size_type rfind(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return rfind(basic_string_view(&c, 1), pos); - } - - nssv_constexpr14 size_type rfind(CharT const* s, size_type pos, size_type n) const // (3) - { - return rfind(basic_string_view(s, n), pos); - } - - nssv_constexpr14 size_type rfind(CharT const* s, size_type pos = npos) const // (4) - { - return rfind(basic_string_view(s), pos); - } - - // find_first_of(), 4x: - - nssv_constexpr size_type find_first_of(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos(std::find_first_of(cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq)); - } - - nssv_constexpr size_type find_first_of(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_of(CharT const* s, size_type pos, size_type n) const // (3) - { - return find_first_of(basic_string_view(s, n), pos); - } - - nssv_constexpr size_type find_first_of(CharT const* s, size_type pos = 0) const // (4) - { - return find_first_of(basic_string_view(s), pos); - } - - // find_last_of(), 4x: - - nssv_constexpr size_type find_last_of(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_of(v, size() - 1) - : to_pos(std::find_first_of(const_reverse_iterator(cbegin() + pos + 1), crend(), v.cbegin(), v.cend(), Traits::eq)); - } - - nssv_constexpr size_type find_last_of(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_last_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_last_of(CharT const* s, size_type pos = npos) const // (4) - { - return find_last_of(basic_string_view(s), pos); - } - - // find_first_not_of(), 4x: - - nssv_constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos(std::find_if(cbegin() + pos, cend(), not_in_view(v))); - } - - nssv_constexpr size_type find_first_not_of(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_not_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_first_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_first_not_of(CharT const* s, size_type pos = 0) const // (4) - { - return find_first_not_of(basic_string_view(s), pos); - } - - // find_last_not_of(), 4x: - - nssv_constexpr size_type find_last_not_of(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_not_of(v, size() - 1) - : to_pos(std::find_if(const_reverse_iterator(cbegin() + pos + 1), crend(), not_in_view(v))); - } - - nssv_constexpr size_type find_last_not_of(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_not_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_last_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_last_not_of(CharT const* s, size_type pos = npos) const // (4) - { - return find_last_not_of(basic_string_view(s), pos); - } - - // Constants: - -#if nssv_CPP17_OR_GREATER - static nssv_constexpr size_type npos = size_type(-1); -#elif nssv_CPP11_OR_GREATER - enum : size_type { npos = size_type(-1) }; -#else - enum { npos = size_type(-1) }; -#endif - - private: - struct not_in_view - { - const basic_string_view v; - - nssv_constexpr explicit not_in_view(basic_string_view v_) : v(v_) {} - - nssv_constexpr bool operator()(CharT c) const - { - return npos == v.find_first_of(c); - } - }; - - nssv_constexpr size_type to_pos(const_iterator it) const - { - return it == cend() ? npos : size_type(it - cbegin()); - } - - nssv_constexpr size_type to_pos(const_reverse_iterator it) const - { - return it == crend() ? npos : size_type(crend() - it - 1); - } - - nssv_constexpr const_reference data_at(size_type pos) const - { -#if nssv_BETWEEN( nssv_COMPILER_GNUC_VERSION, 1, 500 ) - return data_[pos]; -#else - return assert(pos < size()), data_[pos]; -#endif - } - - private: - const_pointer data_; - size_type size_; - - public: -#if nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - - template< class Allocator > - basic_string_view(std::basic_string const& s) nssv_noexcept - : data_(s.data()) - , size_(s.size()) - {} - -#if nssv_HAVE_EXPLICIT_CONVERSION - - template< class Allocator > - explicit operator std::basic_string() const - { - return to_string(Allocator()); - } - -#endif // nssv_HAVE_EXPLICIT_CONVERSION - -#if nssv_CPP11_OR_GREATER - - template< class Allocator = std::allocator > - std::basic_string - to_string(Allocator const& a = Allocator()) const - { - return std::basic_string(begin(), end(), a); - } - -#else - - std::basic_string - to_string() const - { - return std::basic_string(begin(), end()); - } - - template< class Allocator > - std::basic_string - to_string(Allocator const& a) const - { - return std::basic_string(begin(), end(), a); - } - -#endif // nssv_CPP11_OR_GREATER - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - }; - - // - // Non-member functions: - // - - // 24.4.3 Non-member comparison functions: - // lexicographically compare two string views (function template): - - template< class CharT, class Traits > - nssv_constexpr bool operator== ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator!= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits > - nssv_constexpr bool operator< ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator<= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator> ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator>= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - // Let S be basic_string_view, and sv be an instance of S. - // Implementations shall provide sufficient additional overloads marked - // constexpr and noexcept so that an object t with an implicit conversion - // to S can be compared according to Table 67. - -#if ! nssv_CPP11_OR_GREATER || nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 100, 141 ) - -// accommodate for older compilers: - -// == - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.size() == detail::length(rhs) && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return detail::length(lhs) == rhs.size() && rhs.compare(lhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return !(lhs == rhs); - } - - // < - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) > 0; - } - - // <= - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) >= 0; - } - - // > - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) < 0; - } - - // >= - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) <= 0; - } - -#else // newer compilers: - -#define nssv_BASIC_STRING_VIEW_I(T,U) typename std::decay< basic_string_view >::type - -#if defined(_MSC_VER) // issue 40 -# define nssv_MSVC_ORDER(x) , int=x -#else -# define nssv_MSVC_ORDER(x) /*, int=x*/ -#endif - -// == - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator==( - basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator==( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator!= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator!= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - // < - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator< ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator< ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - // <= - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator<= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator<= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - // > - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator> ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator> ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - // >= - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator>= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator>= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - -#undef nssv_MSVC_ORDER -#undef nssv_BASIC_STRING_VIEW_I - -#endif // compiler-dependent approach to comparisons - - // 24.4.4 Inserters and extractors: - -#if ! nssv_CONFIG_NO_STREAM_INSERTION - - namespace detail { - - template< class Stream > - void write_padding(Stream& os, std::streamsize n) - { - for (std::streamsize i = 0; i < n; ++i) - os.rdbuf()->sputc(os.fill()); - } - - template< class Stream, class View > - Stream& write_to_stream(Stream& os, View const& sv) - { - typename Stream::sentry sentry(os); - - if (!sentry) - return os; - - const std::streamsize length = static_cast(sv.length()); - - // Whether, and how, to pad: - const bool pad = (length < os.width()); - const bool left_pad = pad && (os.flags() & std::ios_base::adjustfield) == std::ios_base::right; - - if (left_pad) - write_padding(os, os.width() - length); - - // Write span characters: - os.rdbuf()->sputn(sv.begin(), length); - - if (pad && !left_pad) - write_padding(os, os.width() - length); - - // Reset output stream width: - os.width(0); - - return os; - } - - } // namespace detail - - template< class CharT, class Traits > - std::basic_ostream& - operator<<( - std::basic_ostream& os, - basic_string_view sv) - { - return detail::write_to_stream(os, sv); - } - -#endif // nssv_CONFIG_NO_STREAM_INSERTION - - // Several typedefs for common character types are provided: - - typedef basic_string_view string_view; - typedef basic_string_view wstring_view; -#if nssv_HAVE_WCHAR16_T - typedef basic_string_view u16string_view; - typedef basic_string_view u32string_view; -#endif - - } -} // namespace nonstd::sv_lite - -// -// 24.4.6 Suffix for basic_string_view literals: -// - -#if nssv_HAVE_USER_DEFINED_LITERALS - -namespace nonstd { - nssv_inline_ns namespace literals { - nssv_inline_ns namespace string_view_literals { - -#if nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - - nssv_constexpr nonstd::sv_lite::string_view operator "" sv(const char* str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator "" sv(const char16_t* str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator "" sv(const char32_t* str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator "" sv(const wchar_t* str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{ str, len }; - } - -#endif // nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - -#if nssv_CONFIG_USR_SV_OPERATOR - - nssv_constexpr nonstd::sv_lite::string_view operator "" _sv(const char* str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator "" _sv(const char16_t* str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator "" _sv(const char32_t* str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator "" _sv(const wchar_t* str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{ str, len }; - } - -#endif // nssv_CONFIG_USR_SV_OPERATOR - - } - } -} // namespace nonstd::literals::string_view_literals - -#endif - -// -// Extensions for std::string: -// - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - namespace sv_lite { - - // Exclude MSVC 14 (19.00): it yields ambiguous to_string(): - -#if nssv_CPP11_OR_GREATER && nssv_COMPILER_MSVC_VERSION != 140 - - template< class CharT, class Traits, class Allocator = std::allocator > - std::basic_string - to_string(basic_string_view v, Allocator const& a = Allocator()) - { - return std::basic_string(v.begin(), v.end(), a); - } - -#else - - template< class CharT, class Traits > - std::basic_string - to_string(basic_string_view v) - { - return std::basic_string(v.begin(), v.end()); - } - - template< class CharT, class Traits, class Allocator > - std::basic_string - to_string(basic_string_view v, Allocator const& a) - { - return std::basic_string(v.begin(), v.end(), a); - } - -#endif // nssv_CPP11_OR_GREATER - - template< class CharT, class Traits, class Allocator > - basic_string_view - to_string_view(std::basic_string const& s) - { - return basic_string_view(s.data(), s.size()); - } - - } -} // namespace nonstd::sv_lite - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -// -// make types and algorithms available in namespace nonstd: -// - -namespace nonstd { - - using sv_lite::basic_string_view; - using sv_lite::string_view; - using sv_lite::wstring_view; - -#if nssv_HAVE_WCHAR16_T - using sv_lite::u16string_view; -#endif -#if nssv_HAVE_WCHAR32_T - using sv_lite::u32string_view; -#endif - - // literal "sv" - - using sv_lite::operator==; - using sv_lite::operator!=; - using sv_lite::operator<; - using sv_lite::operator<=; - using sv_lite::operator>; - using sv_lite::operator>=; - -#if ! nssv_CONFIG_NO_STREAM_INSERTION - using sv_lite::operator<<; -#endif - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - using sv_lite::to_string; - using sv_lite::to_string_view; -#endif - -} // namespace nonstd - -// 24.4.5 Hash support (C++11): - -// Note: The hash value of a string view object is equal to the hash value of -// the corresponding string object. - -#if nssv_HAVE_STD_HASH - -#include - -namespace std { - - template<> - struct hash< nonstd::string_view > - { - public: - std::size_t operator()(nonstd::string_view v) const nssv_noexcept - { - return std::hash()(std::string(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::wstring_view > - { - public: - std::size_t operator()(nonstd::wstring_view v) const nssv_noexcept - { - return std::hash()(std::wstring(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::u16string_view > - { - public: - std::size_t operator()(nonstd::u16string_view v) const nssv_noexcept - { - return std::hash()(std::u16string(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::u32string_view > - { - public: - std::size_t operator()(nonstd::u32string_view v) const nssv_noexcept - { - return std::hash()(std::u32string(v.data(), v.size())); - } - }; - -} // namespace std - -#endif // nssv_HAVE_STD_HASH - -nssv_RESTORE_WARNINGS() - -#endif // nssv_HAVE_STD_STRING_VIEW -#endif // NONSTD_SV_LITE_H_INCLUDED \ No newline at end of file diff --git a/test/test_combiner.cpp b/test/test_combiner.cpp index c21763b9..9c4c88b5 100644 --- a/test/test_combiner.cpp +++ b/test/test_combiner.cpp @@ -133,22 +133,22 @@ TEST(KiwiCppCombiner, Joiner) TEST(KiwiCppCombiner, Allomorph) { - using Tuple = std::tuple; + using Tuple = std::tuple; auto& rule = getCompiledRule(); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"를"}, CondVowel::vowel, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"을"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"를"}, CondVowel::vowel, (uint8_t)0}, + Tuple{ std::u16string_view{u"을"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jko); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"가"}, CondVowel::vowel, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"이"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"가"}, CondVowel::vowel, (uint8_t)0}, + Tuple{ std::u16string_view{u"이"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jks); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"로"}, CondVowel::vocalic, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"으로"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"로"}, CondVowel::vocalic, (uint8_t)0}, + Tuple{ std::u16string_view{u"으로"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jkb); auto joiner = rule.newJoiner(); @@ -182,8 +182,8 @@ TEST(KiwiCppCombiner, Allomorph) EXPECT_EQ(joiner.getU16(), u"북으로"); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"면"}, CondVowel::vocalic, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"으면"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"면"}, CondVowel::vocalic, (uint8_t)0}, + Tuple{ std::u16string_view{u"으면"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::ec); joiner = rule.newJoiner(); From 743f362b38e740fb9e82b9c41296b1dd7f74b7f9 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:11:47 +0900 Subject: [PATCH 36/53] Fix deprecated `std::result_of` --- include/kiwi/ThreadPool.h | 6 +++--- src/sais/mp_utils.hpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/kiwi/ThreadPool.h b/include/kiwi/ThreadPool.h index 848f976f..5925904c 100644 --- a/include/kiwi/ThreadPool.h +++ b/include/kiwi/ThreadPool.h @@ -27,7 +27,7 @@ namespace kiwi template auto enqueue(F&& f, Args&&... args) - ->std::future::type>; + ->std::future::type>; size_t size() const { return workers.size(); } size_t numEnqueued() const { return tasks.size(); } @@ -67,9 +67,9 @@ namespace kiwi template auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> + -> std::future::type> { - using return_type = typename std::result_of::type; + using return_type = typename std::invoke_result::type; auto task = std::make_shared< std::packaged_task >( std::bind(std::forward(f), std::placeholders::_1, std::forward(args)...)); diff --git a/src/sais/mp_utils.hpp b/src/sais/mp_utils.hpp index ccec344a..60e2d912 100644 --- a/src/sais/mp_utils.hpp +++ b/src/sais/mp_utils.hpp @@ -58,7 +58,7 @@ namespace mp ThreadPool(size_t threads = 0); template auto runParallel(size_t workers, F&& f, Args&&... args) - -> std::vector::type>>; + -> std::vector::type>>; ~ThreadPool(); size_t size() const { return workers.size(); } size_t limitedSize() const { return std::min(size(), _limitedSize); }; @@ -106,9 +106,9 @@ namespace mp template auto ThreadPool::runParallel(size_t workers, F&& f, Args&&... args) - -> std::vector::type>> + -> std::vector::type>> { - using return_type = typename std::result_of::type; + using return_type = typename std::invoke_result::type; std::vector> ret; { auto b = std::make_shared(getBarrier(workers)); From c24364f1be11dbbd5e6fc297838c3891ea89a337 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:21:38 +0900 Subject: [PATCH 37/53] remove submodule `mapbox/variant` --- .gitmodules | 3 --- third_party/variant | 1 - 2 files changed, 4 deletions(-) delete mode 160000 third_party/variant diff --git a/.gitmodules b/.gitmodules index e76d2d4b..c5c18bee 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,9 +15,6 @@ [submodule "third_party/cpuinfo"] path = third_party/cpuinfo url = https://github.com/pytorch/cpuinfo -[submodule "third_party/variant"] - path = third_party/variant - url = https://github.com/mapbox/variant [submodule "third_party/eigen"] path = third_party/eigen url = https://gitlab.com/libeigen/eigen diff --git a/third_party/variant b/third_party/variant deleted file mode 160000 index f87fcbda..00000000 --- a/third_party/variant +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f87fcbda9daf13fba47a6a889696b0ad23fc098d From 6560cccad5fc9287dd350baf7006776ecf990816 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:22:49 +0900 Subject: [PATCH 38/53] replace `mapbox/variant` to `std::variant` --- src/Combiner.cpp | 16 ++++++++-------- src/Combiner.h | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Combiner.cpp b/src/Combiner.cpp index ade620d8..19d8a055 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -1115,7 +1115,7 @@ Vector CompiledRule::combineImpl( auto it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { ret.emplace_back(move(p.str)); } @@ -1131,7 +1131,7 @@ Vector CompiledRule::combineImpl( it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { ret.emplace_back(move(p.str)); } @@ -1161,7 +1161,7 @@ tuple CompiledRule::combineOneImpl( auto it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin); KString ret; @@ -1181,7 +1181,7 @@ tuple CompiledRule::combineOneImpl( it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { return make_tuple(p.str, p.leftEnd, p.rightBegin); } @@ -1210,13 +1210,13 @@ tuple CompiledRule::combineOneImpl( Vector> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const { - return mapbox::util::apply_visitor(SearchLeftVisitor{ leftForm, true }, dfa[ruleId]); + return visit(SearchLeftVisitor{ leftForm, true }, dfa[ruleId]); } Vector> CompiledRule::testRightPattern(U16StringView rightForm, size_t ruleId) const { - return mapbox::util::apply_visitor(SearchLeftVisitor{ rightForm, false }, dfaRight[ruleId]); + return visit(SearchLeftVisitor{ rightForm, false }, dfaRight[ruleId]); } vector> CompiledRule::testLeftPattern(U16StringView leftForm, POSTag leftTag, POSTag rightTag, CondVowel cv, CondPolarity cp) const @@ -1231,7 +1231,7 @@ vector> CompiledRule::testLeftPattern(U16Str auto it = findRule(leftTag, rightTag, cv, cp); if (it == map.end()) return ret; - auto p = mapbox::util::apply_visitor(SearchLeftVisitor{ l, true }, dfa[it->second]); + auto p = visit(SearchLeftVisitor{ l, true }, dfa[it->second]); ret.insert(ret.end(), p.begin(), p.end()); return ret; } @@ -1270,7 +1270,7 @@ UnorderedMap> CompiledRule::getRuleIdsByRightTag() const Vector CompiledRule::combine(U16StringView leftForm, U16StringView rightForm, size_t ruleId) const { - return mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[ruleId]); + return visit(CombineVisitor{ leftForm, rightForm }, dfa[ruleId]); } vector CompiledRule::combine(U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv, CondPolarity cp) const diff --git a/src/Combiner.h b/src/Combiner.h index 4ded2ae5..ff83dc9b 100644 --- a/src/Combiner.h +++ b/src/Combiner.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -117,7 +117,7 @@ namespace kiwi template struct VariantFromTuple> { - using type = mapbox::util::variant; + using type = std::variant; }; } From 6155ab0d6d6420411721dce023dc9c1de049d8e2 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:23:27 +0900 Subject: [PATCH 39/53] Update evaluator to c++17 --- tools/Evaluator.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 12511a85..0d1a7c95 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -171,10 +171,10 @@ auto MorphEvaluator::loadTestset(const string& testSetFile) const -> vector tokens; for (size_t i = 1; i < fd.size(); ++i) { - for (auto s : split(fd[i], u' ')) tokens.emplace_back(s.to_string()); + for (auto s : split(fd[i], u' ')) tokens.emplace_back(s); } TestResult tr; - tr.q = fd[0].to_string(); + tr.q = u16string{ fd[0] }; for (auto& t : tokens) tr.a.emplace_back(parseWordPOS(t)); ret.emplace_back(std::move(tr)); } @@ -329,8 +329,8 @@ auto DisambEvaluator::loadTestset(const string& testSetFile) const -> vector Date: Sun, 23 Mar 2025 15:23:51 +0900 Subject: [PATCH 40/53] Fix C API to support `kiwi::ModelType` --- include/kiwi/capi.h | 7 +++++-- src/capi/kiwi_c.cpp | 9 +++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/kiwi/capi.h b/include/kiwi/capi.h index 5bfa6e11..3f33d35a 100644 --- a/include/kiwi/capi.h +++ b/include/kiwi/capi.h @@ -98,8 +98,11 @@ enum KIWI_BUILD_LOAD_TYPO_DICT = 4, KIWI_BUILD_LOAD_MULTI_DICT = 8, KIWI_BUILD_DEFAULT = 15, - KIWI_BUILD_MODEL_TYPE_KNLM = 0x0000, - KIWI_BUILD_MODEL_TYPE_SBG = 0x0100, + KIWI_BUILD_MODEL_TYPE_DEFAULT = 0x0000, + KIWI_BUILD_MODEL_TYPE_KNLM = 0x0100, + KIWI_BUILD_MODEL_TYPE_SBG = 0x0200, + KIWI_BUILD_MODEL_TYPE_CONG = 0x0300, + KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL = 0x0400, }; enum diff --git a/src/capi/kiwi_c.cpp b/src/capi/kiwi_c.cpp index f7d61ec5..858c04ef 100644 --- a/src/capi/kiwi_c.cpp +++ b/src/capi/kiwi_c.cpp @@ -110,8 +110,13 @@ kiwi_builder_h kiwi_builder_init(const char* model_path, int num_threads, int op try { BuildOption buildOption = (BuildOption)(options & 0xFF); - bool useSBG = !!(options & KIWI_BUILD_MODEL_TYPE_SBG); - return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, useSBG ? ModelType::sbg : ModelType::knlm }; + const auto mtMask = options & (KIWI_BUILD_MODEL_TYPE_KNLM | KIWI_BUILD_MODEL_TYPE_SBG | KIWI_BUILD_MODEL_TYPE_CONG | KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL); + const ModelType modelType = mtMask == KIWI_BUILD_MODEL_TYPE_KNLM ? ModelType::knlm + : mtMask == KIWI_BUILD_MODEL_TYPE_SBG ? ModelType::sbg + : mtMask == KIWI_BUILD_MODEL_TYPE_CONG ? ModelType::cong + : mtMask == KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL ? ModelType::congGlobal + : ModelType::none; + return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, modelType }; } catch (...) { From 3457d37095a63b813f3bfa2b155e24c18ce21b93 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 15:56:16 +0900 Subject: [PATCH 41/53] Fix compilation errors --- include/kiwi/TemplateUtils.hpp | 8 ++++++-- src/Combiner.cpp | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/include/kiwi/TemplateUtils.hpp b/include/kiwi/TemplateUtils.hpp index 8c69710b..8c9de09e 100644 --- a/include/kiwi/TemplateUtils.hpp +++ b/include/kiwi/TemplateUtils.hpp @@ -56,7 +56,10 @@ namespace kiwi }; template - struct SeqMax; + struct SeqMax + { + static constexpr std::ptrdiff_t value = 0; + }; template struct SeqMax> @@ -130,7 +133,7 @@ namespace kiwi template class Table { - ValTy table[SeqMax::value + 1]; + std::array::value + 1> table; template void set(seq<>) @@ -153,6 +156,7 @@ namespace kiwi constexpr ValTy operator[](std::ptrdiff_t idx) const { + if (idx < 0 || (size_t)idx >= table.size()) return ValTy{}; return table[idx]; } }; diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 19d8a055..31d20d7c 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include From c31bda1343a41632fee5a66555aa82ace868ff3e Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 20:53:21 +0900 Subject: [PATCH 42/53] Fix compilation errors --- src/MathFunc.hpp | 2 ++ src/PathEvaluator.hpp | 4 ++-- src/SkipBigramModel.hpp | 30 +++++++++++++++++++++++++++- src/SkipBigramModelImpl.hpp | 40 ------------------------------------- src/archImpl/avx2.cpp | 7 +------ src/archImpl/avx512bw.cpp | 7 +------ src/archImpl/avx512vnni.cpp | 7 +------ src/archImpl/avx_vnni.cpp | 7 +------ src/archImpl/neon.cpp | 7 +------ src/archImpl/none.cpp | 12 +---------- src/archImpl/sse2.cpp | 7 +------ src/archImpl/sse4_1.cpp | 7 +------ src/qgemm.hpp | 2 ++ tools/Evaluator.cpp | 2 +- 14 files changed, 44 insertions(+), 97 deletions(-) delete mode 100644 src/SkipBigramModelImpl.hpp diff --git a/src/MathFunc.hpp b/src/MathFunc.hpp index 1645bbc8..a990b1c5 100644 --- a/src/MathFunc.hpp +++ b/src/MathFunc.hpp @@ -1,5 +1,7 @@ #pragma once #include +#include +#include #include "MathFunc.h" #include "SIMD.hpp" diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 89b57ecf..33799b1b 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -358,8 +358,8 @@ namespace kiwi { for (auto& curMorph : cands) { - if (splitComplex && curMorph->getCombined()->complex) continue; - if (blocklist && blocklist->count(curMorph->getCombined())) continue; + if (splitComplex && curMorph->hasComplex()) continue; + if (blocklist && curMorph->hasMorpheme(*blocklist)) continue; // 덧붙은 받침(zCoda)을 위한 지름길 if (curMorph->tag == POSTag::z_coda) diff --git a/src/SkipBigramModel.hpp b/src/SkipBigramModel.hpp index 8ee2c31a..598b6355 100644 --- a/src/SkipBigramModel.hpp +++ b/src/SkipBigramModel.hpp @@ -6,6 +6,7 @@ #include #include "ArchAvailable.h" #include "Knlm.hpp" +#include "MathFunc.h" #include "search.h" namespace kiwi @@ -109,7 +110,34 @@ namespace kiwi return !!vocabValidness[k]; } - float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const; + float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const + { + if (!cnt) return base; + if (!vocabValidness[next]) return base; + +#if defined(__GNUC__) && __GNUC__ < 5 + alignas(256) float arr[windowSize * 2]; +#else + alignas(ArchInfo::alignment) float arr[windowSize * 2]; +#endif + std::fill(arr, arr + windowSize, base); + std::fill(arr + windowSize, arr + windowSize * 2, -INFINITY); + + size_t b = ptrs[next], e = ptrs[next + 1]; + size_t size = e - b; + + for (size_t i = 0; i < cnt; ++i) + { + arr[i] = discnts[history[i]] + base; + float out; + if (nst::search(&keyData[b], &compensations[b], size, history[i], out)) + { + arr[i + windowSize] = out; + } + } + return logSumExp(arr, windowSize * 2) - logWindowSize; + } + }; template diff --git a/src/SkipBigramModelImpl.hpp b/src/SkipBigramModelImpl.hpp deleted file mode 100644 index 2310e1d0..00000000 --- a/src/SkipBigramModelImpl.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once - -#include -#include "SkipBigramModel.hpp" -#include "MathFunc.hpp" - -namespace kiwi -{ - namespace lm - { - template - float SkipBigramModel::evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const - { - if (!cnt) return base; - if (!vocabValidness[next]) return base; - -#if defined(__GNUC__) && __GNUC__ < 5 - alignas(256) float arr[windowSize * 2]; -#else - alignas(ArchInfo::alignment) float arr[windowSize * 2]; -#endif - std::fill(arr, arr + windowSize, base); - std::fill(arr + windowSize, arr + windowSize * 2, -INFINITY); - - size_t b = ptrs[next], e = ptrs[next + 1]; - size_t size = e - b; - - for (size_t i = 0; i < cnt; ++i) - { - arr[i] = discnts[history[i]] + base; - float out; - if (nst::search(&keyData[b], &compensations[b], size, history[i], out)) - { - arr[i + windowSize] = out; - } - } - return LogSumExp{}(arr, std::integral_constant{}) - logWindowSize; - } - } -} diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 0bc09f8a..ec7c1b5e 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -27,11 +27,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/archImpl/avx512bw.cpp b/src/archImpl/avx512bw.cpp index 9797f315..df2cd031 100644 --- a/src/archImpl/avx512bw.cpp +++ b/src/archImpl/avx512bw.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -27,11 +27,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template<> float logSumExp(const float* arr, size_t size) { diff --git a/src/archImpl/avx512vnni.cpp b/src/archImpl/avx512vnni.cpp index ef910ed4..60fc86bd 100644 --- a/src/archImpl/avx512vnni.cpp +++ b/src/archImpl/avx512vnni.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -9,11 +9,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template<> float logSumExp(const float* arr, size_t size) { diff --git a/src/archImpl/avx_vnni.cpp b/src/archImpl/avx_vnni.cpp index 3c59f6f6..178c4b8b 100644 --- a/src/archImpl/avx_vnni.cpp +++ b/src/archImpl/avx_vnni.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -9,11 +9,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/archImpl/neon.cpp b/src/archImpl/neon.cpp index 4ffe4c2f..e37b9ca4 100644 --- a/src/archImpl/neon.cpp +++ b/src/archImpl/neon.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -9,11 +9,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/archImpl/none.cpp b/src/archImpl/none.cpp index b5ff3d82..95060632 100644 --- a/src/archImpl/none.cpp +++ b/src/archImpl/none.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../gemm.h" #include @@ -7,21 +7,11 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/archImpl/sse2.cpp b/src/archImpl/sse2.cpp index 86edbdc4..6669bd7a 100644 --- a/src/archImpl/sse2.cpp +++ b/src/archImpl/sse2.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../gemm.h" #define Eigen EigenSSE2 @@ -8,11 +8,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index 14d2800e..555c5345 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -1,4 +1,4 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" #include "../qgemm.hpp" #include "../gemm.h" @@ -9,11 +9,6 @@ namespace kiwi { namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template float logSumExp(const float* arr, size_t size); template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); template void logSoftmax(float* arr, size_t size); diff --git a/src/qgemm.hpp b/src/qgemm.hpp index 0c34d246..0874b440 100644 --- a/src/qgemm.hpp +++ b/src/qgemm.hpp @@ -1,4 +1,6 @@ #pragma once +#include +#include #include "qgemm.h" #include "SIMD.hpp" diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 0d1a7c95..21412c51 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -44,7 +44,7 @@ inline TokenInfo parseWordPOS(const u16string& str) { auto p = str.rfind('/'); if (p == str.npos) return {}; - u16string form = replace(nonstd::u16string_view(str.data(), p), u"_", u" "); + u16string form = replace(std::u16string_view(str.data(), p), u"_", u" "); if (str[p + 1] == 'E') { if (form[0] == u'아' || form[0] == u'여') form[0] = u'어'; From dbb75821b0052af5502ff6768b1e0b0873c2a464 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 21:23:56 +0900 Subject: [PATCH 43/53] Fix compilation errors --- include/kiwi/Joiner.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 6cd4af89..a5d24e99 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -82,7 +82,7 @@ namespace kiwi static void destructImpl(ErasedVector* self) { auto* target = reinterpret_cast*>(&self->vec); - target->~Vector(); + std::destroy_at(target); } template From dbd90bc14254ccd6d33d0576827f6bc43167d71c Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 21:36:49 +0900 Subject: [PATCH 44/53] Fix compilation errors --- src/CoNgramModel.cpp | 17 +++++++++++------ src/CoNgramModel.hpp | 6 ++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/CoNgramModel.cpp b/src/CoNgramModel.cpp index 5639bd0d..cc436edc 100644 --- a/src/CoNgramModel.cpp +++ b/src/CoNgramModel.cpp @@ -135,7 +135,7 @@ namespace kiwi { nextLmStates.resize(prevLmStates.size() * nextWids.size()); scores.resize(prevLmStates.size() * nextWids.size()); - langMdl->template progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); } } @@ -1211,11 +1211,12 @@ namespace kiwi } template - template void CoNgramModel::progressMatrix( - const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const + { + if constexpr (windowSize > 0) { thread_local TLSForProgressMatrix tls; if (prevStateSize <= (quantized ? 16 : 8) && nextIdSize <= 16) @@ -1225,13 +1226,17 @@ namespace kiwi else { return progressMatrixWSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + } + else + { + return progressMatrixNoWindow(prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); } } template - template - void CoNgramModel::progressMatrix( - const typename std::enable_if<_windowSize == 0, LmStateType>::type* prevStates, const KeyType* nextIds, + void CoNgramModel::progressMatrixNoWindow( + const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const { diff --git a/src/CoNgramModel.hpp b/src/CoNgramModel.hpp index 0906863c..891cff02 100644 --- a/src/CoNgramModel.hpp +++ b/src/CoNgramModel.hpp @@ -311,8 +311,7 @@ namespace kiwi * 새 상태값은 outStates에 저장되고, 각 상태별 확률값은 outScores에 저장된다. * nextIdSize개의 다음 토큰 중 마지막 numValidDistantTokens개의 토큰은 유효한 distant 토큰으로 처리된다. */ - template - void progressMatrix(const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds, + void progressMatrix(const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const; @@ -326,8 +325,7 @@ namespace kiwi size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const; - template - void progressMatrix(const typename std::enable_if<(_windowSize == 0), LmStateType>::type* prevStates, const KeyType* nextIds, + void progressMatrixNoWindow(const LmStateType* prevStates, const KeyType* nextIds, size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, LmStateType* outStates, float* outScores) const; }; From 00d7a2eeb51e5ee24d10273b605f9982e6e13370 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sun, 23 Mar 2025 22:12:48 +0900 Subject: [PATCH 45/53] Fix compilation errors --- src/SIMD.hpp | 2 +- src/archImpl/avx2_qgemm.hpp | 16 ++++++++-------- src/archImpl/avx512_qgemm.hpp | 32 ++++++++++++++++---------------- src/archImpl/sse4_1.cpp | 8 ++++---- src/sais/mp_utils.hpp | 6 +++--- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/SIMD.hpp b/src/SIMD.hpp index 88e795bb..7d395cef 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -571,7 +571,7 @@ namespace kiwi // reduce sum of eight int32_t to one int32_t __m256i sum = _mm256_hadd_epi32(acc, acc); sum = _mm256_hadd_epi32(sum, sum); - return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); + return _mm_cvtsi128_si32(_mm256_castsi256_si128(sum)) + _mm256_extract_epi32(sum, 4); } }; diff --git a/src/archImpl/avx2_qgemm.hpp b/src/archImpl/avx2_qgemm.hpp index 59fb89ee..7963df37 100644 --- a/src/archImpl/avx2_qgemm.hpp +++ b/src/archImpl/avx2_qgemm.hpp @@ -63,10 +63,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum = pbSum; @@ -227,10 +227,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum[0] = pbSum[0]; diff --git a/src/archImpl/avx512_qgemm.hpp b/src/archImpl/avx512_qgemm.hpp index 55f557f4..f5eedce4 100644 --- a/src/archImpl/avx512_qgemm.hpp +++ b/src/archImpl/avx512_qgemm.hpp @@ -65,10 +65,10 @@ namespace kiwi const size_t microM = std::min(packM, m - mi); #define LOOP_BODY(mj) \ const int32_t aOffsets[4] = {\ - mj * 4 < microM ? aIdx[0] * aIdxScale : 0,\ - mj * 4 + 1 < microM ? aIdx[1] * aIdxScale : 0,\ - mj * 4 + 2 < microM ? aIdx[2] * aIdxScale : 0,\ - mj * 4 + 3 < microM ? aIdx[3] * aIdxScale : 0,\ + mj * 4 < microM ? (int32_t)(aIdx[0] * aIdxScale) : 0,\ + mj * 4 + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0,\ + mj * 4 + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0,\ + mj * 4 + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0,\ };\ auto* aPtr = aBase;\ psum = pbSum;\ @@ -237,10 +237,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum[0] = pbSum[0]; @@ -331,10 +331,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum[0] = pbSum[0]; @@ -437,10 +437,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum[0] = pbSum[0]; diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index 555c5345..cf19986d 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -83,10 +83,10 @@ namespace kiwi { const size_t microM = std::min(packM, m - mi); const int32_t aOffsets[4] = { - aIdx[0] * aIdxScale, - 1 < microM ? aIdx[1] * aIdxScale : 0, - 2 < microM ? aIdx[2] * aIdxScale : 0, - 3 < microM ? aIdx[3] * aIdxScale : 0, + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, }; auto* aPtr = aBase; psum = pbSum; diff --git a/src/sais/mp_utils.hpp b/src/sais/mp_utils.hpp index 60e2d912..bd562cb1 100644 --- a/src/sais/mp_utils.hpp +++ b/src/sais/mp_utils.hpp @@ -305,7 +305,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline auto runParallel(ThreadPool* pool, Fn&& func, Args&&... args) -> std::vector { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`runParallel` receives arguments of wrong type."); @@ -331,7 +331,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline void runParallel(ThreadPool* pool, Fn&& func, Args&&... args) { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`runParallel` receives arguments of wrong type."); @@ -360,7 +360,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline void forParallel(ThreadPool* pool, ptrdiff_t start, ptrdiff_t stop, ptrdiff_t step, Fn&& func, Args&&... args) { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`forParallel` receives arguments of wrong type."); From 51dfac482ecfbc054163e8fe810ee258b854c4af Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 24 Mar 2025 02:46:47 +0900 Subject: [PATCH 46/53] Update KiwiJava binding to 0.21.0 --- CMakeLists.txt | 17 +++++++++++++---- bindings/java/JniUtils.hpp | 1 + bindings/java/kiwi_java.cpp | 19 ++++++++++++++++++- bindings/java/kr/pe/bab2min/Kiwi.java | 6 +++--- bindings/java/kr/pe/bab2min/KiwiBuilder.java | 20 ++++++++++++++------ 5 files changed, 49 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5959c7b0..78ed5f57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.12) -project(kiwi VERSION 0.20.4 DESCRIPTION "Kiwi, Korean Intelligent Word Identifier") +project(kiwi VERSION 0.21.0 DESCRIPTION "Kiwi, Korean Intelligent Word Identifier") set ( CMAKE_CXX_STANDARD 17 ) set ( CMAKE_VERBOSE_MAKEFILE true ) @@ -83,7 +83,6 @@ endif() include_directories( include/ ) include_directories( third_party/tclap/include ) include_directories( third_party/cpp-btree ) -include_directories( third_party/variant/include ) include_directories( third_party/eigen ) include_directories( third_party/json/include ) include_directories( third_party/streamvbyte/include ) @@ -158,11 +157,19 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") set( CORE_SRCS ${CORE_SRCS} src/archImpl/avx2.cpp - src/archImpl/avx_vnni.cpp src/archImpl/avx512bw.cpp src/archImpl/avx512vnni.cpp ) + # If AVX-VNNI is supported (MSVC, GCC 11+ or Clang 11+) + set ( AVX_VNNI_SUPPORTED (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11))) + if (AVX_VNNI_SUPPORTED) + set( CORE_SRCS + ${CORE_SRCS} + src/archImpl/avx_vnni.cpp + ) + endif() endif() + if(MSVC) set_source_files_properties(src/archImpl/sse2.cpp PROPERTIES COMPILE_FLAGS "/arch:SSE2") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "/arch:SSE2") @@ -177,9 +184,11 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "-msse2 -msse4.1") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma") - set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavxvnni") set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw") set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw -mavx512vnni") + if (AVX_VNNI_SUPPORTED) + set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavxvnni") + endif() endif() endif() elseif (KIWI_CPU_ARCH MATCHES "arm64") diff --git a/bindings/java/JniUtils.hpp b/bindings/java/JniUtils.hpp index f0cb47d9..55513981 100644 --- a/bindings/java/JniUtils.hpp +++ b/bindings/java/JniUtils.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include diff --git a/bindings/java/kiwi_java.cpp b/bindings/java/kiwi_java.cpp index 9013d9fe..8b19f07c 100644 --- a/bindings/java/kiwi_java.cpp +++ b/bindings/java/kiwi_java.cpp @@ -95,6 +95,23 @@ namespace jni } }; + template<> + struct ValueBuilder : public ValueBuilder + { + using CppType = kiwi::ModelType; + using JniType = jint; + + CppType fromJava(JNIEnv* env, JniType v) + { + return (CppType)v; + } + + JniType toJava(JNIEnv* env, CppType v) + { + return (JniType)v; + } + }; + template<> struct ValueBuilder : public ValueBuilder { @@ -564,7 +581,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) .template method<&JTypoTransformer::scaleCost>("_scaleCost"), jni::define() - .template ctor() + .template ctor() .template method<&JKiwiBuilder::addWord>("addWord") .template method<&JKiwiBuilder::addWord2>("addWord") .template method<&JKiwiBuilder::addPreAnalyzedWord>("addPreAnalyzedWord") diff --git a/bindings/java/kr/pe/bab2min/Kiwi.java b/bindings/java/kr/pe/bab2min/Kiwi.java index 5cfbf60a..802ffb44 100644 --- a/bindings/java/kr/pe/bab2min/Kiwi.java +++ b/bindings/java/kr/pe/bab2min/Kiwi.java @@ -12,7 +12,7 @@ public class Kiwi implements AutoCloseable { private long _inst; - final private static String _version = "0.20.4"; + final private static String _version = "0.21.0"; public static class Match { final static public int none = 0, @@ -345,8 +345,8 @@ public Kiwi(long _inst) { this._inst = _inst; } - public static Kiwi init(String modelPath, int numWorkers, int buildOptions, boolean useSBG) throws Exception { - try(KiwiBuilder b = new KiwiBuilder(modelPath, numWorkers, buildOptions, useSBG)) { + public static Kiwi init(String modelPath, int numWorkers, int buildOptions, int modelType) throws Exception { + try(KiwiBuilder b = new KiwiBuilder(modelPath, numWorkers, buildOptions, modelType)) { return b.build(); } } diff --git a/bindings/java/kr/pe/bab2min/KiwiBuilder.java b/bindings/java/kr/pe/bab2min/KiwiBuilder.java index 5cfdfd17..a09a83a7 100644 --- a/bindings/java/kr/pe/bab2min/KiwiBuilder.java +++ b/bindings/java/kr/pe/bab2min/KiwiBuilder.java @@ -12,6 +12,14 @@ public static class BuildOption { default_ = integrateAllomorph | loadDefaultDict | loadTypoDict | loadMultiDict; } + public static class ModelType { + final static public int none = 0, + knlm = 1, + sbg = 2, + cong = 3, + congGlobal = 4; + } + public static class AnalyzedMorph { public String form; public byte tag = Kiwi.POSTag.nng; @@ -113,20 +121,20 @@ public KiwiBuilder(long _inst) { this._inst = _inst; } - public KiwiBuilder(String modelPath, int numWorkers, int buildOptions, boolean useSBG) { - ctor(modelPath, numWorkers, buildOptions, useSBG); + public KiwiBuilder(String modelPath, int numWorkers, int buildOptions, int modelType) { + ctor(modelPath, numWorkers, buildOptions, modelType); } public KiwiBuilder(String modelPath, int numWorkers, int buildOptions) { - ctor(modelPath, numWorkers, buildOptions, false); + ctor(modelPath, numWorkers, buildOptions, ModelType.none); } public KiwiBuilder(String modelPath, int numWorkers) { - ctor(modelPath, numWorkers, BuildOption.default_, false); + ctor(modelPath, numWorkers, BuildOption.default_, ModelType.none); } public KiwiBuilder(String modelPath) { - ctor(modelPath, 1, BuildOption.default_, false); + ctor(modelPath, 1, BuildOption.default_, ModelType.none); } protected void finalize() throws Exception { @@ -137,7 +145,7 @@ public boolean isAlive() { return _inst != 0; } - private native void ctor(String modelPath, int numWorkers, int buildOptions, boolean useSBG); + private native void ctor(String modelPath, int numWorkers, int buildOptions, int modelType); @Override public native void close() throws Exception; From 6543a1f6429ed8976680259dcd66cafe9900d923 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 24 Mar 2025 14:50:49 +0900 Subject: [PATCH 47/53] Fix compilation errors --- src/KiwiBuilder.cpp | 2 +- src/Knlm.cpp | 198 -------------------------------------------- src/Knlm.hpp | 198 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+), 199 deletions(-) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 24612bfc..fcecc445 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -4,11 +4,11 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "StrUtils.h" #include "FrozenTrie.hpp" -#include "Knlm.hpp" #include "serializer.hpp" #include "count.hpp" #include "FeatureTestor.h" diff --git a/src/Knlm.cpp b/src/Knlm.cpp index 4c09c230..f4683c9b 100644 --- a/src/Knlm.cpp +++ b/src/Knlm.cpp @@ -7,204 +7,6 @@ namespace kiwi { namespace lm { - template - template - void KnLangModel::dequantizeDispatch( - tp::seq, - size_t bits, - Vector& restored_floats, Vector& restored_leaf_ll, - const char* llq_data, size_t llq_size, - const char* gammaq_data, size_t gammaq_size, - const float* ll_table, - const float* gamma_table, - size_t num_non_leaf_nodes, - size_t num_leaf_nodes - ) - { - using Fn = void(*)(Vector&, Vector&, - const char*, size_t, - const char*, size_t, - const float*, - const float*, - size_t, - size_t); - static constexpr Fn table[] = { - &dequantize... - }; - return table[bits - 1](restored_floats, restored_leaf_ll, - llq_data, llq_size, - gammaq_data, gammaq_size, - ll_table, gamma_table, - num_non_leaf_nodes, num_leaf_nodes - ); - } - - template - KnLangModel::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } - { - auto* ptr = reinterpret_cast(base.get()); - auto& header = getHeader(); - const size_t quantized = header.quantized & 0x1F; - const bool compressed = header.quantized & 0x80; - - Vector d_node_size; - auto* node_sizes = reinterpret_cast(ptr + header.node_offset); - key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); - std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); - size_t num_leaf_nodes = 0; - if (compressed) - { - d_node_size.resize(header.num_nodes); - auto qc_header = reinterpret_cast(ptr + header.node_offset); - auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); - QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); - node_sizes = d_node_size.data(); - } - - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) num_non_leaf_nodes++; - else num_leaf_nodes++; - } - - // restore ll & gamma data - Vector restored_leaf_ll, restored_floats; - const float* ll_data = nullptr; - const float* gamma_data = nullptr; - const float* leaf_ll_data = nullptr; - if (quantized) - { - if (quantized > 16) - { - throw std::runtime_error{ "16+ bits quantization not supported." }; - } - - restored_floats.resize(num_non_leaf_nodes * 2); - restored_leaf_ll.resize(num_leaf_nodes); - leaf_ll_data = restored_leaf_ll.data(); - ll_data = &restored_floats[0]; - gamma_data = &restored_floats[num_non_leaf_nodes]; - - const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); - const float* gamma_table = ll_table + ((size_t)1 << quantized); - - dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, - ptr + header.ll_offset, header.gamma_offset - header.ll_offset, - ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, - ll_table, - gamma_table, - num_non_leaf_nodes, - num_leaf_nodes - ); - extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); - } - else - { - ll_data = reinterpret_cast(ptr + header.ll_offset); - gamma_data = reinterpret_cast(ptr + header.gamma_offset); - leaf_ll_data = ll_data + num_non_leaf_nodes; - extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); - } - - size_t htx_vocab_size = header.vocab_size; - if (header.htx_offset) - { - htx_data = reinterpret_cast(ptr + header.htx_offset); - htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; - extra_buf = toAlignedPtr(htx_data + header.vocab_size); - } - - if (!header.extra_buf_size) - { - extra_buf = nullptr; - } - - // restore node's data - node_data = make_unique(num_non_leaf_nodes); - all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); - value_data = &all_value_data[htx_vocab_size]; - std::fill(&all_value_data[0], value_data, 0); - - size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; - Vector> key_ranges; - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) - { - auto& node = node_data[non_leaf_idx]; - if (!key_ranges.empty()) - { - auto& back = key_ranges.back(); - value_data[back[1]] = non_leaf_idx - back[0]; - } - node.num_nexts = node_sizes[i]; - node.next_offset = next_offset; - node.ll = ll_data[non_leaf_idx]; - node.gamma = gamma_data[non_leaf_idx]; - next_offset += node_sizes[i]; - key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); - non_leaf_idx++; - } - else - { - auto& back = key_ranges.back(); - reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; - back[1]++; - while (key_ranges.back()[1] == key_ranges.back()[2]) - { - key_ranges.pop_back(); - if (key_ranges.empty()) break; - key_ranges.back()[1]++; - } - leaf_idx++; - } - } - - for (size_t i = 0; i < node_data[0].num_nexts; ++i) - { - auto k = key_data[i]; - auto v = value_data[i]; - all_value_data[k] = v; - } - - Vector tempBuf; - for (size_t i = 0; i < non_leaf_idx; ++i) - { - auto& node = node_data[i]; - nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); - } - - if (htx_data) - { - ptrdiff_t node = 0; - progress(node, (KeyType)header.bos_id); - unk_ll = getLL(node, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); - } - else - { - unk_ll = getLL(0, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, (KeyType)header.bos_id); - } - - Deque dq; - for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) - { - auto p = dq.front(); - for (size_t i = 0; i < p->num_nexts; ++i) - { - auto k = key_data[p->next_offset + i]; - auto v = value_data[p->next_offset + i]; - if (v <= 0) continue; - auto* child = &p[v]; - child->lower = findLowerNode(p, k) - child; - dq.emplace_back(child); - } - } - } - template float KnLangModel::getLL(ptrdiff_t node_idx, KeyType next) const { diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 18531153..180e507f 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -967,6 +967,204 @@ namespace kiwi extra_buf, extra_buf_size); } } + + template + template + void KnLangModel::dequantizeDispatch( + tp::seq, + size_t bits, + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + using Fn = void(*)(Vector&, Vector&, + const char*, size_t, + const char*, size_t, + const float*, + const float*, + size_t, + size_t); + static constexpr Fn table[] = { + &dequantize... + }; + return table[bits - 1](restored_floats, restored_leaf_ll, + llq_data, llq_size, + gammaq_data, gammaq_size, + ll_table, gamma_table, + num_non_leaf_nodes, num_leaf_nodes + ); + } + + template + KnLangModel::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } + { + auto* ptr = reinterpret_cast(base.get()); + auto& header = getHeader(); + const size_t quantized = header.quantized & 0x1F; + const bool compressed = header.quantized & 0x80; + + Vector d_node_size; + auto* node_sizes = reinterpret_cast(ptr + header.node_offset); + key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); + std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); + size_t num_leaf_nodes = 0; + if (compressed) + { + d_node_size.resize(header.num_nodes); + auto qc_header = reinterpret_cast(ptr + header.node_offset); + auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); + QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); + node_sizes = d_node_size.data(); + } + + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) num_non_leaf_nodes++; + else num_leaf_nodes++; + } + + // restore ll & gamma data + Vector restored_leaf_ll, restored_floats; + const float* ll_data = nullptr; + const float* gamma_data = nullptr; + const float* leaf_ll_data = nullptr; + if (quantized) + { + if (quantized > 16) + { + throw std::runtime_error{ "16+ bits quantization not supported." }; + } + + restored_floats.resize(num_non_leaf_nodes * 2); + restored_leaf_ll.resize(num_leaf_nodes); + leaf_ll_data = restored_leaf_ll.data(); + ll_data = &restored_floats[0]; + gamma_data = &restored_floats[num_non_leaf_nodes]; + + const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); + const float* gamma_table = ll_table + ((size_t)1 << quantized); + + dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, + ptr + header.ll_offset, header.gamma_offset - header.ll_offset, + ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, + ll_table, + gamma_table, + num_non_leaf_nodes, + num_leaf_nodes + ); + extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); + } + else + { + ll_data = reinterpret_cast(ptr + header.ll_offset); + gamma_data = reinterpret_cast(ptr + header.gamma_offset); + leaf_ll_data = ll_data + num_non_leaf_nodes; + extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); + } + + size_t htx_vocab_size = header.vocab_size; + if (header.htx_offset) + { + htx_data = reinterpret_cast(ptr + header.htx_offset); + htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; + extra_buf = toAlignedPtr(htx_data + header.vocab_size); + } + + if (!header.extra_buf_size) + { + extra_buf = nullptr; + } + + // restore node's data + node_data = make_unique(num_non_leaf_nodes); + all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); + value_data = &all_value_data[htx_vocab_size]; + std::fill(&all_value_data[0], value_data, 0); + + size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; + Vector> key_ranges; + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) + { + auto& node = node_data[non_leaf_idx]; + if (!key_ranges.empty()) + { + auto& back = key_ranges.back(); + value_data[back[1]] = non_leaf_idx - back[0]; + } + node.num_nexts = node_sizes[i]; + node.next_offset = next_offset; + node.ll = ll_data[non_leaf_idx]; + node.gamma = gamma_data[non_leaf_idx]; + next_offset += node_sizes[i]; + key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); + non_leaf_idx++; + } + else + { + auto& back = key_ranges.back(); + reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; + back[1]++; + while (key_ranges.back()[1] == key_ranges.back()[2]) + { + key_ranges.pop_back(); + if (key_ranges.empty()) break; + key_ranges.back()[1]++; + } + leaf_idx++; + } + } + + for (size_t i = 0; i < node_data[0].num_nexts; ++i) + { + auto k = key_data[i]; + auto v = value_data[i]; + all_value_data[k] = v; + } + + Vector tempBuf; + for (size_t i = 0; i < non_leaf_idx; ++i) + { + auto& node = node_data[i]; + nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); + } + + if (htx_data) + { + ptrdiff_t node = 0; + progress(node, (KeyType)header.bos_id); + unk_ll = getLL(node, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); + } + else + { + unk_ll = getLL(0, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, (KeyType)header.bos_id); + } + + Deque dq; + for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->num_nexts; ++i) + { + auto k = key_data[p->next_offset + i]; + auto v = value_data[p->next_offset + i]; + if (v <= 0) continue; + auto* child = &p[v]; + child->lower = findLowerNode(p, k) - child; + dq.emplace_back(child); + } + } + } } template From 596905810778a77dc221dc079c25340e4137abe1 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 24 Mar 2025 21:08:56 +0900 Subject: [PATCH 48/53] Add support for AVX-VNNI and update SIMD implementation --- CMakeLists.txt | 13 ++++++++++++- src/ArchAvailable.h | 8 ++++++++ src/SIMD.hpp | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 78ed5f57..2143ad9e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,13 @@ if(NOT KIWI_CPU_ARCH) set(KIWI_CPU_ARCH "${KIWI_CPU_ARCH}" PARENT_SCOPE) endif() +set( AVX_VNNI_SUPPORTED (KIWI_USE_CPUINFO AND + (MSVC OR + (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR + (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) + ) +)) + if(APPLE) set(CMAKE_OSX_ARCHITECTURES "${KIWI_CPU_ARCH}") endif() @@ -120,6 +127,11 @@ if(KIWI_USE_CPUINFO) ) endif() +if (AVX_VNNI_SUPPORTED) + message(STATUS "AVX-VNNI is supported") + set ( ADDITIONAL_FLAGS ${ADDITIONAL_FLAGS} "-DKIWI_AVX_VNNI_SUPPORTED" ) +endif() + if(MSVC) set ( CMAKE_C_FLAGS_DEBUG "-DDEBUG -DC_FLAGS -Zi -Od /utf-8 /bigobj" ) set ( CMAKE_CXX_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}" ) @@ -161,7 +173,6 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") src/archImpl/avx512vnni.cpp ) # If AVX-VNNI is supported (MSVC, GCC 11+ or Clang 11+) - set ( AVX_VNNI_SUPPORTED (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11))) if (AVX_VNNI_SUPPORTED) set( CORE_SRCS ${CORE_SRCS} diff --git a/src/ArchAvailable.h b/src/ArchAvailable.h index c13f4ff6..d141e1e4 100644 --- a/src/ArchAvailable.h +++ b/src/ArchAvailable.h @@ -14,7 +14,9 @@ namespace kiwi #if CPUINFO_ARCH_X86_64 static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -28,7 +30,9 @@ namespace kiwi #ifdef KIWI_ARCH_X86_64 static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -48,7 +52,9 @@ namespace kiwi #if CPUINFO_ARCH_X86_64 static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1) #endif @@ -59,7 +65,9 @@ namespace kiwi #ifdef KIWI_ARCH_X86_64 static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1) #endif diff --git a/src/SIMD.hpp b/src/SIMD.hpp index 7d395cef..6e894c00 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -600,7 +600,7 @@ namespace kiwi // reduce sum of eight int32_t to one int32_t __m256i sum = _mm256_hadd_epi32(acc, acc); sum = _mm256_hadd_epi32(sum, sum); - return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); + return _mm_cvtsi128_si32(_mm256_castsi256_si128(sum)) + _mm256_extract_epi32(sum, 4); } }; From 221ff5947c4cacbd89da81ca3a641cec57e143d7 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 24 Mar 2025 21:34:19 +0900 Subject: [PATCH 49/53] Fix workflows --- .github/workflows/arm64_centos7.yml | 2 +- .github/workflows/centos7.yml | 4 ++-- .github/workflows/macos.yml | 4 ++-- .github/workflows/ppc64le_centos7.yml | 2 +- .github/workflows/release.yml | 2 +- .github/workflows/ubuntu.yml | 4 ++-- .github/workflows/windows.yml | 4 ++-- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/arm64_centos7.yml b/.github/workflows/arm64_centos7.yml index 28f93c70..4ece78d0 100644 --- a/.github/workflows/arm64_centos7.yml +++ b/.github/workflows/arm64_centos7.yml @@ -34,7 +34,7 @@ jobs: - name: Test run: | ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ cp -r build /artifacts/ cp -r eval_results /artifacts/ - name: Benchmark diff --git a/.github/workflows/centos7.yml b/.github/workflows/centos7.yml index 75c45bc3..1cf15169 100644 --- a/.github/workflows/centos7.yml +++ b/.github/workflows/centos7.yml @@ -40,8 +40,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - run: tar -zcvf arts.tgz build/*kiwi* build/test/*kiwi* eval_results/*.txt build/bindings/java/*.jar - name: Archive binaries uses: actions/upload-artifact@v4 diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index d3d7468d..db732322 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -60,8 +60,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - name: Run Benchmark run: | curl -OL https://latina.bab2min.pe.kr/_data/kowiki1000.txt diff --git a/.github/workflows/ppc64le_centos7.yml b/.github/workflows/ppc64le_centos7.yml index 00fa49db..88fd5994 100644 --- a/.github/workflows/ppc64le_centos7.yml +++ b/.github/workflows/ppc64le_centos7.yml @@ -28,7 +28,7 @@ jobs: mkdir build && pushd build && cmake -DCMAKE_BUILD_TYPE=Release -DKIWI_USE_MIMALLOC=0 -DKIWI_JAVA_BINDING=1 .. make -j2 && popd ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ cp -r build /artifacts/ cp -r eval_results /artifacts/ - name: Archive binaries diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index fe6d1c3d..ab45f6b1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -220,7 +220,7 @@ jobs: - name: Test run: | ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -o eval_results/ - name: Release run: | cd build diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index 34742212..61de6496 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -60,8 +60,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - name: Run Benchmark run: | curl -OL https://latina.bab2min.pe.kr/_data/kowiki1000.txt diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index ac4ea5cf..62414018 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -35,8 +35,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - .\build\Release\kiwi-evaluator.exe -m .\models\base (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ - .\build\Release\kiwi-evaluator.exe -m .\models\base --sbg (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ + .\build\Release\kiwi-evaluator.exe -m .\models\base -t knlm --morph (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ + .\build\Release\kiwi-evaluator.exe -m .\models\base -t sbg --morph (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ - name: Archive binaries uses: actions/upload-artifact@v4 with: From 5a5e076dc7b3524f4e538887fed45da9350140d9 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 24 Mar 2025 21:44:09 +0900 Subject: [PATCH 50/53] Fix CMakeLists.txt --- CMakeLists.txt | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2143ad9e..66f64f49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,12 +38,16 @@ if(NOT KIWI_CPU_ARCH) set(KIWI_CPU_ARCH "${KIWI_CPU_ARCH}" PARENT_SCOPE) endif() -set( AVX_VNNI_SUPPORTED (KIWI_USE_CPUINFO AND + +if (KIWI_USE_CPUINFO AND (MSVC OR - (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) OR - (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 11) + ((CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") AND CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 11) ) -)) +) + set ( AVX_VNNI_SUPPORTED ON ) +else() + set ( AVX_VNNI_SUPPORTED OFF ) +endif() if(APPLE) set(CMAKE_OSX_ARCHITECTURES "${KIWI_CPU_ARCH}") From 8be959b349c94c928599f407b92a329ae4afb576 Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 25 Mar 2025 01:00:04 +0900 Subject: [PATCH 51/53] Fix CMakeLists.txt --- CMakeLists.txt | 9 ++++++--- bindings/java/CMakeLists.txt | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 66f64f49..5e44cf08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,7 +98,7 @@ include_directories( third_party/eigen ) include_directories( third_party/json/include ) include_directories( third_party/streamvbyte/include ) add_subdirectory( third_party/streamvbyte ) -set ( STREAMBYTE_OBJECTS +set ( STREAMVBYTE_OBJECTS $ ) if(KIWI_USE_CPUINFO) @@ -223,14 +223,14 @@ add_library( "${PROJECT_NAME}_static" STATIC ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_STATIC} - ${STREAMBYTE_OBJECTS} + ${STREAMVBYTE_OBJECTS} ) add_library( "${PROJECT_NAME}" SHARED ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_SHARED} - ${STREAMBYTE_OBJECTS} + ${STREAMVBYTE_OBJECTS} ) # Install the kiwi library as well as header files to (`include/kiwi` directory) @@ -299,6 +299,9 @@ if(MSVC) target_compile_options("${PROJECT_NAME}_static" PUBLIC /MT ) + target_compile_options("streamvbyte" PUBLIC + /MT + ) endif() target_compile_options("${PROJECT_NAME}" PUBLIC diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt index ffbc188d..8dcf8f01 100644 --- a/bindings/java/CMakeLists.txt +++ b/bindings/java/CMakeLists.txt @@ -10,6 +10,7 @@ set(pkg_name "KiwiJava-${PROJECT_VERSION}") add_library (${pkg_name} SHARED kiwi_java.cpp $ $ + $ ) if(UNIX AND NOT APPLE) target_link_libraries( ${pkg_name} From dcd7516f65a86ebcc5d9cfc62330a1e5df7ea897 Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 25 Mar 2025 01:16:42 +0900 Subject: [PATCH 52/53] Add 64-bit mode check for Issue205 test --- test/test_cpp.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 141ad563..0c74452c 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -1679,6 +1679,12 @@ TEST(KiwiCpp, IssueP189) TEST(KiwiCpp, Issue205) { + if (sizeof(void*) != 8) + { + std::cerr << "This test is only available in 64-bit mode" << std::endl; + return; + } + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, }; builder.addWord(u"함박 스테이크"); auto kiwi1 = builder.build(); From 3d586d362f0e19e04cc182e63f8808ce4149068a Mon Sep 17 00:00:00 2001 From: bab2min Date: Tue, 25 Mar 2025 11:16:26 +0900 Subject: [PATCH 53/53] Remove warnings --- CMakeLists.txt | 6 ++++++ src/Combiner.cpp | 2 +- src/sais/mp_utils.hpp | 2 +- test/test_cpp.cpp | 6 +++--- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5e44cf08..8e3e639f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,6 +160,12 @@ else() set ( CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE} -g3") set ( CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}") set ( CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_EXE_LINKER_FLAGS_RELEASE}" ) + + if (APPLE) + set ( CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wno-unqualified-std-cast-call" ) + set ( CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Wno-unqualified-std-cast-call" ) + set ( CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -Wno-unqualified-std-cast-call" ) + endif() endif() if (KIWI_CPU_ARCH MATCHES "x86_64") diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 31d20d7c..9d21608e 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -780,7 +780,7 @@ void RuleSet::loadRules(istream& istr) while (getline(istr, line)) { if (line[0] == '#') continue; - while (!line.empty() && line.back() < 0x80 && isSpace(line.back())) line.pop_back(); + while (!line.empty() && ((uint8_t)line.back() < 0x80) && isSpace(line.back())) line.pop_back(); if (line.empty()) continue; auto fields = split(line, '\t'); diff --git a/src/sais/mp_utils.hpp b/src/sais/mp_utils.hpp index bd562cb1..f176edbe 100644 --- a/src/sais/mp_utils.hpp +++ b/src/sais/mp_utils.hpp @@ -396,7 +396,7 @@ namespace mp ThreadPool* pool; public: OverrideLimitedSize(ThreadPool* _pool, size_t newSize) - : pool{ _pool }, prevSize{ pool ? pool->limitedSize() : -1 } + : pool{ _pool }, prevSize{ _pool ? _pool->limitedSize() : -1 } { if (pool) pool->_limitedSize = newSize; } diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 0c74452c..481ef8ba 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -432,7 +432,7 @@ TEST(KiwiCpp, HSDataset) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; dataset.reset(); - while (s = dataset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = dataset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s; @@ -456,7 +456,7 @@ TEST(KiwiCpp, HSDataset) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; trainset.reset(); - while (s = trainset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = trainset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s; @@ -468,7 +468,7 @@ TEST(KiwiCpp, HSDataset) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; devset.reset(); - while (s = devset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = devset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s;