diff --git a/.github/workflows/arm64_centos7.yml b/.github/workflows/arm64_centos7.yml index 28f93c70..4ece78d0 100644 --- a/.github/workflows/arm64_centos7.yml +++ b/.github/workflows/arm64_centos7.yml @@ -34,7 +34,7 @@ jobs: - name: Test run: | ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ cp -r build /artifacts/ cp -r eval_results /artifacts/ - name: Benchmark diff --git a/.github/workflows/centos7.yml b/.github/workflows/centos7.yml index 75c45bc3..1cf15169 100644 --- a/.github/workflows/centos7.yml +++ b/.github/workflows/centos7.yml @@ -40,8 +40,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - run: tar -zcvf arts.tgz build/*kiwi* build/test/*kiwi* eval_results/*.txt build/bindings/java/*.jar - name: Archive binaries uses: actions/upload-artifact@v4 diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index d3d7468d..db732322 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -60,8 +60,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - name: Run Benchmark run: | curl -OL https://latina.bab2min.pe.kr/_data/kowiki1000.txt diff --git a/.github/workflows/ppc64le_centos7.yml b/.github/workflows/ppc64le_centos7.yml index 00fa49db..88fd5994 100644 --- a/.github/workflows/ppc64le_centos7.yml +++ b/.github/workflows/ppc64le_centos7.yml @@ -28,7 +28,7 @@ jobs: mkdir build && pushd build && cmake -DCMAKE_BUILD_TYPE=Release -DKIWI_USE_MIMALLOC=0 -DKIWI_JAVA_BINDING=1 .. make -j2 && popd ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ cp -r build /artifacts/ cp -r eval_results /artifacts/ - name: Archive binaries diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index fe6d1c3d..ab45f6b1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -220,7 +220,7 @@ jobs: - name: Test run: | ./build/test/kiwi-test - mkdir eval_results && ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ + mkdir eval_results && ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -o eval_results/ - name: Release run: | cd build diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index 34742212..61de6496 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -60,8 +60,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt -o eval_results/ - ./build/kiwi-evaluator -m ./models/base eval_data/*.txt --sbg -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t knlm -o eval_results/ + ./build/kiwi-evaluator -m ./models/base --morph eval_data/*.txt -t sbg -o eval_results/ - name: Run Benchmark run: | curl -OL https://latina.bab2min.pe.kr/_data/kowiki1000.txt diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index ac4ea5cf..62414018 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -35,8 +35,8 @@ jobs: - name: Run Evaluator run: | mkdir eval_results - .\build\Release\kiwi-evaluator.exe -m .\models\base (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ - .\build\Release\kiwi-evaluator.exe -m .\models\base --sbg (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ + .\build\Release\kiwi-evaluator.exe -m .\models\base -t knlm --morph (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ + .\build\Release\kiwi-evaluator.exe -m .\models\base -t sbg --morph (Get-ChildItem eval_data\*.txt | Select-Object -Expand FullName) -o eval_results\ - name: Archive binaries uses: actions/upload-artifact@v4 with: diff --git a/.gitmodules b/.gitmodules index a9fcfa31..c5c18bee 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,12 +15,12 @@ [submodule "third_party/cpuinfo"] path = third_party/cpuinfo url = https://github.com/pytorch/cpuinfo -[submodule "third_party/variant"] - path = third_party/variant - url = https://github.com/mapbox/variant [submodule "third_party/eigen"] path = third_party/eigen url = https://gitlab.com/libeigen/eigen [submodule "third_party/json"] path = third_party/json url = https://github.com/nlohmann/json +[submodule "third_party/streamvbyte"] + path = third_party/streamvbyte + url = https://github.com/fast-pack/streamvbyte diff --git a/CMakeLists.txt b/CMakeLists.txt index 28779f64..8e3e639f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,8 @@ cmake_minimum_required(VERSION 3.12) -project(kiwi VERSION 0.20.4 DESCRIPTION "Kiwi, Korean Intelligent Word Identifier") +project(kiwi VERSION 0.21.0 DESCRIPTION "Kiwi, Korean Intelligent Word Identifier") -set ( CMAKE_CXX_STANDARD 14 ) +set ( CMAKE_CXX_STANDARD 17 ) set ( CMAKE_VERBOSE_MAKEFILE true ) option(KIWI_USE_MIMALLOC "Use mimalloc for faster memory allocation" ON) @@ -38,6 +38,17 @@ if(NOT KIWI_CPU_ARCH) set(KIWI_CPU_ARCH "${KIWI_CPU_ARCH}" PARENT_SCOPE) endif() + +if (KIWI_USE_CPUINFO AND + (MSVC OR + ((CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") AND CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 11) + ) +) + set ( AVX_VNNI_SUPPORTED ON ) +else() + set ( AVX_VNNI_SUPPORTED OFF ) +endif() + if(APPLE) set(CMAKE_OSX_ARCHITECTURES "${KIWI_CPU_ARCH}") endif() @@ -45,10 +56,11 @@ endif() set ( CORE_SRCS src/ArchUtils.cpp src/Combiner.cpp + src/CoNgramModel.cpp + src/Dataset.cpp src/Form.cpp src/FeatureTestor.cpp src/FileUtils.cpp - src/Dataset.cpp src/Joiner.cpp src/Kiwi.cpp src/KiwiBuilder.cpp @@ -57,6 +69,7 @@ set ( CORE_SRCS src/PatternMatcher.cpp src/search.cpp src/ScriptType.cpp + src/SkipBigramModel.cpp src/SubstringExtractor.cpp src/SwTokenizer.cpp src/TagUtils.cpp @@ -81,9 +94,13 @@ endif() include_directories( include/ ) include_directories( third_party/tclap/include ) include_directories( third_party/cpp-btree ) -include_directories( third_party/variant/include ) include_directories( third_party/eigen ) include_directories( third_party/json/include ) +include_directories( third_party/streamvbyte/include ) +add_subdirectory( third_party/streamvbyte ) +set ( STREAMVBYTE_OBJECTS + $ +) if(KIWI_USE_CPUINFO) message(STATUS "Use cpuinfo") include_directories( third_party/cpuinfo/include ) @@ -98,9 +115,6 @@ if(KIWI_USE_CPUINFO) set ( ADDITIONAL_FLAGS ${ADDITIONAL_FLAGS} "-DKIWI_USE_CPUINFO" ) if(MSVC) - target_compile_options("clog" PUBLIC - /MT - ) target_compile_options("cpuinfo" PUBLIC /MT ) @@ -110,15 +124,18 @@ if(KIWI_USE_CPUINFO) endif() set ( CPUINFO_OBJECTS_STATIC - $ $ ) set ( CPUINFO_OBJECTS_SHARED - $ $ ) endif() +if (AVX_VNNI_SUPPORTED) + message(STATUS "AVX-VNNI is supported") + set ( ADDITIONAL_FLAGS ${ADDITIONAL_FLAGS} "-DKIWI_AVX_VNNI_SUPPORTED" ) +endif() + if(MSVC) set ( CMAKE_C_FLAGS_DEBUG "-DDEBUG -DC_FLAGS -Zi -Od /utf-8 /bigobj" ) set ( CMAKE_CXX_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG}" ) @@ -143,6 +160,12 @@ else() set ( CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE} -g3") set ( CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}") set ( CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO "${CMAKE_EXE_LINKER_FLAGS_RELEASE}" ) + + if (APPLE) + set ( CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wno-unqualified-std-cast-call" ) + set ( CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Wno-unqualified-std-cast-call" ) + set ( CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -Wno-unqualified-std-cast-call" ) + endif() endif() if (KIWI_CPU_ARCH MATCHES "x86_64") @@ -157,21 +180,36 @@ if (KIWI_CPU_ARCH MATCHES "x86_64") ${CORE_SRCS} src/archImpl/avx2.cpp src/archImpl/avx512bw.cpp + src/archImpl/avx512vnni.cpp ) + # If AVX-VNNI is supported (MSVC, GCC 11+ or Clang 11+) + if (AVX_VNNI_SUPPORTED) + set( CORE_SRCS + ${CORE_SRCS} + src/archImpl/avx_vnni.cpp + ) + endif() endif() + if(MSVC) set_source_files_properties(src/archImpl/sse2.cpp PROPERTIES COMPILE_FLAGS "/arch:SSE2") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "/arch:SSE2") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX512") + set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX512") endif() else() set_source_files_properties(src/archImpl/sse2.cpp PROPERTIES COMPILE_FLAGS "-msse2") set_source_files_properties(src/archImpl/sse4_1.cpp PROPERTIES COMPILE_FLAGS "-msse2 -msse4.1") if (KIWI_USE_CPUINFO) set_source_files_properties(src/archImpl/avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma") - set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw") + set_source_files_properties(src/archImpl/avx512bw.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw") + set_source_files_properties(src/archImpl/avx512vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavx512f -mavx512vl -mavx512dq -mavx512bw -mavx512vnni") + if (AVX_VNNI_SUPPORTED) + set_source_files_properties(src/archImpl/avx_vnni.cpp PROPERTIES COMPILE_FLAGS "-mavx -mavx2 -mfma -mavxvnni") + endif() endif() endif() elseif (KIWI_CPU_ARCH MATCHES "arm64") @@ -191,12 +229,14 @@ add_library( "${PROJECT_NAME}_static" STATIC ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_STATIC} + ${STREAMVBYTE_OBJECTS} ) add_library( "${PROJECT_NAME}" SHARED ${CORE_SRCS} src/capi/kiwi_c.cpp ${CPUINFO_OBJECTS_SHARED} + ${STREAMVBYTE_OBJECTS} ) # Install the kiwi library as well as header files to (`include/kiwi` directory) @@ -265,6 +305,9 @@ if(MSVC) target_compile_options("${PROJECT_NAME}_static" PUBLIC /MT ) + target_compile_options("streamvbyte" PUBLIC + /MT + ) endif() target_compile_options("${PROJECT_NAME}" PUBLIC diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt index 6c26dea3..8dcf8f01 100644 --- a/bindings/java/CMakeLists.txt +++ b/bindings/java/CMakeLists.txt @@ -9,8 +9,8 @@ set(CMAKE_JAVA_COMPILE_FLAGS -source 8 -target 8 -encoding utf-8) set(pkg_name "KiwiJava-${PROJECT_VERSION}") add_library (${pkg_name} SHARED kiwi_java.cpp $ - $ $ + $ ) if(UNIX AND NOT APPLE) target_link_libraries( ${pkg_name} diff --git a/bindings/java/JniUtils.hpp b/bindings/java/JniUtils.hpp index f0cb47d9..55513981 100644 --- a/bindings/java/JniUtils.hpp +++ b/bindings/java/JniUtils.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include diff --git a/bindings/java/kiwi_java.cpp b/bindings/java/kiwi_java.cpp index 9013d9fe..8b19f07c 100644 --- a/bindings/java/kiwi_java.cpp +++ b/bindings/java/kiwi_java.cpp @@ -95,6 +95,23 @@ namespace jni } }; + template<> + struct ValueBuilder : public ValueBuilder + { + using CppType = kiwi::ModelType; + using JniType = jint; + + CppType fromJava(JNIEnv* env, JniType v) + { + return (CppType)v; + } + + JniType toJava(JNIEnv* env, CppType v) + { + return (JniType)v; + } + }; + template<> struct ValueBuilder : public ValueBuilder { @@ -564,7 +581,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) .template method<&JTypoTransformer::scaleCost>("_scaleCost"), jni::define() - .template ctor() + .template ctor() .template method<&JKiwiBuilder::addWord>("addWord") .template method<&JKiwiBuilder::addWord2>("addWord") .template method<&JKiwiBuilder::addPreAnalyzedWord>("addPreAnalyzedWord") diff --git a/bindings/java/kr/pe/bab2min/Kiwi.java b/bindings/java/kr/pe/bab2min/Kiwi.java index 5cfbf60a..802ffb44 100644 --- a/bindings/java/kr/pe/bab2min/Kiwi.java +++ b/bindings/java/kr/pe/bab2min/Kiwi.java @@ -12,7 +12,7 @@ public class Kiwi implements AutoCloseable { private long _inst; - final private static String _version = "0.20.4"; + final private static String _version = "0.21.0"; public static class Match { final static public int none = 0, @@ -345,8 +345,8 @@ public Kiwi(long _inst) { this._inst = _inst; } - public static Kiwi init(String modelPath, int numWorkers, int buildOptions, boolean useSBG) throws Exception { - try(KiwiBuilder b = new KiwiBuilder(modelPath, numWorkers, buildOptions, useSBG)) { + public static Kiwi init(String modelPath, int numWorkers, int buildOptions, int modelType) throws Exception { + try(KiwiBuilder b = new KiwiBuilder(modelPath, numWorkers, buildOptions, modelType)) { return b.build(); } } diff --git a/bindings/java/kr/pe/bab2min/KiwiBuilder.java b/bindings/java/kr/pe/bab2min/KiwiBuilder.java index 5cfdfd17..a09a83a7 100644 --- a/bindings/java/kr/pe/bab2min/KiwiBuilder.java +++ b/bindings/java/kr/pe/bab2min/KiwiBuilder.java @@ -12,6 +12,14 @@ public static class BuildOption { default_ = integrateAllomorph | loadDefaultDict | loadTypoDict | loadMultiDict; } + public static class ModelType { + final static public int none = 0, + knlm = 1, + sbg = 2, + cong = 3, + congGlobal = 4; + } + public static class AnalyzedMorph { public String form; public byte tag = Kiwi.POSTag.nng; @@ -113,20 +121,20 @@ public KiwiBuilder(long _inst) { this._inst = _inst; } - public KiwiBuilder(String modelPath, int numWorkers, int buildOptions, boolean useSBG) { - ctor(modelPath, numWorkers, buildOptions, useSBG); + public KiwiBuilder(String modelPath, int numWorkers, int buildOptions, int modelType) { + ctor(modelPath, numWorkers, buildOptions, modelType); } public KiwiBuilder(String modelPath, int numWorkers, int buildOptions) { - ctor(modelPath, numWorkers, buildOptions, false); + ctor(modelPath, numWorkers, buildOptions, ModelType.none); } public KiwiBuilder(String modelPath, int numWorkers) { - ctor(modelPath, numWorkers, BuildOption.default_, false); + ctor(modelPath, numWorkers, BuildOption.default_, ModelType.none); } public KiwiBuilder(String modelPath) { - ctor(modelPath, 1, BuildOption.default_, false); + ctor(modelPath, 1, BuildOption.default_, ModelType.none); } protected void finalize() throws Exception { @@ -137,7 +145,7 @@ public boolean isAlive() { return _inst != 0; } - private native void ctor(String modelPath, int numWorkers, int buildOptions, boolean useSBG); + private native void ctor(String modelPath, int numWorkers, int buildOptions, int modelType); @Override public native void close() throws Exception; diff --git a/include/kiwi/ArchUtils.h b/include/kiwi/ArchUtils.h index ba2d0e85..baa27ee5 100644 --- a/include/kiwi/ArchUtils.h +++ b/include/kiwi/ArchUtils.h @@ -10,7 +10,9 @@ namespace kiwi sse2, sse4_1, avx2, + avx_vnni, avx512bw, + avx512vnni, neon, last = neon, }; @@ -25,16 +27,7 @@ namespace kiwi const char* archToStr(ArchType arch); template - struct ArchInfo; - - template<> - struct ArchInfo - { - static constexpr size_t alignment = 4; - }; - - template<> - struct ArchInfo + struct ArchInfo { static constexpr size_t alignment = 4; }; @@ -57,12 +50,24 @@ namespace kiwi static constexpr size_t alignment = 32; }; + template<> + struct ArchInfo + { + static constexpr size_t alignment = 32; + }; + template<> struct ArchInfo { static constexpr size_t alignment = 64; }; + template<> + struct ArchInfo + { + static constexpr size_t alignment = 64; + }; + template<> struct ArchInfo { diff --git a/include/kiwi/CoNgramModel.h b/include/kiwi/CoNgramModel.h new file mode 100644 index 00000000..3cffd91b --- /dev/null +++ b/include/kiwi/CoNgramModel.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ArchUtils.h" +#include "Mmap.h" +#include "LangModel.h" + +namespace kiwi +{ + namespace lm + { + struct CoNgramModelHeader + { + uint64_t vocabSize, contextSize; + uint16_t dim; + uint8_t contextType, outputType; + uint8_t keySize, windowSize, quantize, _reserved; + uint64_t numNodes; + uint64_t nodeOffset, keyOffset, valueOffset, embOffset; + }; + + template + struct Node + { + KeyType numNexts = 0; + ValueType value = 0; + DiffType lower = 0; + uint32_t nextOffset = 0; + }; + + class CoNgramModelBase : public ILangModel + { + protected: + const size_t memorySize = 0; + CoNgramModelHeader header; + + CoNgramModelBase(const utils::MemoryObject& mem) : memorySize{ mem.size() }, header{ *reinterpret_cast(mem.get()) } + { + } + public: + virtual ~CoNgramModelBase() {} + size_t vocabSize() const override { return header.vocabSize; } + size_t getMemorySize() const override { return memorySize; } + + const CoNgramModelHeader& getHeader() const { return header; } + + static utils::MemoryObject build(const std::string& contextDefinition, const std::string& embedding, size_t maxContextLength = -1, bool useVLE = true, bool reorderContextIdx = true); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool useDistantTokens = false, bool quantized = true); + }; + } +} diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index e45738d5..69e32c73 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -1,6 +1,7 @@ #pragma once #include #include "Kiwi.h" +#include "FrozenTrie.h" namespace kiwi { @@ -30,49 +31,73 @@ namespace kiwi struct ThreadLocal { std::mt19937_64 rng; - Vector tokenBuf; + Vector tokenBuf; Vector lmLProbsBuf; Vector outNgramNodeBuf; - Deque historyBuf; Deque inData; Deque outData; Deque lmLProbsData; Deque outNgramNodeData; Deque restLmLProbsData; Deque restLmLProbsCntData; + Vector> unlikelihoodBuf; + Deque unlikelihoodInData; + Deque unlikelihoodOutData; }; static constexpr int32_t nonVocab = -1; - HiddenMember, sizeof(Vector) * 2> sents; - std::shared_ptr knlm; + HiddenMember, sizeof(Vector) * 2> sents; + std::shared_ptr langModel; + std::shared_ptr kiwiInst; + std::shared_ptr>> oovDict; std::unique_ptr workers; std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; - std::bernoulli_distribution dropoutOnHistory; + float dropoutProbOnHistory = 0; + std::discrete_distribution<> nounAugmentor; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; Vector tokenToVocab, vocabToToken; + Vector windowTokenValidness; Deque> futures; const Vector* morphemes = nullptr; const Vector* forms = nullptr; + utils::FrozenTrie contextualMapper; size_t knlmVocabSize = 0; size_t batchSize = 0; size_t causalContextSize = 0; size_t windowSize = 0; + bool exclusiveWindow = true; + size_t generateUnlikelihoods = -1; size_t totalTokens = 0; size_t passedSents = 0; size_t passedWorkItems = 0; + std::array(Kiwi::SpecialMorph::max)> specialMorphIds = { { 0, } }; size_t numValidTokensInSent(size_t sentId) const; - template - size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); + template + void prepareInOutData(Deque& inData, Deque& outData, const Vector& tokens, std::mt19937_64& rng) const; + + bool tokenizeUnlikely(Vector>& out, int32_t prefix, int32_t target, int32_t suffix, std::mt19937_64& rng) const; + + template + size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + UlInTy unlikelihoodIn, UlOutTy unlikelihoodOut, size_t* unlikelihoodSize); public: - HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0, double _dropoutProbOnHistory = 0); + HSDataset(size_t _batchSize = 0, + size_t _causalContextSize = 0, + size_t _windowSize = 0, + bool _exclusiveWindow = true, + size_t _workers = 0, + double _dropoutProb = 0, + double _dropoutProbOnHistory = 0, + double _nounAugmentingProb = 0, + size_t _generateUnlikelihoods = -1); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; @@ -82,6 +107,7 @@ namespace kiwi size_t numEstimBatches() const; size_t numSents() const; size_t numTokens() const; + bool doesGenerateUnlikelihoods() const { return generateUnlikelihoods < (size_t)-1; } size_t getBatchSize() const { return batchSize; } size_t getCausalContextSize() const { return causalContextSize; } @@ -90,8 +116,10 @@ namespace kiwi void seed(size_t newSeed); void reset(); - size_t next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut); - size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut); + size_t next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int32_t* unlikelihoodIn = nullptr, int32_t* unlikelihoodOut = nullptr, size_t* unlikelihoodSize = nullptr); + size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int64_t* unlikelihoodIn = nullptr, int64_t* unlikelihoodOut = nullptr, size_t* unlikelihoodSize = nullptr); size_t vocabSize() const { return vocabToToken.size(); } size_t getKnlmVocabSize() const; @@ -100,9 +128,9 @@ namespace kiwi std::u16string vocabForm(uint32_t vocab) const; std::vector estimVocabFrequency() const; - Range::const_iterator> getSent(size_t idx) const; + Range::const_iterator> getSent(size_t idx) const; std::vector getAugmentedSent(size_t idx); - std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1) const; + std::vector, size_t>> extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers = 1, bool exclusiveCnt = false) const; }; } diff --git a/include/kiwi/Form.h b/include/kiwi/Form.h index 9ae59f55..0a3c71a4 100644 --- a/include/kiwi/Form.h +++ b/include/kiwi/Form.h @@ -167,6 +167,9 @@ namespace kiwi /** 분할된 형태소의 경우 원형 형태소를 반환한다. 그 외에는 자기 자신을 반환한다. */ const Morpheme* getCombined() const { return this + combined; } + /** 현재 인스턴스가 단일 형태소인지 확인한다 */ + bool isSingle() const { return chunks.empty() || complex || saisiot; } + bool hasComplex() const { if (getCombined()->complex) return true; @@ -188,6 +191,7 @@ namespace kiwi } return false; } + }; /** diff --git a/include/kiwi/Joiner.h b/include/kiwi/Joiner.h index 4447d3d1..a5d24e99 100644 --- a/include/kiwi/Joiner.h +++ b/include/kiwi/Joiner.h @@ -1,12 +1,11 @@ #pragma once #include "Types.h" #include "ArchUtils.h" -#include "LmState.h" +#include "LangModel.h" namespace kiwi { class Kiwi; - template class VoidState; struct Form; namespace cmb @@ -57,34 +56,114 @@ namespace kiwi LmState lmState; float score = 0; - Candidate(const CompiledRule& _cr, const LangModel& lm) + Candidate(const CompiledRule& _cr, const lm::ILangModel* lm) : joiner{ _cr }, lmState{ lm } { } }; template - struct Candidate> + struct Candidate> { Joiner joiner; - Candidate(const CompiledRule& _cr, const LangModel& lm) + Candidate(const CompiledRule& _cr, const lm::ILangModel* lm) : joiner{ _cr } { } }; - class AutoJoiner + class ErasedVector { - friend class kiwi::Kiwi; + using FnDestruct = void(*)(ErasedVector*); + using FnCopyConstruct = void(*)(ErasedVector*, const ErasedVector&); + + template + static void destructImpl(ErasedVector* self) + { + auto* target = reinterpret_cast*>(&self->vec); + std::destroy_at(target); + } + + template + static void copyConstructImpl(ErasedVector* self, const ErasedVector& other) + { + auto* target = reinterpret_cast*>(&self->vec); + new (target) Vector{ *reinterpret_cast*>(&other.vec) }; + } - struct AddVisitor; - struct AddVisitor2; - const Kiwi* kiwi = nullptr; union { - typename std::aligned_storage) + sizeof(int), alignof(Vector)>::type candBuf; + Vector vec; }; + FnDestruct destruct = nullptr; + FnCopyConstruct copyConstruct = nullptr; + public: + + template + ErasedVector(Vector&& v) + { + auto* target = reinterpret_cast*>(&vec); + new (target) Vector{ move(v) }; + destruct = &destructImpl; + copyConstruct = ©ConstructImpl; + } + + ~ErasedVector() + { + if (destruct) + { + (*destruct)(this); + destruct = nullptr; + copyConstruct = nullptr; + } + } + + ErasedVector(const ErasedVector& other) + : destruct{ other.destruct }, copyConstruct{ other.copyConstruct } + { + if (!destruct) return; + (*copyConstruct)(this, other); + } + + ErasedVector(ErasedVector&& other) + { + std::swap(vec, other.vec); + std::swap(destruct, other.destruct); + std::swap(copyConstruct, other.copyConstruct); + } + + ErasedVector& operator=(const ErasedVector& other) + { + this->~ErasedVector(); + new (this) ErasedVector{ other }; + return *this; + } + + ErasedVector& operator=(ErasedVector&& other) + { + std::swap(vec, other.vec); + std::swap(destruct, other.destruct); + std::swap(copyConstruct, other.copyConstruct); + return *this; + } + + template + Vector& get() + { + return *reinterpret_cast*>(&vec); + } + + template + const Vector& get() const + { + return *reinterpret_cast*>(&vec); + } + }; + + class AutoJoiner + { + friend class kiwi::Kiwi; template explicit AutoJoiner(const Kiwi& kiwi, Candidate&& state); @@ -93,16 +172,27 @@ namespace kiwi void foreachMorpheme(const Form* formHead, Func&& func) const; template - void add(size_t morphemeId, Space space, Vector>& candidates); + void addImpl(size_t morphemeId, Space space, Vector>& candidates); template - void add(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates); + void addImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates); template - void addWithoutSearch(size_t morphemeId, Space space, Vector>>& candidates); + void addWithoutSearchImpl(size_t morphemeId, Space space, Vector>>& candidates); template - void addWithoutSearch(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates); + void addWithoutSearchImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates); + + template + struct Dispatcher; + + using FnAdd = void(*)(AutoJoiner*, size_t, Space, Vector>>&); + using FnAdd2 = void(*)(AutoJoiner*, U16StringView, POSTag, bool, Space, Vector>>&); + + const Kiwi* kiwi = nullptr; + FnAdd dfAdd = nullptr; + FnAdd2 dfAdd2 = nullptr; + ErasedVector candBuf; public: ~AutoJoiner(); diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 1c2ee7d2..49516a7d 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -25,7 +25,7 @@ #include "ThreadPool.h" #include "WordDetector.h" #include "TagUtils.h" -#include "LmState.h" +#include "LangModel.h" #include "Joiner.h" #include "TypoTransformer.h" @@ -58,7 +58,9 @@ namespace kiwi class Kiwi { friend class KiwiBuilder; - friend class PathEvaluator; + template friend struct BestPathFinder; + template friend struct PathEvaluator; + template friend struct MorphemeEvaluator; friend class cmb::AutoJoiner; template class LmState> friend struct NewAutoJoinerGetter; @@ -81,22 +83,17 @@ namespace kiwi Vector typoPtrs; Vector typoForms; utils::FrozenTrie formTrie; - LangModel langMdl; + std::shared_ptr langMdl; std::shared_ptr combiningRule; std::unique_ptr pool; - inline const Morpheme* getDefaultMorpheme(POSTag tag) const; - - template - cmb::AutoJoiner newJoinerImpl() const - { - return cmb::AutoJoiner{ *this, cmb::Candidate{ *combiningRule, langMdl } }; - } + const Morpheme* getDefaultMorpheme(POSTag tag) const; ArchType selectedArch = ArchType::none; void* dfSplitByTrie = nullptr; void* dfFindForm = nullptr; void* dfFindBestPath = nullptr; + void* dfNewJoiner = nullptr; public: enum class SpecialMorph { @@ -129,7 +126,7 @@ namespace kiwi * kiwi::KiwiBuilder 를 통해 생성된 객체만이 형태소 분석에 사용할 수 있다. */ Kiwi(ArchType arch = ArchType::default_, - LangModel _langMdl = {}, + const std::shared_ptr& _langMdl = {}, bool typoTolerant = false, bool continualTypoTolerant = false, bool lengtheningTypoTolerant = false); @@ -156,6 +153,8 @@ namespace kiwi ArchType archType() const { return selectedArch; } + ModelType modelType() const { return langMdl ? langMdl->getType() : ModelType::none; } + /** * @brief 현재 Kiwi 객체가 오타 교정 기능이 켜진 상태로 생성되었는지 알려준다. * @@ -368,6 +367,10 @@ namespace kiwi TokenResult* tokenizedResultOut = nullptr ) const; + + template + cmb::AutoJoiner newJoinerImpl() const; + /** * @brief 형태소들을 결합하여 텍스트로 복원해주는 작업을 수행하는 AutoJoiner를 반환한다. * @@ -378,11 +381,6 @@ namespace kiwi */ cmb::AutoJoiner newJoiner(bool lmSearch = true) const; - /** - * @brief Kiwi에 내장된 언어 모델에 접근할 수 있는 LmObject 객체를 생성한다. - */ - std::unique_ptr newLmObject() const; - /** * @brief `TokenInfo::typoFormId`로부터 실제 오타 형태를 복원한다. * @@ -514,9 +512,9 @@ namespace kiwi integrateAllomorph = v; } - const lm::KnLangModelBase* getKnLM() const + const lm::ILangModel* getLangModel() const { - return langMdl.knlm.get(); + return langMdl.get(); } void findMorpheme(std::vector& out, const std::u16string& s, POSTag tag = POSTag::unknown) const; @@ -533,20 +531,43 @@ namespace kiwi Vector forms; Vector morphemes; UnorderedMap formMap; - LangModel langMdl; + std::shared_ptr langMdl; std::shared_ptr combiningRule; WordDetector detector; - + size_t numThreads = 0; BuildOption options = BuildOption::none; + ModelType modelType = ModelType::none; ArchType archType = ArchType::none; + public: + struct ModelBuildArgs + { + std::string morphemeDef; + std::vector corpora; + size_t minMorphCnt = 10; + size_t lmOrder = 4; + std::vector lmMinCnts = { 1 }; + size_t numWorkers = 1; + size_t sbgSize = 1000000; + bool useLmTagHistory = true; + bool quantizeLm = true; + bool compressLm = true; + float dropoutSampling = 0.05f; + float dropoutProb = 0.15f; + }; + + private: + using MorphemeMap = UnorderedMap, std::pair>; + + template + std::unique_ptr buildKnLM(const ModelBuildArgs& args, size_t lmVocabSize, MorphemeMap& morphMap) const; + void loadMorphBin(std::istream& is); void saveMorphBin(std::ostream& os) const; FormRaw& addForm(const KString& form); size_t addForm(Vector& newForms, UnorderedMap& newFormMap, KString form) const; - using MorphemeMap = UnorderedMap, std::pair>; void initMorphemes(); @@ -556,16 +577,27 @@ namespace kiwi MorphemeMap restoreMorphemeMap(bool separateDefaultMorpheme = false) const; template - void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const; - - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; - void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector* splitOut = nullptr) const; + void _addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, + UnorderedMap, size_t>* oovDict = nullptr) const; + + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; + void addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio = 0, RaggedVector* splitOut = nullptr, + UnorderedMap, size_t>* oovDict = nullptr) const; void updateForms(); - void updateMorphemes(); + void updateMorphemes(size_t vocabSize = 0); size_t findMorpheme(U16StringView form, POSTag tag) const; - + std::pair addWord(U16StringView newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId); std::pair addWord(const std::u16string& newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId); std::pair addWord(U16StringView form, POSTag tag = POSTag::nnp, float score = 0); @@ -589,44 +621,36 @@ namespace kiwi ) const; void addCombinedMorphemes( - Vector& newForms, - UnorderedMap& newFormMap, - Vector& newMorphemes, - UnorderedMap>& newFormCands, - size_t leftId, - size_t rightId, + Vector& newForms, + UnorderedMap& newFormMap, + Vector& newMorphemes, + UnorderedMap>& newFormCands, + size_t leftId, + size_t rightId, size_t ruleId ) const; void buildCombinedMorphemes( - Vector& newForms, + Vector& newForms, UnorderedMap& newFormMap, - Vector& newMorphemes, + Vector& newMorphemes, UnorderedMap>& newFormCands ) const; void addAllomorphsToRule(); + std::array(Kiwi::SpecialMorph::max)> getSpecialMorphs() const; + public: - struct ModelBuildArgs - { - std::string morphemeDef; - std::vector corpora; - size_t minMorphCnt = 10; - size_t lmOrder = 4; - std::vector lmMinCnts = { 1 }; - size_t numWorkers = 1; - size_t sbgSize = 1000000; - bool useLmTagHistory = true; - bool quantizeLm = true; - bool compressLm = true; - float dropoutSampling = 0.05f; - float dropoutProb = 0.15f; - }; + + /** + * @brief 주어진 모델 경로로부터 모델의 타입을 추정한다. + */ + static ModelType getModelType(const std::string& modelPath); /** * @brief KiwiBuilder의 기본 생성자 - * + * * @note 이 생성자로 생성된 경우 `ready() == false`인 상태이므로 유효한 Kiwi 객체를 생성할 수 없다. */ KiwiBuilder(); @@ -643,9 +667,9 @@ namespace kiwi /** * @brief KiwiBuilder를 raw 데이터로부터 생성한다. - * - * - * @note 이 함수는 현재 내부적으로 기본 모델 구축에 쓰인다. + * + * + * @note 이 함수는 현재 내부적으로 기본 모델 구축에 쓰인다. * 추후 공개 데이터로도 쉽게 직접 모델을 구축할 수 있도록 개선된 API를 제공할 예정. */ KiwiBuilder(const ModelBuildArgs& args); @@ -657,29 +681,29 @@ namespace kiwi /** * @brief KiwiBuilder를 모델 파일로부터 생성한다. - * + * * @param modelPath 모델이 위치한 경로 * @param numThreads 모델 및 형태소 분석에 사용할 스레드 개수 * @param options 생성 옵션. `kiwi::BuildOption`을 참조 */ - KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, bool useSBG = false); + KiwiBuilder(const std::string& modelPath, size_t numThreads = 0, BuildOption options = BuildOption::default_, ModelType modelType = ModelType::none); /** * @brief 현재 KiwiBuilder 객체가 유효한 분석 모델을 로딩한 상태인지 알려준다. - * + * * @return 유효한 상태면 true를 반환한다. 기본 생성자로 생성한 경우 `ready() == false`이며, * 다른 생성자로 생성한 경우는 `ready() == true`이다. */ bool ready() const { - return !!langMdl.knlm; + return !!langMdl; } void saveModel(const std::string& modelPath) const; /** * @brief 사전에 새로운 형태소를 추가한다. 이미 동일한 형태소가 있는 경우는 무시된다. - * + * * @param form 새로운 형태소의 형태 * @param tag 품사 태그 * @param score 페널티 점수. 이에 대한 자세한 설명은 하단의 note 참조. @@ -687,7 +711,7 @@ namespace kiwi * @note 이 방법으로 추가된 형태소는 언어모델 탐색에서 어휘 사전 외 토큰(OOV 토큰)으로 처리된다. * 이 방법으로 추가된 형태소는 항상 분석 과정에서 최우선으로 탐색되지는 않으므로 최상의 결과를 위해서는 `score` 값을 조절할 필요가 있다. * `score` 값을 높게 설정할수록 다른 후보들과의 경쟁에서 이 형태소가 더 높은 점수를 받아 최종 분석 결과에 노출될 가능성이 높아진다. - * 만약 이 방법으로 추가된 형태소가 원치 않는 상황에서 과도하게 출력되는 경우라면 `score`를 더 작은 값으로, + * 만약 이 방법으로 추가된 형태소가 원치 않는 상황에서 과도하게 출력되는 경우라면 `score`를 더 작은 값으로, * 반대로 원하는 상황에서도 출력되지 않는 경우라면 `score`를 더 큰 값으로 조절하는 게 좋다. */ std::pair addWord(const std::u16string& form, POSTag tag = POSTag::nnp, float score = 0); @@ -716,26 +740,26 @@ namespace kiwi * @param score 페널티 점수. 이에 대한 자세한 설명은 하단의 `addWord`함수의 note 참조. * @exception kiwi::UnknownMorphemeException `analyzed`로 주어진 형태소 중 하나라도 존재하지 않는게 있는 경우 예외를 발생시킨다. * @return 형태소열을 추가하는데 성공했으면 true, 동일한 형태소열이 존재하여 추가에 실패한 경우 false를 반환한다. - * @note 이 함수는 특정 문자열이 어떻게 분석되어야하는지 직접적으로 지정해줄 수 있다. - * 따라서 `addWord` 함수를 사용해도 오분석이 발생하는 경우, 이 함수를 통해 해당 사례들에 대해 정확한 분석 결과를 추가하면 원하는 분석 결과를 얻을 수 있다. + * @note 이 함수는 특정 문자열이 어떻게 분석되어야하는지 직접적으로 지정해줄 수 있다. + * 따라서 `addWord` 함수를 사용해도 오분석이 발생하는 경우, 이 함수를 통해 해당 사례들에 대해 정확한 분석 결과를 추가하면 원하는 분석 결과를 얻을 수 있다. */ - bool addPreAnalyzedWord(const std::u16string& form, - const std::vector>& analyzed, + bool addPreAnalyzedWord(const std::u16string& form, + const std::vector>& analyzed, std::vector> positions = {}, float score = 0 ); - bool addPreAnalyzedWord(const char16_t* form, - const std::vector>& analyzed, + bool addPreAnalyzedWord(const char16_t* form, + const std::vector>& analyzed, std::vector> positions = {}, float score = 0 ); /** * @brief 규칙에 의해 변형된 형태소 목록을 생성하여 자동 추가한다. - * - * @param tag - * @param repl - * @param score + * + * @param tag + * @param repl + * @param score * @return 새로 추가된 변형된 형태소의 ID와 그 형태를 pair로 묶은 목록 */ template @@ -770,24 +794,24 @@ namespace kiwi } /** - * @brief - * - * @param dictPath - * @return + * @brief + * + * @param dictPath + * @return */ size_t loadDictionary(const std::string& dictPath); - std::vector extractWords(const U16MultipleReader& reader, + std::vector extractWords(const U16MultipleReader& reader, size_t minCnt = 10, size_t maxWordLen = 10, float minScore = 0.25, float posThreshold = -3, bool lmFilter = true ) const; - std::vector extractAddWords(const U16MultipleReader& reader, + std::vector extractAddWords(const U16MultipleReader& reader, size_t minCnt = 10, size_t maxWordLen = 10, float minScore = 0.25, float posThreshold = -3, bool lmFilter = true ); /** * @brief 현재 단어 및 사전 설정을 기반으로 Kiwi 객체를 생성한다. - * + * * @param typos * @param typoCostThreshold * @return 형태소 분석 준비가 완료된 Kiwi의 객체. @@ -803,22 +827,31 @@ namespace kiwi const std::vector& inputPathes, const std::string& outputPath, const std::string& morphemeDefPath = {}, - size_t morphemeDefMinCnt = 0 + size_t morphemeDefMinCnt = 0, + bool generateOovDict = false ) const; using TokenFilter = std::function; - HSDataset makeHSDataset(const std::vector& inputPathes, - size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, + HSDataset makeHSDataset(const std::vector& inputPathes, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, double dropoutProbOnHistory = 0, + double nounAugmentingProb = 0, + size_t generateUnlikelihoods = -1, const TokenFilter& tokenFilter = {}, const TokenFilter& windowFilter = {}, double splitRatio = 0, bool separateDefaultMorpheme = false, const std::string& morphemeDefPath = {}, size_t morphemeDefMinCnt = 0, + const std::vector>>& contextualMapper = {}, HSDataset* splitDataset = nullptr ) const; + + BuildOption getOptions() const { return options; } + ModelType getModelType() const { return modelType; } + + static void buildMorphData(const std::string& morphemeDefPath, const std::string& outputPath, size_t minCnt = 10); }; } diff --git a/include/kiwi/Knlm.h b/include/kiwi/Knlm.h index ef54eb9b..7ab8044a 100644 --- a/include/kiwi/Knlm.h +++ b/include/kiwi/Knlm.h @@ -1,23 +1,12 @@ #pragma once -#include -#include -#include -#include -#include -#include - -#include "Utils.h" -#include "Mmap.h" -#include "ArchUtils.h" +#include "LangModel.h" namespace kiwi { namespace lm { - using Vid = uint16_t; - - struct Header + struct KnLangModelHeader { uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset; uint64_t unk_id, bos_id, eos_id, vocab_size; @@ -26,7 +15,7 @@ namespace kiwi }; template - struct Node + struct KnLangModelNode { KeyType num_nexts = 0; DiffType lower = 0; @@ -34,7 +23,7 @@ namespace kiwi float ll = 0, gamma = 0; }; - class KnLangModelBase + class KnLangModelBase : public ILangModel { protected: utils::MemoryObject base; @@ -52,21 +41,24 @@ namespace kiwi public: virtual ~KnLangModelBase() {} - const Header& getHeader() const { return *reinterpret_cast(base.get()); } + size_t vocabSize() const override { return getHeader().vocab_size; } + size_t getMemorySize() const override { return base.size(); } + + const KnLangModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } virtual ptrdiff_t getLowerNode(ptrdiff_t node_idx) const = 0; virtual size_t nonLeafNodeSize() const = 0; virtual const void* getExtraBuf() const = 0; - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none, bool transposed = false); - template> + template> static utils::MemoryOwner build(Trie&& ngram_cf, size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, - const std::vector>* bigram_list = nullptr, + const std::vector>* bigram_list = nullptr, const HistoryTx* history_transformer = nullptr, const void* extra_buf = nullptr, size_t extra_buf_size = 0 diff --git a/include/kiwi/LangModel.h b/include/kiwi/LangModel.h new file mode 100644 index 00000000..ec839e63 --- /dev/null +++ b/include/kiwi/LangModel.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "Utils.h" +#include "Mmap.h" +#include "ArchUtils.h" +#include "Types.h" + +namespace kiwi +{ + namespace lm + { + class ILangModel + { + public: + virtual ~ILangModel() = default; + virtual ModelType getType() const = 0; + virtual size_t vocabSize() const = 0; + virtual size_t getMemorySize() const = 0; + + virtual void* getFindBestPathFn() const = 0; + virtual void* getNewJoinerFn() const = 0; + }; + + template + struct LmStateBase + { + float next(const ILangModel* langMdl, typename DerivedLM::VocabType nextToken) + { + using LmStateType = typename DerivedLM::LmStateType; + return static_cast(this)->nextImpl(static_cast(langMdl), nextToken); + } + }; + + template + class VoidLangModel; + + template + struct VoidState : public LmStateBase> + { + bool operator==(const VoidState& other) const + { + return true; + } + + float nextImpl(const VoidLangModel* langMdl, uint32_t nextToken) + { + return 0; + } + }; + + template + class VoidLangModel : public ILangModel + { + public: + using VocabType = uint32_t; + using LmStateType = VoidState; + + ModelType getType() const override { return ModelType::none; } + size_t vocabSize() const override { return 0; } + void* getFindBestPathFn() const override { return nullptr; } + void* getNewJoinerFn() const override { return nullptr; } + }; + } + + template + struct Hash> + { + size_t operator()(const lm::VoidState& state) const + { + return 0; + } + }; +} diff --git a/include/kiwi/LmState.h b/include/kiwi/LmState.h index 1399343d..e69de29b 100644 --- a/include/kiwi/LmState.h +++ b/include/kiwi/LmState.h @@ -1,34 +0,0 @@ -#pragma once - -#include -#include "Utils.h" -#include "Trie.hpp" -#include "Knlm.h" -#include "SkipBigramModel.h" - -namespace kiwi -{ - struct LangModel - { - std::shared_ptr knlm; - std::shared_ptr sbg; - }; - - class LmObjectBase - { - public: - virtual ~LmObjectBase() {} - - virtual size_t vocabSize() const = 0; - - virtual float evalSequence(const uint32_t* seq, size_t length, size_t stride) const = 0; - - virtual void predictNext(const uint32_t* seq, size_t length, size_t stride, float* outScores) const = 0; - - virtual void evalSequences( - const uint32_t* prefix, size_t prefixLength, size_t prefixStride, - const uint32_t* suffix, size_t suffixLength, size_t suffixStride, - size_t seqSize, const uint32_t** seq, const size_t* seqLength, const size_t* seqStride, float* outScores - ) const = 0; - }; -} diff --git a/include/kiwi/Mmap.h b/include/kiwi/Mmap.h index a2c6ef3d..7850812d 100644 --- a/include/kiwi/Mmap.h +++ b/include/kiwi/Mmap.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #ifdef _WIN32 #define NOMINMAX @@ -260,6 +261,16 @@ namespace kiwi const void* get() const { return obj->get(); } size_t size() const { return obj->size(); } + + void writeToFile(const std::string& filepath) const + { + std::ofstream ofs; + if (!openFile(ofs, filepath, std::ios_base::binary)) + { + throw IOException{ "Cannot open file : " + filepath }; + } + ofs.write((const char*)get(), size()); + } }; template diff --git a/include/kiwi/SkipBigramModel.h b/include/kiwi/SkipBigramModel.h index 69e73069..396eb583 100644 --- a/include/kiwi/SkipBigramModel.h +++ b/include/kiwi/SkipBigramModel.h @@ -1,26 +1,18 @@ #pragma once -#include -#include -#include -#include -#include -#include - -#include "ArchUtils.h" -#include "Mmap.h" +#include "Knlm.h" namespace kiwi { - namespace sb + namespace lm { - struct Header + struct SkipBigramModelHeader { uint64_t vocabSize; uint8_t keySize, windowSize, compressed, quantize, _rsv[4]; }; - class SkipBigramModelBase + class SkipBigramModelBase : public ILangModel { protected: utils::MemoryObject base; @@ -30,9 +22,12 @@ namespace kiwi } public: virtual ~SkipBigramModelBase() {} - const Header& getHeader() const { return *reinterpret_cast(base.get()); } + size_t vocabSize() const override { return getHeader().vocabSize; } + ModelType getType() const override { return ModelType::sbg; } + + const SkipBigramModelHeader& getHeader() const { return *reinterpret_cast(base.get()); } - static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); + static std::unique_ptr create(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem, ArchType archType = ArchType::none); }; } } diff --git a/include/kiwi/SubstringExtractor.h b/include/kiwi/SubstringExtractor.h index 67115867..0b07a6db 100644 --- a/include/kiwi/SubstringExtractor.h +++ b/include/kiwi/SubstringExtractor.h @@ -39,6 +39,7 @@ namespace kiwi ); void addArray(const uint16_t* first, const uint16_t* last); void addArray(const uint32_t* first, const uint32_t* last); + void addArray(const int32_t* first, const int32_t* last); void addArray(const uint64_t* first, const uint64_t* last); utils::FrozenTrie count() const; std::unique_ptr buildLM( diff --git a/include/kiwi/TemplateUtils.hpp b/include/kiwi/TemplateUtils.hpp index 3d2c7038..8c9de09e 100644 --- a/include/kiwi/TemplateUtils.hpp +++ b/include/kiwi/TemplateUtils.hpp @@ -56,7 +56,10 @@ namespace kiwi }; template - struct SeqMax; + struct SeqMax + { + static constexpr std::ptrdiff_t value = 0; + }; template struct SeqMax> @@ -130,7 +133,7 @@ namespace kiwi template class Table { - ValTy table[SeqMax::value + 1]; + std::array::value + 1> table; template void set(seq<>) @@ -153,9 +156,48 @@ namespace kiwi constexpr ValTy operator[](std::ptrdiff_t idx) const { + if (idx < 0 || (size_t)idx >= table.size()) return ValTy{}; return table[idx]; } }; } + + + template + struct SignedType { using type = IntTy; }; + + template<> + struct SignedType { using type = int8_t; }; + + template<> + struct SignedType { using type = int16_t; }; + + template<> + struct SignedType { using type = int32_t; }; + + template<> + struct SignedType { using type = int64_t; }; + + template<> + struct SignedType { using type = int16_t; }; + + + template + struct UnsignedType { using type = IntTy; }; + + template<> + struct UnsignedType { using type = uint8_t; }; + + template<> + struct UnsignedType { using type = uint16_t; }; + + template<> + struct UnsignedType { using type = uint32_t; }; + + template<> + struct UnsignedType { using type = uint64_t; }; + + template<> + struct UnsignedType { using type = uint16_t; }; } diff --git a/include/kiwi/ThreadPool.h b/include/kiwi/ThreadPool.h index 848f976f..5925904c 100644 --- a/include/kiwi/ThreadPool.h +++ b/include/kiwi/ThreadPool.h @@ -27,7 +27,7 @@ namespace kiwi template auto enqueue(F&& f, Args&&... args) - ->std::future::type>; + ->std::future::type>; size_t size() const { return workers.size(); } size_t numEnqueued() const { return tasks.size(); } @@ -67,9 +67,9 @@ namespace kiwi template auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> + -> std::future::type> { - using return_type = typename std::result_of::type; + using return_type = typename std::invoke_result::type; auto task = std::make_shared< std::packaged_task >( std::bind(std::forward(f), std::placeholders::_1, std::forward(args)...)); diff --git a/include/kiwi/Trie.hpp b/include/kiwi/Trie.hpp index 2bd9a42a..49ff5781 100644 --- a/include/kiwi/Trie.hpp +++ b/include/kiwi/Trie.hpp @@ -150,8 +150,8 @@ namespace kiwi } } - template - void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const + template + void traverseWithKeys(_Fn&& fn, std::vector<_CKey, _Alloc>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { fn((Node*)this, rkeys); @@ -487,8 +487,8 @@ namespace kiwi return nodes[0].traverse(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative); } - template - void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const + template + void traverseWithKeys(_Fn&& fn, std::vector<_CKey, _Alloc>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { return nodes[0].traverseWithKeys(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative); } diff --git a/include/kiwi/Types.h b/include/kiwi/Types.h index e6258499..f7c4c11c 100644 --- a/include/kiwi/Types.h +++ b/include/kiwi/Types.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -62,16 +63,6 @@ inline Type operator^=(Type& a, Type b)\ return reinterpret_cast(reinterpret_cast::type&>(a) ^= static_cast::type>(b));\ } -namespace nonstd -{ - namespace sv_lite - { - template class basic_string_view; - } - - using string_view = sv_lite::basic_string_view>; - using u16string_view = sv_lite::basic_string_view>; -} namespace kiwi { @@ -187,7 +178,7 @@ namespace kiwi using KcScores = Vector>; #endif - using U16StringView = nonstd::u16string_view; + using U16StringView = std::u16string_view; /** * @brief 형태소 품사 태그와 관련된 열거형 @@ -303,6 +294,18 @@ namespace kiwi default_ = integrateAllomorph | loadDefaultDict | loadTypoDict | loadMultiDict, }; + enum class ModelType + { + none = 0, /**< Select default model */ + knlm = 1, /**< Kneser-Ney Language Model */ + sbg = 2, /**< Skip-Bigram Model */ + cong = 3, /**< Contextual N-gram embedding Language Model (Only local context) */ + congGlobal = 4, /**< Contextual N-gram embedding Language Model (local and global context) */ + congFp32 = 5, /**< Contextual N-gram embedding Language Model (Only local context, non-quantized(slow) version) */ + congGlobalFp32 = 6, /**< Contextual N-gram embedding Language Model (local and global context, non-quantized(slow) version) */ + knlmTransposed, + }; + struct Morpheme; /** diff --git a/include/kiwi/Utils.h b/include/kiwi/Utils.h index b2a76e81..c3b1680d 100644 --- a/include/kiwi/Utils.h +++ b/include/kiwi/Utils.h @@ -388,5 +388,6 @@ namespace kiwi std::ifstream& openFile(std::ifstream& f, const std::string& filePath, std::ios_base::openmode mode = std::ios_base::in); std::ofstream& openFile(std::ofstream& f, const std::string& filePath, std::ios_base::openmode mode = std::ios_base::out); + bool isOpenable(const std::string& filePath); } diff --git a/include/kiwi/capi.h b/include/kiwi/capi.h index 5bfa6e11..3f33d35a 100644 --- a/include/kiwi/capi.h +++ b/include/kiwi/capi.h @@ -98,8 +98,11 @@ enum KIWI_BUILD_LOAD_TYPO_DICT = 4, KIWI_BUILD_LOAD_MULTI_DICT = 8, KIWI_BUILD_DEFAULT = 15, - KIWI_BUILD_MODEL_TYPE_KNLM = 0x0000, - KIWI_BUILD_MODEL_TYPE_SBG = 0x0100, + KIWI_BUILD_MODEL_TYPE_DEFAULT = 0x0000, + KIWI_BUILD_MODEL_TYPE_KNLM = 0x0100, + KIWI_BUILD_MODEL_TYPE_SBG = 0x0200, + KIWI_BUILD_MODEL_TYPE_CONG = 0x0300, + KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL = 0x0400, }; enum diff --git a/src/ArchAvailable.h b/src/ArchAvailable.h index 9a2a59f0..d141e1e4 100644 --- a/src/ArchAvailable.h +++ b/src/ArchAvailable.h @@ -12,7 +12,11 @@ namespace kiwi using AvailableArch = tp::seq< #ifdef KIWI_USE_CPUINFO #if CPUINFO_ARCH_X86_64 + static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED + static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -24,7 +28,11 @@ namespace kiwi #endif #else #ifdef KIWI_ARCH_X86_64 + static_cast(ArchType::avx512vnni), static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED + static_cast(ArchType::avx_vnni), +#endif static_cast(ArchType::avx2), static_cast(ArchType::sse4_1), #endif @@ -38,4 +46,34 @@ namespace kiwi static_cast(ArchType::none), static_cast(ArchType::balanced) >; + + using QuantAvailableArch = tp::seq < +#ifdef KIWI_USE_CPUINFO +#if CPUINFO_ARCH_X86_64 + static_cast(ArchType::avx512vnni), + static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED + static_cast(ArchType::avx_vnni), +#endif + static_cast(ArchType::avx2), + static_cast(ArchType::sse4_1) +#endif +#if CPUINFO_ARCH_ARM64 + static_cast(ArchType::neon) +#endif +#else +#ifdef KIWI_ARCH_X86_64 + static_cast(ArchType::avx512vnni), + static_cast(ArchType::avx512bw), +#ifdef KIWI_AVX_VNNI_SUPPORTED + static_cast(ArchType::avx_vnni), +#endif + static_cast(ArchType::avx2), + static_cast(ArchType::sse4_1) +#endif +#ifdef KIWI_ARCH_ARM64 + static_cast(ArchType::neon) +#endif +#endif + >; } diff --git a/src/ArchUtils.cpp b/src/ArchUtils.cpp index 9342f365..4f553222 100644 --- a/src/ArchUtils.cpp +++ b/src/ArchUtils.cpp @@ -12,7 +12,9 @@ ArchType kiwi::getBestArch() #ifdef KIWI_USE_CPUINFO cpuinfo_initialize(); #if CPUINFO_ARCH_X86_64 + if (cpuinfo_has_x86_avx512vnni()) return ArchType::avx512vnni; if (cpuinfo_has_x86_avx512bw()) return ArchType::avx512bw; + if (cpuinfo_has_x86_avx_vnni_int8()) return ArchType::avx_vnni; if (cpuinfo_has_x86_avx2()) return ArchType::avx2; if (cpuinfo_has_x86_sse4_1()) return ArchType::sse4_1; #endif @@ -24,7 +26,7 @@ ArchType kiwi::getBestArch() #endif #else #ifdef KIWI_ARCH_X86_64 - return ArchType::avx512bw; + return ArchType::avx512vnni; #elif defined(__x86_64__) || defined(KIWI_ARCH_X86) return ArchType::sse2; #elif defined(KIWI_ARCH_ARM64) @@ -43,7 +45,9 @@ namespace kiwi "sse2", "sse4_1", "avx2", + "avx_vnni", "avx512bw", + "avx512vnni", "neon", }; @@ -51,7 +55,7 @@ namespace kiwi { if (arch <= ArchType::balanced) return arch; #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 || KIWI_ARCH_X86_64 || KIWI_ARCH_X86 - if (ArchType::sse2 <= arch && arch <= ArchType::avx512bw && arch <= best) + if (ArchType::sse2 <= arch && arch <= ArchType::avx512vnni && arch <= best) { return arch; } diff --git a/src/BestPathContainer.hpp b/src/BestPathContainer.hpp new file mode 100644 index 00000000..a8b86649 --- /dev/null +++ b/src/BestPathContainer.hpp @@ -0,0 +1,466 @@ +#pragma once + +#include +#include + +namespace kiwi +{ + template + struct WordLL; + + using Wid = uint32_t; + + enum class PathEvaluatingMode + { + topN, + top1Small, + top1Medium, + top1, + }; + + template + struct WordLL + { + LmState lmState; + uint8_t prevRootId = 0; + SpecialState spState; + uint8_t rootId = 0; + + const Morpheme* morpheme = nullptr; + float accScore = 0, accTypoCost = 0; + const WordLL* parent = nullptr; + Wid wid = 0; + uint16_t ownFormId = 0; + uint8_t combineSocket = 0; + + WordLL() = default; + + WordLL(const Morpheme* _morph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) + : morpheme{ _morph }, + accScore{ _accScore }, + accTypoCost{ _accTypoCost }, + parent{ _parent }, + lmState{ _lmState }, + spState{ _spState }, + rootId{ _parent ? _parent->rootId : (uint8_t)0 } + { + } + + const WordLL* root() const + { + if (parent) return parent->root(); + else return this; + } + + bool equalTo(const LmState& lmState, uint8_t prevRootId, SpecialState spState) const + { + return ((this->prevRootId == prevRootId) & (this->spState == spState)) && this->lmState == lmState; + } + + bool operator==(const WordLL& o) const + { + return equalTo(o.lmState, o.prevRootId, o.spState); + } + }; + + template + struct Hash> + { + size_t operator()(const WordLL& p) const + { + size_t ret = Hash{}(p.lmState); + ret = *reinterpret_cast(&p.prevRootId) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + return ret; + } + + size_t operator()(const LmState& lmState, uint8_t prevRootId, uint8_t spState) const + { + size_t ret = Hash{}(lmState); + ret = ((uint16_t)(prevRootId) | ((uint16_t)spState << 8)) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + return ret; + } + }; + + static constexpr uint8_t commonRootId = -1; + + template + struct PathHash + { + LmState lmState; + uint8_t rootId, spState; + + PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) + : lmState{ _lmState }, rootId{ _rootId }, spState{ _spState } + { + } + + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } + { + } + + bool operator==(const PathHash& o) const + { + return lmState == o.lmState && rootId == o.rootId && spState == o.spState; + } + }; + + template + struct Hash> + { + size_t operator()(const PathHash& p) const + { + size_t ret = Hash{}(p.lmState); + ret = *reinterpret_cast(&p.rootId) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + return ret; + } + }; + + struct WordLLGreater + { + template + bool operator()(const WordLL& a, const WordLL& b) const + { + return a.accScore > b.accScore; + } + }; + + template + inline std::ostream& printDebugPath(std::ostream& os, const WordLL& src) + { + if (src.parent) + { + printDebugPath(os, *src.parent); + } + + if (src.morpheme) src.morpheme->print(os); + else os << "NULL"; + os << " , "; + return os; + } + + template + class BestPathConatiner; + + template + struct BestPathContainerTraits + { + static constexpr size_t maxSize = -1; + }; + + template + class BestPathConatiner + { + // pair: [index, size] + UnorderedMap, std::pair> bestPathIndex; + Vector> bestPathValues; + public: + + inline void clear() + { + bestPathIndex.clear(); + bestPathValues.clear(); + } + + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + PathHash ph{ lmState, prevRootId, spState }; + auto inserted = bestPathIndex.emplace(ph, std::make_pair((uint32_t)bestPathValues.size(), 1)); + if (inserted.second) + { + bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); + if (rootId != commonRootId) bestPathValues.back().rootId = rootId; + bestPathValues.resize(bestPathValues.size() + topN - 1); + } + else + { + auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; + auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; + if (std::distance(bestPathFirst, bestPathLast) < topN) + { + *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; + if (rootId != commonRootId) bestPathLast->rootId = rootId; + std::push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); + ++inserted.first->second.second; + } + else + { + if (accScore > bestPathFirst->accScore) + { + std::pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; + if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; + std::push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); + } + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& p : bestPathIndex) + { + const auto index = p.second.first; + const auto size = p.second.second; + for (size_t i = 0; i < size; ++i) + { + resultOut.emplace_back(std::move(bestPathValues[index + i])); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->isSingle()) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + } + }; + + template + class BestPathConatiner + { + UnorderedSet> bestPathes; + public: + inline void clear() + { + bestPathes.clear(); + } + + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + WordLL newPath{ morph, accScore, accTypoCost, parent, std::move(lmState), spState }; + newPath.prevRootId = prevRootId; + if (rootId != commonRootId) newPath.rootId = rootId; + auto inserted = bestPathes.emplace(newPath); + if (!inserted.second) + { + // this is dangerous, but we can update the key safely + // because an equality between the two objects is guaranteed + auto& target = const_cast&>(*inserted.first); + if (accScore > target.accScore) + { + target = newPath; + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& p : bestPathes) + { + resultOut.emplace_back(std::move(p)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->isSingle()) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + }; + + template<> + struct BestPathContainerTraits + { + static constexpr size_t maxSize = (sizeof(size_t) == 8 ? 64 : 32) * 2; + }; + + template<> + struct BestPathContainerTraits + { + static constexpr size_t maxSize = BestPathContainerTraits::maxSize * 4; + }; + + template + class BucketedHashContainer + { + static constexpr size_t bucketSize = 1 << bucketBits; + + std::array::maxSize>, bucketSize> hashes; + std::array>, bucketSize> values; + + public: + BucketedHashContainer() + { + for (auto& v : values) + { + v.reserve(BestPathContainerTraits::maxSize); + } + } + + inline void clear() + { + for (auto& v : values) + { + v.clear(); + } + } + + template + inline void insertOptimized(size_t topN, uint8_t prevRootId, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + static constexpr size_t numBits = sizeof(size_t) * 8; + const size_t h = Hash>{}(lmState, prevRootId, spState); + const size_t bucket = (h >> 8) & (bucketSize - 1); + auto& hash = hashes[bucket]; + auto& value = values[bucket]; + + size_t it = value.size(); + size_t bits[2]; + bits[0] = nst::findAll(hash.data(), std::min(value.size(), numBits), (uint8_t)h); + bits[1] = value.size() > numBits ? nst::findAll(hash.data() + numBits, value.size() - numBits, (uint8_t)h) : 0; + while (bits[0]) + { + const size_t i = utils::countTrailingZeroes(bits[0]); + if (value[i].equalTo(lmState, prevRootId, spState)) + { + it = i; + goto breakloop; + } + bits[0] &= ~((size_t)1 << i); + } + while (bits[1]) + { + const size_t i = utils::countTrailingZeroes(bits[1]); + if (value[i].equalTo(lmState, prevRootId, spState)) + { + it = i + numBits; + goto breakloop; + } + bits[1] &= ~((size_t)1 << i); + } + + breakloop:; + if (it >= value.size()) + { + if (value.size() < hash.size()) + { + hash[value.size()] = h; + value.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); + value.back().prevRootId = prevRootId; + if (rootId != commonRootId) value.back().rootId = rootId; + } + else + { + // skip insertion if container is full. + // this isn't correct, but it rarely happens + } + } + else + { + auto& target = value[it]; + if (accScore > target.accScore) + { + target.morpheme = morph; + target.accScore = accScore; + target.accTypoCost = accTypoCost; + target.parent = parent; + target.lmState = std::move(lmState); + target.spState = spState; + target.rootId = parent ? parent->rootId : 0; + if (rootId != commonRootId) target.rootId = rootId; + } + } + } + + inline void insert(size_t topN, uint8_t prevRootId, uint8_t rootId, + const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + { + static constexpr ArchType archType = LmState::arch; + if constexpr (archType != ArchType::none && archType != ArchType::balanced) + { + return insertOptimized(topN, prevRootId, rootId, morph, accScore, accTypoCost, parent, std::move(lmState), spState); + } + + const size_t h = Hash>{}(lmState, prevRootId, spState); + const size_t bucket = (h >> 8) & (bucketSize - 1); + auto& hash = hashes[bucket]; + auto& value = values[bucket]; + + const auto hashEnd = hash.begin() + value.size(); + auto it = std::find(hash.begin(), hashEnd, (uint8_t)h); + while (it != hashEnd) + { + if (value[it - hash.begin()].equalTo(lmState, prevRootId, spState)) + { + break; + } + ++it; + it = std::find(it, hashEnd, (uint8_t)h); + } + + if (it == hashEnd) + { + if (value.size() < hash.size()) + { + hash[value.size()] = h; + value.emplace_back(morph, accScore, accTypoCost, parent, std::move(lmState), spState); + value.back().prevRootId = prevRootId; + if (rootId != commonRootId) value.back().rootId = rootId; + } + else + { + // skip insertion if container is full. + // this isn't correct, but it rarely happens + } + } + else + { + auto& target = value[it - hash.begin()]; + if (accScore > target.accScore) + { + target.morpheme = morph; + target.accScore = accScore; + target.accTypoCost = accTypoCost; + target.parent = parent; + target.lmState = std::move(lmState); + target.spState = spState; + target.rootId = parent ? parent->rootId : 0; + if (rootId != commonRootId) target.rootId = rootId; + } + } + } + + inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + { + for (auto& v : values) + { + for (auto& p : v) + { + resultOut.emplace_back(std::move(p)); + auto& newPath = resultOut.back(); + + // fill the rest information of resultOut + newPath.wid = lastSeqId; + if (curMorph->isSingle()) + { + newPath.combineSocket = curMorph->combineSocket; + newPath.ownFormId = ownFormId; + } + } + } + } + }; + + + template + class alignas(BestPathContainerTraits::maxSize) BestPathConatiner + : public BucketedHashContainer + { + }; + + template + class alignas(BestPathContainerTraits::maxSize) BestPathConatiner + : public BucketedHashContainer + { + }; +} diff --git a/src/CoNgramModel.cpp b/src/CoNgramModel.cpp new file mode 100644 index 00000000..cc436edc --- /dev/null +++ b/src/CoNgramModel.cpp @@ -0,0 +1,1753 @@ +#include +#include +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" +#include "CoNgramModel.hpp" +#include "StrUtils.h" +#include "FrozenTrie.hpp" +#include "qgemm.h" +#include "gemm.h" + +using namespace std; + +namespace kiwi +{ + inline size_t padMultipleOf(size_t n, size_t multiple) + { + return (n + multiple - 1) / multiple * multiple; + } + + template + struct MorphemeEvaluator> + { + using LmState = lm::CoNgramState; + + template + void eval( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const size_t totalPrevPathes, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ) const + { + thread_local BestPathConatiner bestPathCont; + thread_local Vector*> regularPrevPathes; + thread_local Vector*>> combiningPrevPathes; + thread_local Vector regularMorphs, regularDistantMorphs, combiningLMorphs, combiningRMorphs; + thread_local Vector prevLmStates, nextLmStates; + thread_local Vector nextWids, nextDistantWids; + thread_local Vector scores; + + const auto* langMdl = static_cast*>(kw->getLangModel()); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const size_t langVocabSize = langMdl->vocabSize(); + + regularPrevPathes.clear(); + combiningPrevPathes.clear(); + regularMorphs.clear(); + regularDistantMorphs.clear(); + combiningLMorphs.clear(); + combiningRMorphs.clear(); + prevLmStates.clear(); + nextWids.clear(); + nextDistantWids.clear(); + + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) + { + if (prevPath.combineSocket) + { + combiningPrevPathes.emplace_back(prev, &prevPath); + continue; + } + regularPrevPathes.emplace_back(&prevPath); + } + } + + prevLmStates.resize(regularPrevPathes.size()); + for (size_t i = 0; i < regularPrevPathes.size(); ++i) + { + prevLmStates[i] = regularPrevPathes[i]->lmState; + } + + for (auto& curMorph : morphs) + { + if (curMorph->combineSocket) + { + (curMorph->isSingle() ? combiningLMorphs : combiningRMorphs).emplace_back(curMorph); + continue; + } + Wid firstWid; + if (curMorph->isSingle()) + { + firstWid = curMorph->lmMorphemeId; + } + else + { + firstWid = curMorph->chunks[0]->lmMorphemeId; + } + + if (morphBase[firstWid].tag == POSTag::p) + { + continue; + } + if (windowSize > 0 && langMdl->distantTokenMask(firstWid)) + { + regularDistantMorphs.emplace_back(curMorph); + nextDistantWids.emplace_back(firstWid); + } + else + { + regularMorphs.emplace_back(curMorph); + nextWids.emplace_back(firstWid); + } + } + + if (windowSize > 0) + { + regularMorphs.insert(regularMorphs.end(), regularDistantMorphs.begin(), regularDistantMorphs.end()); + nextWids.insert(nextWids.end(), nextDistantWids.begin(), nextDistantWids.end()); + } + + if (prevLmStates.size() > 0 && nextWids.size() > 0) + { + if (prevLmStates.size() == 1 && nextWids.size() == 1) + { + nextLmStates.resize(1); + scores.resize(1); + nextLmStates[0] = prevLmStates[0]; + scores[0] = nextLmStates[0].next(langMdl, nextWids[0]); + } + else + { + nextLmStates.resize(prevLmStates.size() * nextWids.size()); + scores.resize(prevLmStates.size() * nextWids.size()); + langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data()); + } + } + + for (size_t curId = 0; curId < regularMorphs.size(); ++curId) + { + const auto* curMorph = regularMorphs[curId]; + bestPathCont.clear(); + + size_t length = 1; + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + length = curMorph->chunks.size(); + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + size_t prevId = -1; + for (auto* prevPath : regularPrevPathes) + { + ++prevId; + auto& state = nextLmStates[prevId * regularMorphs.size() + curId]; + auto score = prevPath->accScore + morphScore + scores[prevId * regularMorphs.size() + curId]; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + + for (size_t i = 1; i < length; ++i) + { + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto continueFor; + } + score += state.next(langMdl, wid); + } + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + continueFor:; + } + + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + + for (auto* curMorph : combiningLMorphs) + { + bestPathCont.clear(); + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + for (auto* prevPath : regularPrevPathes) + { + auto state = prevPath->lmState; + float score = prevPath->accScore + morphScore; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + } + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + + for (auto* curMorph : combiningRMorphs) + { + bestPathCont.clear(); + size_t length = 1; + const Morpheme* lastMorph; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + } + // if the morpheme has chunk set + else + { + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + length = curMorph->chunks.size(); + } + + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; + } + + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + for (auto& p : combiningPrevPathes) + { + auto* prev = p.first; + auto* prevPath = p.second; + float score = prevPath->accScore + morphScore; + // merge with only the same socket + if (prevPath->combineSocket != curMorph->combineSocket || curMorph->isSingle()) + { + continue; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) score -= spacePenalty; + else continue; + } + Wid firstWid = morphBase[prevPath->wid].getCombined()->lmMorphemeId; + + FormEvaluator formEvaluator{ *prevPath, ownForms, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, score)) continue; + + auto state = prevPath->lmState; + score += state.next(langMdl, firstWid); + + for (size_t i = 1; i < length; ++i) + { + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto continueFor2; + } + score += state.next(langMdl, wid); + } + + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(state), score, node, *prevPath, ruleBasedScorer); + continueFor2:; + } + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + } + }; + + namespace lm + { + inline float half2float(uint16_t h) + { + union + { + uint32_t i; + float f; + } u; + u.i = (uint32_t)(h & 0x8000) << 16; + u.i |= ((uint32_t)(h & 0x7FFF) + 0x1C000) << 13; + return u.f; + } + + inline void dequantize(float* out, const int8_t* ints, size_t n, float scale) + { + for (size_t i = 0; i < n; ++i) + { + out[i] = ints[i] * scale; + } + } + + inline void addBias(uint8_t* out, const int8_t* ints, size_t n) + { + for (size_t i = 0; i < n; ++i) + { + out[i] = ints[i] + 128; + } + } + + template + CoNgramModel::CoNgramModel(utils::MemoryObject&& mem) : CoNgramModelBase{ mem } + { + auto* ptr = reinterpret_cast(mem.get()); + + Vector nodeSizes(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.nodeOffset), nodeSizes.data(), header.numNodes); + + static constexpr size_t kvAlignment = ArchInfo::alignment; + size_t paddedKVSize = 0; + for (size_t i = 0; i < nodeSizes.size(); ++i) + { + if (!nodeSizes[i]) continue; + paddedKVSize += padMultipleOf(nodeSizes[i] * (sizeof(VlKeyType) + sizeof(int32_t)), kvAlignment); + } + + keyValueData = make_unique(paddedKVSize + kvAlignment); + alignedKeyValueData = reinterpret_cast(padMultipleOf(reinterpret_cast(keyValueData.get()), kvAlignment)); + auto keyData = make_unique(header.numNodes - 1); + if (std::is_same::value) + { + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), (uint32_t*)keyData.get(), header.numNodes - 1); + } + else + { + Vector tempKeyData(header.numNodes - 1); + streamvbyte_decode(reinterpret_cast(ptr + header.keyOffset), tempKeyData.data(), header.numNodes - 1); + std::copy(tempKeyData.begin(), tempKeyData.end(), keyData.get()); + } + Vector values(header.numNodes); + streamvbyte_decode_0124(reinterpret_cast(ptr + header.valueOffset), values.data(), header.numNodes); + + size_t numNonLeafNodes = 0, numLeafNodes = 0; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) numNonLeafNodes++; + else numLeafNodes++; + } + + nodeData = make_unique(numNonLeafNodes); + auto valueData = make_unique(header.numNodes - 1); + + size_t nonLeafIdx = 0, leafIdx = 0, nextOffset = 0; + Vector> keyRanges; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (nodeSizes[i]) + { + auto& node = nodeData[nonLeafIdx]; + if (!keyRanges.empty()) + { + auto& back = keyRanges.back(); + valueData[back[1]] = nonLeafIdx - back[0]; + } + node.value = values[i]; + node.numNexts = nodeSizes[i]; + keyRanges.emplace_back(std::array{ nonLeafIdx, (size_t)nextOffset, (size_t)(nextOffset + node.numNexts) }); + nextOffset += nodeSizes[i]; + nonLeafIdx++; + } + else + { + auto& back = keyRanges.back(); + valueData[back[1]] = -(int32_t)values[i]; + back[1]++; + while (keyRanges.back()[1] == keyRanges.back()[2]) + { + keyRanges.pop_back(); + if (keyRanges.empty()) break; + keyRanges.back()[1]++; + } + leafIdx++; + } + } + + uint8_t* kvDataPtr = const_cast(alignedKeyValueData); + nonLeafIdx = 0; + nextOffset = 0; + for (size_t i = 0; i < header.numNodes; ++i) + { + if (!nodeSizes[i]) continue; + auto& node = nodeData[nonLeafIdx]; + node.nextOffset = (uint32_t)(kvDataPtr - alignedKeyValueData); + memcpy(kvDataPtr, &keyData[nextOffset], node.numNexts * sizeof(VlKeyType)); + memcpy(kvDataPtr + node.numNexts * sizeof(VlKeyType), &valueData[nextOffset], node.numNexts * sizeof(int32_t)); + kvDataPtr += node.numNexts * (sizeof(VlKeyType) + sizeof(int32_t)); + nextOffset += node.numNexts; + nonLeafIdx++; + } + + allRootValueData = make_unique(header.vocabSize); + memset(allRootValueData.get(), 0, sizeof(int32_t) * header.vocabSize); + for (size_t i = 0; i < nodeData[0].numNexts; ++i) + { + allRootValueData[keyData[i]] = valueData[i]; + } + Vector tempBuf; + for (size_t i = 0; i < nonLeafIdx; ++i) + { + auto& node = nodeData[i]; + nst::prepareKV(const_cast(&alignedKeyValueData[node.nextOffset]), node.numNexts, tempBuf); + } + + Deque dq; + for (dq.emplace_back(&nodeData[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->numNexts; ++i) + { + auto kv = nst::extractKV(&alignedKeyValueData[p->nextOffset], p->numNexts, i); + if (kv.second <= 0) continue; + auto* child = &p[kv.second]; + child->lower = findLowerNode(p, kv.first) - child; + if (child->value == 0) + { + child->value = findLowerValue(p, kv.first); + } + dq.emplace_back(child); + } + } + + { + const size_t contextEmbSize = header.contextSize * contextEmbStride(); + const size_t distantEmbSize = windowSize > 0 ? header.vocabSize * distantEmbStride() : 0; + const size_t outputEmbSize = header.vocabSize * outputEmbStride(); + const size_t positionConfSize = windowSize > 0 ? (header.windowSize + 1) * sizeof(float) : 0; + const size_t distantMaskSize = windowSize > 0 ? (header.vocabSize + 7) / 8 : 0; + + allEmbs = make_unique(contextEmbSize + outputEmbSize + distantEmbSize + positionConfSize + distantMaskSize); + auto p = allEmbs.get(); + contextEmbPtr = reinterpret_cast(p); + distantEmbPtr = reinterpret_cast(p += contextEmbSize); + outputEmbPtr = reinterpret_cast(p += distantEmbSize); + positionConfidPtr = reinterpret_cast(p += outputEmbSize); + distantMaskPtr = reinterpret_cast(p += positionConfSize); + if (windowSize == 0) + { + distantEmbPtr = nullptr; + positionConfidPtr = nullptr; + distantMaskPtr = nullptr; + } + } + + auto* eptr = ptr + header.embOffset; + auto* optr = const_cast(contextEmbPtr); + for (size_t i = 0; i < header.contextSize; ++i) + { + if (quantized) + { + addBias(optr, reinterpret_cast(eptr), header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // scale + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), reinterpret_cast(eptr), header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } + + *reinterpret_cast(optr) = -half2float(*reinterpret_cast(eptr)); // bias + optr += sizeof(float); + eptr += sizeof(uint16_t); + if (windowSize > 0) + { + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // confidence + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // valid token sum + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + eptr += sizeof(uint16_t) * 2; + } + } + + optr = const_cast(outputEmbPtr); + for (size_t i = 0; i < header.vocabSize; ++i) + { + auto* qvals = reinterpret_cast(eptr); + if (quantized) + { + memcpy(optr, qvals, header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = accumulate(qvals, qvals + header.dim, 0) * 128; + optr += sizeof(int32_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), qvals, header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } + } + + if (windowSize > 0) + { + optr = const_cast(distantEmbPtr); + for (size_t i = 0; i < header.vocabSize; ++i) + { + if (quantized) + { + addBias(optr, reinterpret_cast(eptr), header.dim); + optr += header.dim; + eptr += header.dim; + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // scale + optr += sizeof(float); + eptr += sizeof(uint16_t); + } + else + { + const float scale = half2float(*reinterpret_cast(eptr + header.dim)); + dequantize(reinterpret_cast(optr), reinterpret_cast(eptr), header.dim, scale); + optr += header.dim * sizeof(float); + eptr += header.dim + sizeof(uint16_t); + } + + *reinterpret_cast(optr) = -half2float(*reinterpret_cast(eptr)); // bias + optr += sizeof(float); + eptr += sizeof(uint16_t); + *reinterpret_cast(optr) = half2float(*reinterpret_cast(eptr)); // confidence + optr += sizeof(float); + eptr += sizeof(uint16_t); + if (quantized) + { + optr += sizeof(float); + } + } + + const_cast(positionConfidPtr)[0] = 0; + for (size_t i = 0; i < header.windowSize; ++i) + { + const_cast(positionConfidPtr)[i + 1] = half2float(*reinterpret_cast(eptr)); + eptr += sizeof(uint16_t); + } + + optr = const_cast(distantMaskPtr); + const size_t compressedDistantMaskSize = (header.vocabSize + 7) / 8; + std::copy(eptr, eptr + compressedDistantMaskSize, optr); + } + } + + template + float CoNgramModel::progress(int32_t& nodeIdx, + uint32_t& contextIdx, + std::array& history, + KeyType next) const + { + const bool validDistantToken = distantTokenMask(next); + float ll = 0; + + if (windowSize > 0 && validDistantToken) + { + if constexpr (quantized) + { + int32_t contextIdcs[1 + windowSize]; + float lls[(1 + windowSize) * 2]; + int32_t nextIdx[1] = { (int32_t)next }; + + memcpy(lls, positionConfidPtr, (windowSize + 1) * sizeof(float)); + lls[0] += getContextConfid(contextIdx); + contextIdcs[0] = contextIdx; + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[i]; + lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; + contextIdcs[i + 1] = (historyToken ? historyToken : 0) + header.contextSize; + } + logSoftmax(lls, windowSize + 1); + qgemm::scatteredGEMMOpt( + 1 + windowSize, 1, header.dim, + getContextQuantEmb(0), contextIdcs, contextEmbStride(), + getOutputQuantEmb(0), nextIdx, outputEmbStride(), + &lls[1 + windowSize], 1); + for (size_t i = 0; i < 1 + windowSize; ++i) + { + lls[i] += lls[i + 1 + windowSize]; + } + lls[0] -= getContextValidTokenSum(contextIdx); + ll = logSumExp(lls, windowSize + 1); + ll += getContextValidTokenSum(contextIdx); + } + else + { + thread_local Eigen::MatrixXf mat; + mat.resize(header.dim, 1 + windowSize); + thread_local Eigen::VectorXf lls; + lls.resize(1 + windowSize); + + memcpy(lls.data(), positionConfidPtr, (windowSize + 1) * sizeof(float)); + lls[0] += getContextConfid(contextIdx); + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[i]; + lls[i + 1] += historyToken ? getDistantConfid(historyToken) : -99999; + } + logSoftmax(lls.data(), windowSize + 1); + + memcpy(mat.col(0).data(), getContextEmb(contextIdx), header.dim * sizeof(float)); + lls[0] += getContextBias(contextIdx); + for (size_t i = 0; i < windowSize; ++i) + { + const auto historyToken = history[i]; + if (historyToken) memcpy(mat.col(i + 1).data(), getDistantEmb(historyToken), header.dim * sizeof(float)); + else memset(mat.col(i + 1).data(), 0, header.dim * sizeof(float)); + lls[i + 1] += getDistantBias(historyToken); + } + lls.tail(windowSize).array() += getContextValidTokenSum(contextIdx); + Eigen::Map outputVec{ getOutputEmb(next), header.dim }; + gemm::template gemv( + mat.cols(), mat.rows(), + mat.data(), mat.colStride(), + outputVec.data(), lls.data() + ); + ll = logSumExp(lls.data(), windowSize + 1); + } + } + else + { + if constexpr (quantized) + { + const auto* contextPtr = getContextQuantEmb(contextIdx); + const auto* outputPtr = getOutputQuantEmb(next); + + int32_t acc = qgemm::dotprod(contextPtr, outputPtr, header.dim); + const float contextScale = *reinterpret_cast(contextPtr + header.dim), + outputScale = *reinterpret_cast(outputPtr + header.dim), + contextBias = *reinterpret_cast(contextPtr + header.dim + sizeof(float)); + const int32_t hsum = *reinterpret_cast(outputPtr + header.dim + sizeof(float)); + acc -= hsum; + ll = acc * contextScale * outputScale + contextBias; + } + else + { + ll = getContextBias(contextIdx); + Eigen::Map contextVec{ getContextEmb(contextIdx), header.dim }; + Eigen::Map outputVec{ getOutputEmb(next), header.dim }; + gemm::template gemv( + 1, header.dim, + contextVec.data(), contextVec.colStride(), + outputVec.data(), &ll + ); + } + } + + contextIdx = progressContextNode(nodeIdx, next); + if (windowSize > 0) + { + if (history[windowSize]) + { + memcpy(&history[0], &history[1], windowSize * sizeof(KeyType)); + } + history[windowSize] = validDistantToken ? next : 0; + } + return ll; + } + + template + struct CoNgramModel::TLSForProgressMatrix + { + Vector> contextCache; + Vector contextIdcs, historyIdcs, nextIdcs; + Vector inverseContextIdcs, inverseHistoryIdcs, inverseNextIdcs; + Vector resultBuf, confidenceBuf, scoreBuf; + UnorderedMap historyMap; + Vector uniqHistoryTokens; + Vector inputEmbBuf, outputEmbBuf; // only for non-quantized + Vector contextIdcs2, nextIdcs2; // only for quantized + }; + + // specialization for windowSize > 0 + template + template + inline auto CoNgramModel::nextState( + const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next, + bool cacheIsValid, pair& cache) const -> LmStateType + { + LmStateType ret{ state.node }; // partially initialized + if (cacheIsValid) + { + ret.node = cache.first; + ret.contextIdx = cache.second; + } + else + { + ret.contextIdx = progressContextNode(ret.node, next); + cache = std::make_pair(ret.node, ret.contextIdx); + } + + if (state.history[windowSize]) + { + memcpy(&ret.history[0], &state.history[1], windowSize * sizeof(KeyType)); + } + else + { + memcpy(&ret.history[0], &state.history[0], windowSize * sizeof(KeyType)); + } + ret.history[windowSize] = distantTokenMask(next) ? next : 0; + return ret; + } + + // specialization for windowSize == 0 + template + template + inline auto CoNgramModel::nextState( + const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next, + bool cacheIsValid, pair& cache) const -> LmStateType + { + LmStateType ret{ state.node }; // partially initialized + if (cacheIsValid) + { + ret.node = cache.first; + ret.contextIdx = cache.second; + } + else + { + ret.contextIdx = progressContextNode(ret.node, next); + cache = std::make_pair(ret.node, ret.contextIdx); + } + return ret; + } + + inline uint64_t mergePair(uint32_t a, uint32_t b) + { + return ((uint64_t)a << 32) | b; + } + + inline pair splitPair(uint64_t a) + { + return make_pair(a >> 32, a & 0xFFFFFFFF); + } + + template + inline void CoNgramModel::progressMatrixWSort( + TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + static constexpr size_t scoreBatchSize = 32; + + tls.contextIdcs.resize(prevStateSize); + tls.historyIdcs.clear(); + tls.nextIdcs.resize(nextIdSize); + tls.inverseContextIdcs.resize(prevStateSize); + tls.inverseHistoryIdcs.clear(); + tls.inverseHistoryIdcs.resize(prevStateSize * windowSize, -1); + tls.inverseNextIdcs.resize(nextIdSize); + if (quantized) + { + tls.contextIdcs2.clear(); + tls.nextIdcs2.clear(); + } + else + { + tls.inputEmbBuf.resize(prevStateSize * header.dim); + tls.outputEmbBuf.resize(nextIdSize * header.dim); + } + tls.confidenceBuf.resize(prevStateSize * 2); + tls.scoreBuf.resize(scoreBatchSize * (windowSize + 2)); + + const size_t numInvalidDistantTokens = nextIdSize - numValidDistantTokens; + for (size_t i = 0; i < nextIdSize; ++i) + { + tls.nextIdcs[i] = mergePair(nextIds[i], i); + } + sort(tls.nextIdcs.begin(), tls.nextIdcs.begin() + numInvalidDistantTokens); + sort(tls.nextIdcs.begin() + numInvalidDistantTokens, tls.nextIdcs.end()); + size_t uniqOutputSize = 0; + for (size_t i = 0; i < nextIdSize; ++i) + { + const auto nextId = splitPair(tls.nextIdcs[i]).first; + const auto idx = splitPair(tls.nextIdcs[i]).second; + if (i == 0 || nextId != splitPair(tls.nextIdcs[i - 1]).first) + { + if (quantized) + { + tls.nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &tls.outputEmbBuf[uniqOutputSize * header.dim]); + } + uniqOutputSize++; + } + tls.inverseNextIdcs[idx] = uniqOutputSize - 1; + } + tls.resultBuf.resize(prevStateSize * uniqOutputSize); + + for (size_t i = 0; i < prevStateSize; ++i) + { + tls.contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); + } + sort(tls.contextIdcs.begin(), tls.contextIdcs.end()); + size_t uniqInputSize = 0; + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = splitPair(tls.contextIdcs[i]).first; + const auto idx = splitPair(tls.contextIdcs[i]).second; + if (i == 0 || contextId != splitPair(tls.contextIdcs[i - 1]).first) + { + if (quantized) + { + tls.contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &tls.inputEmbBuf[uniqInputSize * header.dim]); + fill(&tls.resultBuf[uniqInputSize * uniqOutputSize], &tls.resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); + } + tls.confidenceBuf[uniqInputSize * 2] = getContextConfid(contextId); + tls.confidenceBuf[uniqInputSize * 2 + 1] = getContextValidTokenSum(contextId); + uniqInputSize++; + } + tls.inverseContextIdcs[idx] = uniqInputSize - 1; + } + + size_t uniqHistorySize = 0; + if (prevStateSize <= 8) // use vector for small size + { + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[j]; + if (historyToken) + { + tls.historyIdcs.emplace_back(mergePair(historyToken, i * windowSize + j)); + } + } + } + sort(tls.historyIdcs.begin(), tls.historyIdcs.end()); + uniqHistorySize = 0; + for (size_t i = 0; i < tls.historyIdcs.size(); ++i) + { + const auto historyToken = splitPair(tls.historyIdcs[i]).first; + const auto idx = splitPair(tls.historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(tls.historyIdcs[i - 1]).first) + { + uniqHistorySize++; + } + tls.inverseHistoryIdcs[idx] = uniqHistorySize - 1; + } + if (!quantized) + { + tls.inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); + } + tls.confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + tls.resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); + + uniqHistorySize = 0; + for (size_t i = 0; i < tls.historyIdcs.size(); ++i) + { + const auto historyToken = splitPair(tls.historyIdcs[i]).first; + const auto idx = splitPair(tls.historyIdcs[i]).second; + if (i == 0 || historyToken != splitPair(tls.historyIdcs[i - 1]).first) + { + if (quantized) + { + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(uniqInputSize + uniqHistorySize) * header.dim]); + fill(&tls.resultBuf[(uniqInputSize + uniqHistorySize) * uniqOutputSize], &tls.resultBuf[(uniqInputSize + uniqHistorySize + 1) * uniqOutputSize], getDistantBias(historyToken)); + } + tls.confidenceBuf[uniqInputSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); + uniqHistorySize++; + } + } + } + else // use map for large size + { + tls.historyMap.clear(); + tls.uniqHistoryTokens.clear(); + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[j]; + if (!historyToken) continue; + const auto idx = i * windowSize + j; + auto inserted = tls.historyMap.emplace(historyToken, tls.historyMap.size()); + tls.inverseHistoryIdcs[idx] = inserted.first->second; + if (inserted.second) tls.uniqHistoryTokens.emplace_back(historyToken); + } + } + uniqHistorySize = tls.historyMap.size(); + if (!quantized) + { + tls.inputEmbBuf.resize((uniqInputSize + uniqHistorySize) * header.dim); + } + tls.confidenceBuf.resize(uniqInputSize * 2 + uniqHistorySize); + tls.resultBuf.resize(padMultipleOf(uniqInputSize + uniqHistorySize, 8) * padMultipleOf(uniqOutputSize, 8)); + + for (size_t i = 0; i < tls.uniqHistoryTokens.size(); ++i) + { + const auto historyToken = tls.uniqHistoryTokens[i]; + if (quantized) + { + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(uniqInputSize + i) * header.dim]); + fill(&tls.resultBuf[(uniqInputSize + i) * uniqOutputSize], &tls.resultBuf[(uniqInputSize + i + 1) * uniqOutputSize], getDistantBias(historyToken)); + } + tls.confidenceBuf[uniqInputSize * 2 + i] = getDistantConfid(historyToken); + } + } + + Eigen::Map resultMap{ tls.resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + + if constexpr (quantized) + { + qgemm::scatteredGEMMOpt( + uniqInputSize + uniqHistorySize, uniqOutputSize, header.dim, + getContextQuantEmb(0), tls.contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), tls.nextIdcs2.data(), outputEmbStride(), + tls.resultBuf.data(), uniqOutputSize); + } + else + { + Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(uniqInputSize + uniqHistorySize) }; + Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); + } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + for (size_t j = 0; j < numInvalidDistantTokens; ++j) + { + outScores[i * nextIdSize + j] = resultMap(tls.inverseNextIdcs[j], tls.inverseContextIdcs[i]); + } + } + + auto* validTokenSumBuf = tls.scoreBuf.data() + scoreBatchSize * (windowSize + 1); + for (size_t i = 0; i < prevStateSize * numValidDistantTokens; i += scoreBatchSize) + { + const size_t batchSize = std::min(scoreBatchSize, prevStateSize * numValidDistantTokens - i); + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] = tls.confidenceBuf[tls.inverseContextIdcs[pIdx] * 2]; + validTokenSumBuf[j] = tls.confidenceBuf[tls.inverseContextIdcs[pIdx] * 2 + 1]; + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + tls.scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : tls.confidenceBuf[uniqInputSize * 2 + idx]; + } + } + Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; + scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; + logSoftmaxTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + scoreMap.template rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] += resultMap(tls.inverseNextIdcs[nIdx], tls.inverseContextIdcs[pIdx]); + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + if (idx != -1) + { + tls.scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(tls.inverseNextIdcs[nIdx], uniqInputSize + idx); + } + } + } + logSumExpTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + const bool cacheIsValid = pIdx > 0 && prevStates[pIdx].node == prevStates[pIdx - 1].node; + outScores[pIdx * nextIdSize + nIdx] = tls.scoreBuf[j]; + } + } + } + + template + inline void CoNgramModel::progressMatrixWOSort( + TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + static constexpr size_t scoreBatchSize = 32; + + if (quantized) + { + tls.contextIdcs2.clear(); + tls.nextIdcs2.clear(); + } + else + { + tls.inputEmbBuf.resize(prevStateSize * (1 + windowSize) * header.dim); + tls.outputEmbBuf.resize(nextIdSize * header.dim); + } + tls.confidenceBuf.resize(prevStateSize * (2 + windowSize)); + tls.scoreBuf.resize(scoreBatchSize * (windowSize + 2)); + tls.resultBuf.resize(padMultipleOf(prevStateSize * (1 + windowSize), 8) * padMultipleOf(nextIdSize, 8)); + + const size_t numInvalidDistantTokens = nextIdSize - numValidDistantTokens; + for (size_t i = 0; i < nextIdSize; ++i) + { + const auto nextId = nextIds[i]; + if (quantized) + { + tls.nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &tls.outputEmbBuf[i * header.dim]); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = prevStates[i].contextIdx; + if (quantized) + { + tls.contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &tls.inputEmbBuf[i * header.dim]); + fill(&tls.resultBuf[i * nextIdSize], &tls.resultBuf[(i + 1) * nextIdSize], getContextBias(contextId)); + } + tls.confidenceBuf[i * 2] = getContextConfid(contextId); + tls.confidenceBuf[i * 2 + 1] = getContextValidTokenSum(contextId); + } + + size_t uniqHistorySize = 0; + tls.inverseHistoryIdcs.clear(); + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < windowSize; ++j) + { + const auto historyToken = prevStates[i].history[j]; + if (historyToken) + { + if (quantized) + { + tls.contextIdcs2.emplace_back(historyToken + header.contextSize); + } + else + { + copy(getDistantEmb(historyToken), getDistantEmb(historyToken) + header.dim, &tls.inputEmbBuf[(prevStateSize + uniqHistorySize) * header.dim]); + fill(&tls.resultBuf[(prevStateSize + uniqHistorySize) * nextIdSize], &tls.resultBuf[(prevStateSize + uniqHistorySize + 1) * nextIdSize], getDistantBias(historyToken)); + } + tls.confidenceBuf[prevStateSize * 2 + uniqHistorySize] = getDistantConfid(historyToken); + uniqHistorySize++; + } + tls.inverseHistoryIdcs.emplace_back(historyToken ? uniqHistorySize - 1 : -1); + + } + } + + Eigen::Map resultMap{ tls.resultBuf.data(), (Eigen::Index)nextIdSize, (Eigen::Index)(prevStateSize + uniqHistorySize) }; + + if constexpr (quantized) + { + qgemm::scatteredGEMMOpt( + prevStateSize + uniqHistorySize, nextIdSize, header.dim, + getContextQuantEmb(0), tls.contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), tls.nextIdcs2.data(), outputEmbStride(), + tls.resultBuf.data(), nextIdSize); + } + else + { + Eigen::Map inputMap{ tls.inputEmbBuf.data(), header.dim, (Eigen::Index)(prevStateSize + uniqHistorySize) }; + Eigen::Map outputMap{ tls.outputEmbBuf.data(), header.dim, (Eigen::Index)nextIdSize }; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); + } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + for (size_t j = 0; j < numInvalidDistantTokens; ++j) + { + outScores[i * nextIdSize + j] = resultMap(j, i); + } + } + + auto* validTokenSumBuf = tls.scoreBuf.data() + scoreBatchSize * (windowSize + 1); + + for (size_t i = 0; i < prevStateSize * numValidDistantTokens; i += scoreBatchSize) + { + const size_t batchSize = std::min(scoreBatchSize, prevStateSize * numValidDistantTokens - i); + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] = tls.confidenceBuf[pIdx * 2]; + validTokenSumBuf[j] = tls.confidenceBuf[pIdx * 2 + 1]; + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + tls.scoreBuf[j + (k + 1) * scoreBatchSize] = idx == -1 ? -99999 : tls.confidenceBuf[prevStateSize * 2 + idx]; + } + } + Eigen::Map> scoreMap{ tls.scoreBuf.data(), (Eigen::Index)scoreBatchSize, windowSize + 1 }; + scoreMap.rowwise() += Eigen::Map>{ positionConfidPtr, 1, windowSize + 1 }; + logSoftmaxTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + scoreMap.template rightCols().colwise() += Eigen::Map>{ validTokenSumBuf, scoreBatchSize, 1 }; + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + tls.scoreBuf[j] += resultMap(nIdx, pIdx); + for (size_t k = 0; k < windowSize; ++k) + { + const auto idx = tls.inverseHistoryIdcs[pIdx * windowSize + k]; + if (idx != -1) + { + tls.scoreBuf[j + (k + 1) * scoreBatchSize] += resultMap(nIdx, prevStateSize + idx); + } + } + } + logSumExpTransposed(tls.scoreBuf.data(), windowSize + 1, batchSize, scoreBatchSize); + + for (size_t j = 0; j < batchSize; ++j) + { + const auto pIdx = (i + j) / numValidDistantTokens; + const auto nIdx = (i + j) % numValidDistantTokens + numInvalidDistantTokens; + outScores[pIdx * nextIdSize + nIdx] = tls.scoreBuf[j]; + } + } + } + + template + void CoNgramModel::progressMatrix( + const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + if constexpr (windowSize > 0) + { + thread_local TLSForProgressMatrix tls; + if (prevStateSize <= (quantized ? 16 : 8) && nextIdSize <= 16) + { + return progressMatrixWOSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + else + { + return progressMatrixWSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + } + else + { + return progressMatrixNoWindow(prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores); + } + } + + template + void CoNgramModel::progressMatrixNoWindow( + const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const + { + thread_local Vector contextIdcs, nextIdcs; + thread_local Vector inverseContextIdcs, inverseNextIdcs; + thread_local Vector inputEmbBuf, outputEmbBuf, resultBuf; + thread_local Vector contextIdcs2, nextIdcs2; + + contextIdcs.resize(prevStateSize); + nextIdcs.resize(nextIdSize); + inverseContextIdcs.resize(prevStateSize); + inverseNextIdcs.resize(nextIdSize); + if (quantized) + { + contextIdcs2.clear(); + nextIdcs2.clear(); + } + else + { + inputEmbBuf.resize(prevStateSize * header.dim); + outputEmbBuf.resize(nextIdSize * header.dim); + } + + for (size_t i = 0; i < nextIdSize; ++i) + { + nextIdcs[i] = mergePair(nextIds[i], i); + } + sort(nextIdcs.begin(), nextIdcs.end()); + size_t uniqOutputSize = 0; + for (size_t i = 0; i < nextIdSize; ++i) + { + const auto nextId = splitPair(nextIdcs[i]).first; + const auto idx = splitPair(nextIdcs[i]).second; + if (i == 0 || nextId != splitPair(nextIdcs[i - 1]).first) + { + if (quantized) + { + nextIdcs2.emplace_back(nextId); + } + else + { + copy(getOutputEmb(nextId), getOutputEmb(nextId) + header.dim, &outputEmbBuf[uniqOutputSize * header.dim]); + } + uniqOutputSize++; + } + inverseNextIdcs[idx] = uniqOutputSize - 1; + } + resultBuf.resize(padMultipleOf(prevStateSize, 8) * padMultipleOf(uniqOutputSize, 8)); + + for (size_t i = 0; i < prevStateSize; ++i) + { + contextIdcs[i] = mergePair(prevStates[i].contextIdx, i); + } + sort(contextIdcs.begin(), contextIdcs.end()); + size_t uniqInputSize = 0; + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto contextId = splitPair(contextIdcs[i]).first; + const auto idx = splitPair(contextIdcs[i]).second; + if (i == 0 || contextId != splitPair(contextIdcs[i - 1]).first) + { + if (quantized) + { + contextIdcs2.emplace_back(contextId); + } + else + { + copy(getContextEmb(contextId), getContextEmb(contextId) + header.dim, &inputEmbBuf[uniqInputSize * header.dim]); + fill(&resultBuf[uniqInputSize * uniqOutputSize], &resultBuf[(uniqInputSize + 1) * uniqOutputSize], getContextBias(contextId)); + } + uniqInputSize++; + } + inverseContextIdcs[idx] = uniqInputSize - 1; + } + + Eigen::Map resultMap{ resultBuf.data(), (Eigen::Index)uniqOutputSize, (Eigen::Index)uniqInputSize }; + if constexpr (quantized) + { + qgemm::scatteredGEMMOpt( + uniqInputSize, uniqOutputSize, header.dim, + getContextQuantEmb(0), contextIdcs2.data(), contextEmbStride(), + getOutputQuantEmb(0), nextIdcs2.data(), outputEmbStride(), + resultBuf.data(), uniqOutputSize); + } + else + { + Eigen::Map inputMap{ inputEmbBuf.data(), header.dim, (Eigen::Index)uniqInputSize }; + Eigen::Map outputMap{ outputEmbBuf.data(), header.dim, (Eigen::Index)uniqOutputSize }; + gemm::template gemm( + outputMap.cols(), inputMap.cols(), inputMap.rows(), + outputMap.data(), outputMap.colStride(), + inputMap.data(), inputMap.colStride(), + resultMap.data(), resultMap.colStride() + ); + } + + pair contextCache; + for (size_t j = 0; j < nextIdSize; ++j) + { + for (size_t i = 0; i < prevStateSize; ++i) + { + const auto& state = prevStates[i]; + const bool cacheIsValid = i > 0 && state.node == prevStates[i - 1].node; + outStates[i * nextIdSize + j] = nextState(state, nextIds[j], cacheIsValid, contextCache); + } + } + + for (size_t i = 0; i < prevStateSize; ++i) + { + for (size_t j = 0; j < nextIdSize; ++j) + { + outScores[i * nextIdSize + j] = resultMap(inverseNextIdcs[j], inverseContextIdcs[i]); + } + } + } + + static constexpr size_t serialAlignment = 16; + inline size_t alignedOffsetInc(size_t& offset, size_t inc, size_t alignment = serialAlignment) + { + return offset = (offset + inc + alignment - 1) & ~(alignment - 1); + } + + inline std::ostream& writePadding(std::ostream& os, size_t alignment = serialAlignment) + { + const size_t pos = os.tellp(); + size_t pad = ((pos + alignment - 1) & ~(alignment - 1)) - pos; + for (size_t i = 0; i < pad; ++i) + { + os.put(0); + } + return os; + } + + utils::MemoryObject CoNgramModelBase::build(const string& contextDefinition, const string& embedding, size_t maxContextLength, bool useVLE, bool reorderContextId) + { + ifstream contextStr, embeddingStr; + if (!openFile(contextStr, contextDefinition)) + { + throw IOException{ "Cannot open file : " + contextDefinition }; + } + + uint32_t maxClusterId = 0, maxContextId = 0; + size_t keySize = 0; + using Node = utils::TrieNodeEx>>; + utils::ContinuousTrie trie(1); + { + Vector, uint32_t>> contextMap; + UnorderedMap, uint32_t> erasedContexts; + Vector context; + string line; + while (getline(contextStr, line)) + { + auto tokens = split(line, '\t'); + if (tokens.size() <= 1) + { + throw IOException{ "Invalid format : " + contextDefinition }; + } + + auto clusterId = stol(tokens[0].begin(), tokens[0].end()); + if (clusterId < 0) throw IOException{ "Invalid format : " + contextDefinition }; + context.clear(); + for (size_t i = 1; i < tokens.size(); ++i) + { + auto id = stol(tokens[i].begin(), tokens[i].end()); + if (id < 0) throw IOException{ "Invalid format : " + contextDefinition }; + context.push_back(id); + maxContextId = max(maxContextId, (uint32_t)id); + } + if (context.size() > maxContextLength) + { + continue; + } + if (contextMap.size() < context.size()) contextMap.resize(context.size()); + contextMap[context.size() - 1][context] = (uint32_t)clusterId; + maxClusterId = max(maxClusterId, (uint32_t)(clusterId + 1)); + } + + for (size_t i = contextMap.size(); i-- > 0;) // remove redundant context + { + auto& c = contextMap[i]; + for (auto it = c.begin(); it != c.end();) + { + bool erase = false; + for (size_t j = i; j-- > 0; ) + { + auto& c2 = contextMap[j]; + context.clear(); + context.insert(context.end(), it->first.end() - j - 1, it->first.end()); + auto found = c2.find(context); + if (found != c2.end()) + { + erase = found->second == it->second; + break; + } + } + + if (erase) + { + if (it->first.size() < contextMap.size()) + { + erasedContexts.emplace(it->first, it->second); + } + it = c.erase(it); + } + else ++it; + } + } + + if (maxContextId <= 0xFFFF) + { + keySize = 2; + } + else if (useVLE && maxContextId <= 0xFFFFF) + { + keySize = 3; // variable length key + } + else + { + keySize = 4; + } + + for (auto& c : contextMap) + { + for (auto& p : c) + { + if (keySize == 3) + { + static constexpr size_t tMax = (1 << 16) - (1 << 10) * 2; + context.clear(); + for (auto id : p.first) + { + if (id < tMax) + { + context.emplace_back(id); + } + else + { + id -= tMax; + const size_t high = id >> 10, low = id & 0x3FF; + context.emplace_back(tMax + high); + context.emplace_back(tMax + (1 << 10) + low); + } + } + trie.build(context.begin(), context.end(), p.second + 1); + } + else + { + trie.build(p.first.begin(), p.first.end(), p.second + 1); + } + } + } + } + + Vector nodeSizes; + nodeSizes.reserve(trie.size()); + Vector keys; + keys.reserve(trie.size()); + Vector values; + values.reserve(trie.size()); + Vector valueNewIdx(maxClusterId + 1); + { + Vector valueCnts(valueNewIdx.size()); + Vector valueArgsorted(valueNewIdx.size()); + Vector rkeys; + trie.traverseWithKeys([&](const Node* node, const Vector& rkeys) + { + nodeSizes.emplace_back(node->next.size()); + for (auto& p : node->next) + { + keys.emplace_back(p.first); + } + values.emplace_back(node->val); + valueCnts[node->val]++; + }, rkeys); + + valueCnts[0] = -1; + + // remap value idx by frequency + if (reorderContextId) + { + iota(valueArgsorted.begin(), valueArgsorted.end(), 0); + sort(valueArgsorted.begin(), valueArgsorted.end(), [&](uint32_t a, uint32_t b) { return valueCnts[a] > valueCnts[b]; }); + for (size_t i = 0; i < valueArgsorted.size(); ++i) + { + valueNewIdx[valueArgsorted[i]] = (uint32_t)i; + } + for (auto& v : values) v = valueNewIdx[v]; + } + } + + assert(nodeSizes.size() - 1 == keys.size()); + + Vector compressedNodeSizes(streamvbyte_max_compressedbytes(nodeSizes.size())); + compressedNodeSizes.resize(streamvbyte_encode_0124(nodeSizes.data(), nodeSizes.size(), compressedNodeSizes.data())); + Vector compressedValues(streamvbyte_max_compressedbytes(values.size())); + compressedValues.resize(streamvbyte_encode_0124(values.data(), values.size(), compressedValues.data())); + Vector compressedKeys(streamvbyte_max_compressedbytes(keys.size())); + compressedKeys.resize(streamvbyte_encode(keys.data(), keys.size(), compressedKeys.data())); + + if (!openFile(embeddingStr, embedding, ios_base::binary)) + { + throw IOException{ "Cannot open file : " + embedding }; + } + const uint32_t dim = utils::read(embeddingStr); + const uint32_t contextSize = utils::read(embeddingStr); + const uint32_t outputSize = utils::read(embeddingStr); + const uint32_t windowSize = utils::read(embeddingStr); + + Vector contextEmb(dim * contextSize); + Vector contextEmbScale(contextSize); + Vector contextEmbBias(contextSize); + Vector contextValidTokenSum(contextSize); + Vector contextConfidence(contextSize); + Vector distantEmb(dim * outputSize); + Vector distantEmbScale(outputSize); + Vector distantEmbBias(outputSize); + Vector distantConfidence(outputSize); + vector positionConfidence(windowSize); + Vector outputEmb(dim * outputSize); + Vector outputEmbScale(outputSize); + Vector distantMask(outputSize); + + embeddingStr.read((char*)contextEmb.data(), contextEmb.size()); + embeddingStr.read((char*)contextEmbScale.data(), contextEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextEmbBias.data(), contextEmbBias.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextValidTokenSum.data(), contextValidTokenSum.size() * sizeof(uint16_t)); + embeddingStr.read((char*)contextConfidence.data(), contextConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantEmb.data(), distantEmb.size()); + embeddingStr.read((char*)distantEmbScale.data(), distantEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantEmbBias.data(), distantEmbBias.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantConfidence.data(), distantConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)positionConfidence.data(), positionConfidence.size() * sizeof(uint16_t)); + embeddingStr.read((char*)outputEmb.data(), outputEmb.size()); + embeddingStr.read((char*)outputEmbScale.data(), outputEmbScale.size() * sizeof(uint16_t)); + embeddingStr.read((char*)distantMask.data(), distantMask.size()); + + // remap context embedding + if (reorderContextId) + { + Vector newContextEmb(contextEmb.size()); + Vector newContextEmbScale(contextSize); + Vector newContextEmbBias(contextSize); + Vector newContextValidTokenSum(contextSize); + for (size_t i = 0; i < contextSize; ++i) + { + auto idx = valueNewIdx[i]; + auto src = contextEmb.data() + i * dim; + auto dst = newContextEmb.data() + idx * dim; + copy(src, src + dim, dst); + newContextEmbScale[idx] = contextEmbScale[i]; + newContextEmbBias[idx] = contextEmbBias[i]; + newContextValidTokenSum[idx] = contextValidTokenSum[i]; + } + contextEmb = move(newContextEmb); + contextEmbScale = move(newContextEmbScale); + contextEmbBias = move(newContextEmbBias); + contextValidTokenSum = move(newContextValidTokenSum); + } + + // compress distantMask into bits + const size_t compressedDistantMaskSize = (outputSize + 7) / 8; + { + for (size_t i = 0; i < outputSize; ++i) + { + if (i % 8 == 0) + { + distantMask[i / 8] = distantMask[i]; + } + else + { + distantMask[i / 8] |= distantMask[i] << (i % 8); + } + } + distantMask.resize(compressedDistantMaskSize); + } + + CoNgramModelHeader header; + memset(&header, 0, sizeof(CoNgramModelHeader)); + header.dim = dim; + header.contextSize = contextSize; + header.vocabSize = outputSize; + header.keySize = keySize; + header.windowSize = windowSize; + header.numNodes = nodeSizes.size(); + + size_t finalSize = 0; + header.nodeOffset = alignedOffsetInc(finalSize, sizeof(CoNgramModelHeader)); + header.keyOffset = alignedOffsetInc(finalSize, compressedNodeSizes.size()); + header.valueOffset = alignedOffsetInc(finalSize, compressedKeys.size()); + header.embOffset = alignedOffsetInc(finalSize, compressedValues.size()); + finalSize += dim * (contextSize + outputSize * 2); + finalSize += contextSize * sizeof(uint16_t) * 4; + finalSize += outputSize * sizeof(uint16_t) * 4; + finalSize += windowSize * sizeof(uint16_t); + finalSize += compressedDistantMaskSize; + + utils::MemoryOwner mem{ finalSize }; + utils::omstream ostr{ (char*)mem.get(), (std::ptrdiff_t)mem.size() }; + ostr.write((const char*)&header, sizeof(CoNgramModelHeader)); + writePadding(ostr); + ostr.write((const char*)compressedNodeSizes.data(), compressedNodeSizes.size()); + writePadding(ostr); + ostr.write((const char*)compressedKeys.data(), compressedKeys.size()); + writePadding(ostr); + ostr.write((const char*)compressedValues.data(), compressedValues.size()); + writePadding(ostr); + + for (size_t i = 0; i < contextSize; ++i) + { + ostr.write((const char*)&contextEmb[i * dim], dim); + ostr.write((const char*)&contextEmbScale[i], sizeof(uint16_t)); + ostr.write((const char*)&contextEmbBias[i], sizeof(uint16_t)); + ostr.write((const char*)&contextConfidence[i], sizeof(uint16_t)); + ostr.write((const char*)&contextValidTokenSum[i], sizeof(uint16_t)); + } + for (size_t i = 0; i < outputSize; ++i) + { + ostr.write((const char*)&outputEmb[i * dim], dim); + ostr.write((const char*)&outputEmbScale[i], sizeof(uint16_t)); + } + for (size_t i = 0; i < outputSize; ++i) + { + ostr.write((const char*)&distantEmb[i * dim], dim); + ostr.write((const char*)&distantEmbScale[i], sizeof(uint16_t)); + ostr.write((const char*)&distantEmbBias[i], sizeof(uint16_t)); + ostr.write((const char*)&distantConfidence[i], sizeof(uint16_t)); + } + ostr.write((const char*)positionConfidence.data(), positionConfidence.size() * sizeof(uint16_t)); + ostr.write((const char*)distantMask.data(), distantMask.size()); + return mem; + } + + template + void* CoNgramModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder>::findBestPath; + } + + template + void* CoNgramModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + + template + inline std::unique_ptr createOptimizedModelWithWindowSize(utils::MemoryObject&& mem) + { + auto& header = *reinterpret_cast(mem.get()); + if (!useDistantTokens) + { + return make_unique>(std::move(mem)); + } + + switch (header.windowSize) + { + case 7: + return make_unique>(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `window_size` : " + std::to_string((size_t)header.windowSize) }; + }; + } + + template + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + { + auto& header = *reinterpret_cast(mem.get()); + switch (header.keySize) + { + case 2: + return createOptimizedModelWithWindowSize(std::move(mem)); + case 3: + return createOptimizedModelWithWindowSize(std::move(mem)); + case 4: + return createOptimizedModelWithWindowSize(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; + } + } + + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + + template + struct CreateOptimizedModelGetter + { + template + struct Wrapper + { + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), useDistantTokens, quantized>; + }; + }; + + std::unique_ptr CoNgramModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool useDistantTokens, bool quantized) + { + static tp::Table tables[] = { + CreateOptimizedModelGetter{}, + CreateOptimizedModelGetter{}, + }; + static tp::Table quantTables[] = { + CreateOptimizedModelGetter{}, + CreateOptimizedModelGetter{}, + }; + + if (quantized) + { + auto fn = quantTables[useDistantTokens ? 1 : 0][static_cast(archType)]; + if (fn) return (*fn)(std::move(mem)); + std::cerr << "Quantization is not supported for " << archToStr(archType) << ". Fall back to non-quantized model." << std::endl; + } + auto fn = tables[useDistantTokens ? 1 : 0][static_cast(archType)]; + if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; + return (*fn)(std::move(mem)); + } + } +} diff --git a/src/CoNgramModel.hpp b/src/CoNgramModel.hpp new file mode 100644 index 00000000..891cff02 --- /dev/null +++ b/src/CoNgramModel.hpp @@ -0,0 +1,446 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include "ArchAvailable.h" +#include "search.h" +#include "streamvbyte.h" +#include "MathFunc.h" + +namespace kiwi +{ + namespace lm + { + template + class CoNgramState; + + template + class CoNgramModel : public CoNgramModelBase + { + using MyNode = Node; + + std::unique_ptr nodeData; + std::unique_ptr keyValueData; + const uint8_t* alignedKeyValueData = nullptr; + std::unique_ptr allRootValueData; + std::unique_ptr allEmbs; + const uint8_t* contextEmbPtr = nullptr; // [numContexts, (dim + scale? + bias + confid + vts)] + const uint8_t* outputEmbPtr = nullptr; // [numOutputs, (dim + scale? + sum?)] + const uint8_t* distantEmbPtr = nullptr; // [numOutputs, (dim + scale? + bias + confid + pad?)] + const float* positionConfidPtr = nullptr; + const uint8_t* distantMaskPtr = nullptr; + + inline size_t contextEmbStride() const + { + if (quantized) return header.dim + (windowSize > 0 ? 4 : 2) * sizeof(float); + else return (header.dim + (windowSize > 0 ? 3 : 1)) * sizeof(float); + } + + inline size_t outputEmbStride() const + { + if (quantized) return header.dim + 2 * sizeof(float); + else return header.dim * sizeof(float); + } + + inline size_t distantEmbStride() const + { + if (quantized) return header.dim + 4 * sizeof(float); + else return (header.dim + 2) * sizeof(float); + } + + inline const float* getContextEmb(uint32_t idx) const + { + return reinterpret_cast(contextEmbPtr + idx * contextEmbStride()); + } + + inline const uint8_t* getContextQuantEmb(uint32_t idx) const + { + return contextEmbPtr + idx * contextEmbStride(); + } + + inline float getContextBias(uint32_t idx) const + { + const size_t offset = quantized ? + (header.dim + sizeof(float)) + : (header.dim * sizeof(float)); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline float getContextConfid(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 2 * sizeof(float)) + : (header.dim + 1) * sizeof(float); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline float getContextValidTokenSum(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 3 * sizeof(float)) + : (header.dim + 2) * sizeof(float); + return *reinterpret_cast(contextEmbPtr + idx * contextEmbStride() + offset); + } + + inline const float* getOutputEmb(uint32_t idx) const + { + return reinterpret_cast(outputEmbPtr + idx * outputEmbStride()); + } + + inline const int8_t* getOutputQuantEmb(uint32_t idx) const + { + return reinterpret_cast(outputEmbPtr + idx * outputEmbStride()); + } + + inline const float* getDistantEmb(uint32_t idx) const + { + return reinterpret_cast(distantEmbPtr + idx * distantEmbStride()); + } + + inline const uint8_t* getDistantQuantEmb(uint32_t idx) const + { + return distantEmbPtr + idx * distantEmbStride(); + } + + inline float getDistantBias(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + sizeof(float)) + : (header.dim * sizeof(float)); + return *reinterpret_cast(distantEmbPtr + idx * distantEmbStride() + offset); + } + + inline float getDistantConfid(uint32_t idx) const + { + if (windowSize == 0) return 0; + const size_t offset = quantized ? + (header.dim + 2 * sizeof(float)) + : (header.dim + 1) * sizeof(float); + return *reinterpret_cast(distantEmbPtr + idx * distantEmbStride() + offset); + } + + MyNode* findLowerNode(MyNode* node, KeyType k) const + { + while (node->lower) + { + auto* lowerNode = node + node->lower; + auto* kvs = &alignedKeyValueData[lowerNode->nextOffset]; + int32_t found; + if ((found = nst::searchKV( + kvs, + lowerNode->numNexts, + k)) > 0) + { + return lowerNode + found; + } + node = lowerNode; + } + return node; + } + + uint32_t findLowerValue(MyNode* node, KeyType k) const + { + while (node->lower) + { + auto* lowerNode = node + node->lower; + auto* kvs = &alignedKeyValueData[lowerNode->nextOffset]; + int32_t found; + if ((found = nst::searchKV( + kvs, + lowerNode->numNexts, + k)) != 0) + { + if (found >= 0) + { + return lowerNode[found].value; + } + else + { + return -found; + } + } + node = lowerNode; + } + return node->value; + } + + public: + using VocabType = KeyType; + using LmStateType = CoNgramState; + + CoNgramModel(utils::MemoryObject&& mem); + + ModelType getType() const override + { + if (quantized) + { + if (windowSize > 0) return ModelType::congGlobal; + else return ModelType::cong; + } + else + { + if (windowSize > 0) return ModelType::congGlobalFp32; + else return ModelType::congFp32; + } + } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; + + uint32_t progressContextNode(int32_t& nodeIdx, KeyType next) const + { + if (std::is_same::value) + { + return progressContextNodeVl(nodeIdx, next); + } + + static constexpr size_t tMax = (1 << 16) - (1 << 10) * 2; + if (next < tMax) + { + return progressContextNodeVl(nodeIdx, next); + } + next -= tMax; + const size_t high = next >> 10, low = next & 0x3FF; + progressContextNodeVl(nodeIdx, tMax + high); + return progressContextNodeVl(nodeIdx, tMax + (1 << 10) + low); + } + + uint32_t progressContextNodeVl(int32_t& nodeIdx, VlKeyType next) const + { + static constexpr size_t N = 64 / sizeof(VlKeyType) + 1; + while (1) + { + int32_t v; + auto* node = &nodeData[nodeIdx]; + auto* kvs = &alignedKeyValueData[node->nextOffset]; + if (node != nodeData.get()) + { + PREFETCH_T0(node + node->lower); + if ((v = nst::searchKV( + kvs, + node->numNexts, next + )) == 0) + { + if (!node->lower) return 0; + nodeIdx += node->lower; + PREFETCH_T0(&alignedKeyValueData[nodeData[nodeIdx].nextOffset]); + continue; + } + } + else + { + v = allRootValueData[next]; + if (v == 0) + { + return 0; + } + } + + // non-leaf node + if (v > 0) + { + nodeIdx += v; + return nodeData[nodeIdx].value; + } + // leaf node + else + { + while (node->lower) + { + node += node->lower; + auto* lkvs = &alignedKeyValueData[node->nextOffset]; + int32_t lv; + if (node != nodeData.get()) + { + if ((lv = nst::searchKV( + lkvs, + node->numNexts, next + )) != 0) + { + if (lv > 0) + { + node += lv; + nodeIdx = node - &nodeData[0]; + return (uint32_t)-v; + } + } + } + else + { + lv = allRootValueData[next]; + if (lv > 0) + { + nodeIdx = lv; + return (uint32_t)-v; + } + } + } + nodeIdx = 0; + return (uint32_t)-v; + } + } + } + + inline bool distantTokenMask(uint32_t idx) const + { + if (windowSize > 0) return (distantMaskPtr[idx / 8] & (1 << (idx % 8))) != 0; + else return false; + } + + float progress(int32_t& nodeIdx, + uint32_t& contextIdx, + std::array& history, + KeyType next) const; + + template + LmStateType nextState(const typename std::enable_if<(_windowSize > 0), LmStateType>::type& state, KeyType next, + bool cacheIsValid, std::pair& cache) const; + + template + LmStateType nextState(const typename std::enable_if<_windowSize == 0, LmStateType>::type& state, KeyType next, + bool cacheIsValid, std::pair& cache) const; + + /* + * 총 prevStateSize개의 상태와 nextIdSize개의 다음 토큰을 받아서, 각 상태별로 다음 토큰이 등장할 확률을 계산하고 새 상태를 반환한다. + * 새 상태값은 outStates에 저장되고, 각 상태별 확률값은 outScores에 저장된다. + * nextIdSize개의 다음 토큰 중 마지막 numValidDistantTokens개의 토큰은 유효한 distant 토큰으로 처리된다. + */ + void progressMatrix(const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + + struct TLSForProgressMatrix; + + inline void progressMatrixWSort(TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + + inline void progressMatrixWOSort(TLSForProgressMatrix& tls, const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + + void progressMatrixNoWindow(const LmStateType* prevStates, const KeyType* nextIds, + size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens, + LmStateType* outStates, float* outScores) const; + }; + + template + struct CoNgramState : public LmStateBase> + { + int32_t node = 0; + uint32_t contextIdx; + std::array history; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; + + CoNgramState() : contextIdx{ 0 }, history { { 0, } } + { + } + + CoNgramState(const ILangModel* lm) : contextIdx{ 0 }, history{ {0,} } + { + } + + CoNgramState(int32_t _node) : node{ _node } // partially initialized state + { + } + + bool operator==(const CoNgramState& other) const + { + static constexpr size_t cmpStart = windowSize / 2; + if (node != other.node) return false; + if (memcmp(&history[cmpStart], &other.history[cmpStart], (windowSize - cmpStart) * sizeof(VocabTy))) + { + return false; + } + return true; + } + + float nextImpl(const CoNgramModel* lm, VocabTy next) + { + return lm->progress(node, contextIdx, history, next); + } + }; + + template + struct CoNgramState<0, _arch, VocabTy, VlVocabTy, quantized> : public LmStateBase> + { + int32_t node = 0; + uint32_t contextIdx; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = true; + static constexpr size_t windowSize = 0; + + CoNgramState() : contextIdx{ 0 } + { + } + + CoNgramState(const ILangModel* lm) : contextIdx{ 0 } + { + } + + CoNgramState(int32_t _node) : node{ _node } // partially initialized state + { + } + + bool operator==(const CoNgramState& other) const + { + return node == other.node; + } + + float nextImpl(const CoNgramModel* lm, VocabTy next) + { + std::array history = { {0,} }; + return lm->progress(node, contextIdx, history, next); + } + }; + } + + static constexpr size_t largePrime = sizeof(size_t) == 8 ? 2305843009213693951ll : 2654435761ll; + + inline size_t rol(size_t x, size_t r) + { + return (x << r) | (x >> (sizeof(size_t) * 8 - r)); + } + + template<> + struct Hash + { + size_t operator()(uint32_t v) const + { + return ((size_t)v * largePrime) ^ rol((size_t)v, sizeof(size_t) * 4 + 1); + } + }; + + template + struct Hash> + { + size_t operator()(const lm::CoNgramState& state) const + { + size_t ret = Hash{}(state.node); + static constexpr size_t cmpStart = windowSize - sizeof(size_t) / sizeof(VocabTy); + size_t h = *reinterpret_cast(&state.history[cmpStart]); + h = (h * largePrime) ^ rol(h, sizeof(size_t) * 4 - 1); + ret = h ^ rol(ret, 3); + return ret; + } + }; + + template + struct Hash> + { + size_t operator()(const lm::CoNgramState<0, arch, VocabTy, VlVocabTy, quantized>& state) const + { + size_t ret = Hash{}(state.node); + return ret; + } + }; +} diff --git a/src/Combiner.cpp b/src/Combiner.cpp index 8c72606c..9d21608e 100644 --- a/src/Combiner.cpp +++ b/src/Combiner.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -779,7 +780,7 @@ void RuleSet::loadRules(istream& istr) while (getline(istr, line)) { if (line[0] == '#') continue; - while (!line.empty() && line.back() < 0x80 && isSpace(line.back())) line.pop_back(); + while (!line.empty() && ((uint8_t)line.back() < 0x80) && isSpace(line.back())) line.pop_back(); if (line.empty()) continue; auto fields = split(line, '\t'); @@ -789,8 +790,8 @@ void RuleSet::loadRules(istream& istr) } else if (fields.size() == 2) { - lTag = fields[0].to_string(); - rTag = fields[1].to_string(); + lTag = fields[0]; + rTag = fields[1]; } else { @@ -807,13 +808,13 @@ void RuleSet::loadRules(istream& istr) "+ignorercond", }; - transform(fields[3].begin(), fields[3].end(), const_cast(fields[3].begin()), static_cast(tolower)); + transform(fields[3].begin(), fields[3].end(), const_cast(fields[3].data()), static_cast(tolower)); for (auto f : split(fields[3], ',')) { size_t t = find(fs.begin(), fs.end(), f) - fs.begin(); if (t >= fs.size()) { - throw runtime_error{ "invalid feature value: " + f.to_string()}; + throw runtime_error{ "invalid feature value: " + string{ f } }; } switch (t) @@ -1115,7 +1116,7 @@ Vector CompiledRule::combineImpl( auto it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { ret.emplace_back(move(p.str)); } @@ -1131,7 +1132,7 @@ Vector CompiledRule::combineImpl( it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { ret.emplace_back(move(p.str)); } @@ -1161,7 +1162,7 @@ tuple CompiledRule::combineOneImpl( auto it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { if(p.score >= 0) return make_tuple(p.str, p.leftEnd, p.rightBegin); KString ret; @@ -1181,7 +1182,7 @@ tuple CompiledRule::combineOneImpl( it = findRule(leftTag, rightTag, cv, cp); if (it != map.end()) { - for (auto& p : mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) + for (auto& p : visit(CombineVisitor{ leftForm, rightForm }, dfa[it->second])) { return make_tuple(p.str, p.leftEnd, p.rightBegin); } @@ -1210,13 +1211,13 @@ tuple CompiledRule::combineOneImpl( Vector> CompiledRule::testLeftPattern(U16StringView leftForm, size_t ruleId) const { - return mapbox::util::apply_visitor(SearchLeftVisitor{ leftForm, true }, dfa[ruleId]); + return visit(SearchLeftVisitor{ leftForm, true }, dfa[ruleId]); } Vector> CompiledRule::testRightPattern(U16StringView rightForm, size_t ruleId) const { - return mapbox::util::apply_visitor(SearchLeftVisitor{ rightForm, false }, dfaRight[ruleId]); + return visit(SearchLeftVisitor{ rightForm, false }, dfaRight[ruleId]); } vector> CompiledRule::testLeftPattern(U16StringView leftForm, POSTag leftTag, POSTag rightTag, CondVowel cv, CondPolarity cp) const @@ -1231,7 +1232,7 @@ vector> CompiledRule::testLeftPattern(U16Str auto it = findRule(leftTag, rightTag, cv, cp); if (it == map.end()) return ret; - auto p = mapbox::util::apply_visitor(SearchLeftVisitor{ l, true }, dfa[it->second]); + auto p = visit(SearchLeftVisitor{ l, true }, dfa[it->second]); ret.insert(ret.end(), p.begin(), p.end()); return ret; } @@ -1270,7 +1271,7 @@ UnorderedMap> CompiledRule::getRuleIdsByRightTag() const Vector CompiledRule::combine(U16StringView leftForm, U16StringView rightForm, size_t ruleId) const { - return mapbox::util::apply_visitor(CombineVisitor{ leftForm, rightForm }, dfa[ruleId]); + return visit(CombineVisitor{ leftForm, rightForm }, dfa[ruleId]); } vector CompiledRule::combine(U16StringView leftForm, POSTag leftTag, U16StringView rightForm, POSTag rightTag, CondVowel cv, CondPolarity cp) const diff --git a/src/Combiner.h b/src/Combiner.h index 0d407d83..ff83dc9b 100644 --- a/src/Combiner.h +++ b/src/Combiner.h @@ -1,10 +1,9 @@ #pragma once -#include +#include #include #include #include -#include "string_view.hpp" #include "bitset.hpp" namespace kiwi @@ -118,7 +117,7 @@ namespace kiwi template struct VariantFromTuple> { - using type = mapbox::util::variant; + using type = std::variant; }; } diff --git a/src/Dataset.cpp b/src/Dataset.cpp index f0b258f7..dcc0de38 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -1,19 +1,36 @@ -#include +#include #include #include "FrozenTrie.hpp" #include "RaggedVector.hpp" using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, - double _dropoutProb, double _dropoutProbOnHistory) +HSDataset::HSDataset(size_t _batchSize, + size_t _causalContextSize, + size_t _windowSize, + bool _exclusiveWindow, + size_t _workers, + double _dropoutProb, + double _dropoutProbOnHistory, + double _nounAugmentingProb, + size_t _generateUnlikelihoods) : workers{ _workers ? make_unique(_workers) : nullptr }, - dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, - dropoutOnHistory{ _dropoutProbOnHistory }, + dropout{ {1 - _dropoutProb, _dropoutProb / 3, _dropoutProb / 3, _dropoutProb / 6, _dropoutProb / 6} }, + dropoutProbOnHistory{ (float)_dropoutProbOnHistory }, + nounAugmentor{ { + 1 - _nounAugmentingProb, + _nounAugmentingProb / 12, + _nounAugmentingProb / 12, + _nounAugmentingProb / 12, + _nounAugmentingProb / 4, + _nounAugmentingProb / 4, + _nounAugmentingProb / 4} }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, causalContextSize{ _causalContextSize }, - windowSize{ _windowSize } + windowSize{ _windowSize }, + exclusiveWindow{ _exclusiveWindow }, + generateUnlikelihoods{ _generateUnlikelihoods } { } @@ -76,14 +93,223 @@ size_t HSDataset::numValidTokensInSent(size_t sentId) const size_t c = 0; for (auto t : sents.get()[sentId]) { + if (oovDict && t < 0) + { + POSTag tag = (*oovDict)[-t - 1].second; + t = getDefaultMorphemeId(clearIrregular(tag)); + } + if (tokenToVocab[t] == nonVocab) continue; ++c; } return c; } -template -size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +bool HSDataset::tokenizeUnlikely(Vector>& out, int32_t prefix, int32_t target, int32_t suffix, std::mt19937_64& rng) const +{ + auto form = (oovDict && target < 0) ? (*oovDict)[-target - 1].first : joinHangul((*forms)[(*morphemes)[target].kform].form); + + if (oovDict && prefix < 0) prefix = getDefaultMorphemeId((*oovDict)[-prefix - 1].second); + if (oovDict && suffix < 0) suffix = getDefaultMorphemeId((*oovDict)[-suffix - 1].second); + auto prefixForm = joinHangul((*forms)[(*morphemes)[prefix].kform].form); + auto suffixForm = joinHangul((*forms)[(*morphemes)[suffix].kform].form); + if (form.size() < 2) return false; + auto blocklist = kiwiInst->findMorpheme(form); + std::unordered_set blockset(blocklist.begin(), blocklist.end()); + + thread_local std::vector pretokenized; + pretokenized.clear(); + pretokenized.emplace_back(0, 1, std::vector{ BasicToken(prefixForm, -1, -1, (*morphemes)[prefix].tag) }); + pretokenized.emplace_back(form.size() + 1, form.size() + 2, std::vector{ BasicToken(suffixForm, -1, -1, (*morphemes)[suffix].tag) }); + + form.insert(form.begin(), ' '); + form.push_back(' '); + auto res = kiwiInst->analyze(form, 8, Match::allWithNormalizing, &blockset, pretokenized); + thread_local Vector validResIdx; + validResIdx.clear(); + for (size_t i = 0; i < res.size(); ++i) + { + auto& tokens = res[i].first; + if (tokens.size() <= 3) continue; + if (std::all_of(tokens.begin() + 1, tokens.end() - 1, [&](const TokenInfo& t) + { + return t.morph && !t.morph->getForm().empty() /*&& t.morph->lmMorphemeId != getDefaultMorphemeId(t.morph->tag)*/; + })) + { + validResIdx.emplace_back(i); + } + } + if (validResIdx.empty()) return false; + const float r = std::generate_canonical(rng); + auto& tokens = res[validResIdx[(size_t)(r * (float)validResIdx.size())]].first; + for (size_t i = 1; i < tokens.size() - 1; ++i) + { + out.emplace_back(tokens[i].morph->lmMorphemeId, tokens[i].morph->lmMorphemeId); + } + return true; +} + +inline int32_t getInput(int32_t t, const Vector>* oovDict) +{ + if (oovDict && t < 0) + { + POSTag tag = (*oovDict)[-t - 1].second; + return getDefaultMorphemeId(clearIrregular(tag)); + } + return t; +} + +inline int32_t getOutput(int32_t t, const Vector>* oovDict) +{ + return getInput(t, oovDict); +} + +inline int32_t getInput(const std::pair& t, const Vector>* oovDict) +{ + return getInput(t.first, oovDict); +} + +inline int32_t getOutput(const std::pair& t, const Vector>* oovDict) +{ + return getOutput(t.second, oovDict); +} + +template +void HSDataset::prepareInOutData(Deque& inData, Deque& outData, const Vector& tokens, std::mt19937_64& rng) const +{ + thread_local Deque history; + thread_local Vector contextualTokens; + if (windowSize) + { + history.clear(); + history.resize(windowSize, -1); + if (windowTokenValidness[getInput(tokens[0], oovDict.get())]) + { + history.back() = tokenToVocab[getInput(tokens[0], oovDict.get())]; + } + } + + if (causalContextSize && contextualMapper.size()) + { + auto* node = contextualMapper.root(); + contextualTokens.clear(); + contextualTokens.reserve(tokens.size()); + for (size_t i = 0; i < tokens.size(); ++i) + { + const int32_t v = tokenToVocab[getInput(tokens[i], oovDict.get())]; + auto* next = node->template nextOpt(contextualMapper, v); + while (!next) + { + node = node->fail(); + if (!node) break; + next = node->template nextOpt(contextualMapper, v); + } + if (next) + { + auto val = next->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + contextualTokens.emplace_back(val - 1); + } + else if (contextualMapper.hasSubmatch(val)) + { + auto sub = next->fail(); + for (; sub; sub = sub->fail()) + { + val = sub->val(contextualMapper); + if (contextualMapper.hasMatch(val)) + { + break; + } + } + if (sub) contextualTokens.emplace_back(val - 1); + else contextualTokens.emplace_back(nonVocab); + } + node = next; + } + else + { + contextualTokens.emplace_back(nonVocab); + node = contextualMapper.root(); + } + } + } + + int32_t lastV = nonVocab; + for (size_t i = 1; i < tokens.size(); ++i) + { + const int32_t v = tokenToVocab[getInput(tokens[i], oovDict.get())]; + if (v == nonVocab) + { + continue; + } + const int32_t outV = getOutput(tokens[i], oovDict.get()) == 0 ? nonVocab : tokenToVocab[getOutput(tokens[i], oovDict.get())]; + + if (causalContextSize) + { + for (size_t j = 0; j < causalContextSize; ++j) + { + if (i + j < causalContextSize) + { + if (outV != nonVocab) inData.emplace_back(nonVocab); + } + else if (contextualMapper.size()) + { + if (outV != nonVocab) inData.emplace_back(contextualTokens[i + j - causalContextSize]); + } + else + { + auto t = getInput(tokens[i + j - causalContextSize], oovDict.get()); + if (dropoutProbOnHistory > 0 && std::generate_canonical(rng) < dropoutProbOnHistory) + { + t = getDefaultMorphemeId((*morphemes)[t].tag); + } + if (outV != nonVocab) inData.emplace_back(tokenToVocab[t]); + } + } + } + if (windowSize) + { + if (windowTokenValidness[v]) + { + if (outV != nonVocab) std::copy(history.begin(), history.end(), std::back_inserter(inData)); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = v; + } + else + { + history.pop_front(); + history.push_back(v); + } + } + else + { + if (outV != nonVocab) inData.resize(inData.size() + windowSize, -1); + if (exclusiveWindow) + { + if (lastV != nonVocab) + { + history.pop_front(); + history.push_back(lastV); + } + lastV = nonVocab; + } + } + } + + if (outV != nonVocab) outData.emplace_back(v); + } +} + +template +size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + UlInTy unlikelihoodIn, UlOutTy unlikelihoodOut, size_t* unlikelihoodSize) { const auto& prepareNext = [&](size_t, size_t localId, size_t sentFirst, size_t sentLast) { @@ -97,48 +323,101 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, tokens.emplace_back(sent[0]); for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p) { - auto t = *p; - switch (dropout(local.rng)) + int32_t t = *p; + int32_t tWithOOV = *p; + if (oovDict && t < 0) { - case 0: // no dropout - tokens.emplace_back(t); + t = getDefaultMorphemeId((*oovDict)[-t - 1].second); + } + int32_t t1 = *(p + 1); + if (oovDict && t1 < 0) + { + t1 = getDefaultMorphemeId((*oovDict)[-t1 - 1].second); + } + const auto nounAugment = ((*morphemes)[t].tag == POSTag::nnp && !isSpecialClass((*morphemes)[t1].tag)) ? nounAugmentor(local.rng) : 0; + + switch (nounAugment) + { + case 1: // circumfix with sso and ssc + tokens.emplace_back(getDefaultMorphemeId(POSTag::sso)); + break; + case 2: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::singleQuoteOpen]); break; - case 1: // replacement - tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + case 3: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::doubleQuoteOpen]); break; - case 2: // deletion + case 4: // circumfix with sw + tokens.emplace_back(getDefaultMorphemeId(POSTag::sw)); break; - case 3: // insertion - tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); - tokens.emplace_back(t); + case 5: // replace with w_hashtag + tokens.emplace_back(getDefaultMorphemeId(POSTag::w_hashtag)); + break; + case 6: // replace with sh + tokens.emplace_back(getDefaultMorphemeId(POSTag::sh)); + break; + } + + if (nounAugment < 5) + { + switch (dropout(local.rng)) + { + case 0: // no dropout + tokens.emplace_back(tWithOOV); + break; + case 1: // replacement + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; + case 2: // deletion + break; + case 3: // insertion + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + tokens.emplace_back(tWithOOV); + break; + case 4: // insertion + tokens.emplace_back(tWithOOV); + tokens.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; + } + } + + switch (nounAugment) + { + case 1: // circumfix with sso and ssc + tokens.emplace_back(getDefaultMorphemeId(POSTag::ssc)); + break; + case 2: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::singleQuoteClose]); + break; + case 3: + tokens.emplace_back(specialMorphIds[(size_t)Kiwi::SpecialMorph::doubleQuoteClose]); + break; + case 4: // circumfix with sw + tokens.emplace_back(getDefaultMorphemeId(POSTag::sw)); break; } } tokens.emplace_back(sent[sent.size() - 1]); + const size_t offset = local.outData.size(); + prepareInOutData(local.inData, local.outData, tokens, local.rng); local.lmLProbsBuf.resize(tokens.size()); local.outNgramNodeBuf.resize(tokens.size()); - if (knlm) + if (auto knlm = std::dynamic_pointer_cast(langModel)) { knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); } - - auto& history = local.historyBuf; - history.clear(); - if (windowSize) + for (size_t i = 1; i < tokens.size(); ++i) { - history.resize(windowSize, -1); - if (windowTokenValidness[tokens[0]]) + int32_t t = tokens[i]; + if (oovDict && t < 0) { - history.back() = tokenToVocab[tokens[0]]; + t = getDefaultMorphemeId((*oovDict)[-t - 1].second); } - } - for (size_t i = 1; i < tokens.size(); ++i) - { - int32_t v = tokenToVocab[tokens[i]]; + int32_t v = tokenToVocab[t]; if (v == nonVocab) { - size_t r = local.outData.size() / batchSize; + size_t r = (offset + i - 1) / batchSize; if (local.restLmLProbsData.size() <= r) { local.restLmLProbsData.resize(r + 1); @@ -149,40 +428,6 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, continue; } - if (causalContextSize) - { - for (size_t j = 0; j < causalContextSize; ++j) - { - if (i + j < causalContextSize) - { - local.inData.emplace_back(nonVocab); - } - else - { - auto t = tokens[i + j - causalContextSize]; - if (dropoutOnHistory.p() > 0 && dropoutOnHistory(local.rng)) - { - t = getDefaultMorphemeId((*morphemes)[t].tag); - } - local.inData.emplace_back(tokenToVocab[t]); - } - } - } - if (windowSize) - { - if (windowTokenValidness[v]) - { - std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); - history.pop_front(); - history.push_back(v); - } - else - { - local.inData.resize(local.inData.size() + windowSize, -1); - } - } - - local.outData.emplace_back(v); local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]); local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]); } @@ -193,6 +438,35 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.restLmLProbsData.resize(r + 1); local.restLmLProbsCntData.resize(r + 1); } + + if (doesGenerateUnlikelihoods()) + { + local.unlikelihoodBuf.clear(); + local.unlikelihoodBuf.emplace_back(tokens[0], 0); + for (size_t i = 1; i < tokens.size() - 1; ++i) + { + if (oovDict && tokens[i] < 0) + { + if (!tokenizeUnlikely(local.unlikelihoodBuf, tokens[i - 1], tokens[i], tokens[i + 1], local.rng)) + { + local.unlikelihoodBuf.emplace_back(tokens[i], 0); + } + continue; + } + + auto& morph = (*morphemes)[tokens[i]]; + if (tokens[i] < generateUnlikelihoods + || !(morph.tag == POSTag::nng || morph.tag == POSTag::nnp) + || getDefaultMorphemeId(morph.tag) == tokens[i] + || !tokenizeUnlikely(local.unlikelihoodBuf, tokens[i - 1], tokens[i], tokens[i + 1], local.rng)) + { + local.unlikelihoodBuf.emplace_back(tokens[i], 0); + } + } + local.unlikelihoodBuf.emplace_back(tokens.back(), 0); + + prepareInOutData(local.unlikelihoodInData, local.unlikelihoodOutData, local.unlikelihoodBuf, local.rng); + } } return localId; }; @@ -260,13 +534,20 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, auto& l = locals[localId]; - size_t rest = std::min(l.outData.size(), batchSize); + const size_t rest = std::min(l.outData.size(), batchSize); + const size_t unlikelihoodRest = std::min(l.unlikelihoodOutData.size(), batchSize); std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in); std::copy(l.outData.begin(), l.outData.begin() + rest, out); std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs); std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode); restLmOut = l.restLmLProbsData.front(); restLmCntOut = l.restLmLProbsCntData.front(); + if (doesGenerateUnlikelihoods() && unlikelihoodIn && unlikelihoodOut) + { + std::copy(l.unlikelihoodInData.begin(), l.unlikelihoodInData.begin() + unlikelihoodRest * (causalContextSize + windowSize), unlikelihoodIn); + std::copy(l.unlikelihoodOutData.begin(), l.unlikelihoodOutData.begin() + unlikelihoodRest, unlikelihoodOut); + if (unlikelihoodSize) *unlikelihoodSize = unlikelihoodRest; + } l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize)); l.outData.erase(l.outData.begin(), l.outData.begin() + rest); @@ -274,21 +555,29 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest); l.restLmLProbsData.pop_front(); l.restLmLProbsCntData.pop_front(); + if (doesGenerateUnlikelihoods() && unlikelihoodIn && unlikelihoodOut) + { + l.unlikelihoodInData.erase(l.unlikelihoodInData.begin(), l.unlikelihoodInData.begin() + unlikelihoodRest * (causalContextSize + windowSize)); + l.unlikelihoodOutData.erase(l.unlikelihoodOutData.begin(), l.unlikelihoodOutData.begin() + unlikelihoodRest); + } return rest; } -size_t HSDataset::next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +size_t HSDataset::next(int32_t* in, int32_t* out, float* lmLProbs, uint32_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int32_t* unlikelihoodIn, int32_t* unlikelihoodOut, size_t* unlikelihoodSize) { - return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut); + return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut, unlikelihoodIn, unlikelihoodOut, unlikelihoodSize); } -size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut) +size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut, + int64_t* unlikelihoodIn, int64_t* unlikelihoodOut, size_t* unlikelihoodSize) { - return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut); + return _next(in, out, lmLProbs, outNgramNode, restLmOut, restLmCntOut, unlikelihoodIn, unlikelihoodOut, unlikelihoodSize); } size_t HSDataset::ngramNodeSize() const { + auto knlm = std::dynamic_pointer_cast(langModel); return knlm ? knlm->nonLeafNodeSize() : 0; } @@ -328,7 +617,7 @@ std::vector kiwi::HSDataset::estimVocabFrequency() const return ret; } -Range::const_iterator> HSDataset::getSent(size_t idx) const +Range::const_iterator> HSDataset::getSent(size_t idx) const { return sents.get()[idx]; } @@ -357,13 +646,17 @@ std::vector HSDataset::getAugmentedSent(size_t idx) ret.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); ret.emplace_back(t); break; + case 4: // insertion + ret.emplace_back(t); + ret.emplace_back(getDefaultMorphemeId((*morphemes)[t].tag)); + break; } } ret.emplace_back(*sent.rbegin()); return ret; } -std::vector, size_t>> kiwi::HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers) const +std::vector, size_t>> HSDataset::extractPrefixes(size_t minCnt, size_t maxLength, size_t numWorkers, bool exclusiveCnt) const { using Pair = std::pair, size_t>; std::vector ret; @@ -373,15 +666,58 @@ std::vector, size_t>> kiwi::HSDataset::extractPr counter.addArray(&*sent.begin(), &*sent.end()); } auto trie = counter.count(); - trie.traverse([&](size_t cnt, const std::vector& prefix) + if (exclusiveCnt) { - if (cnt < minCnt) return; - if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + Vector, size_t>> cnts_by_length(maxLength); + trie.traverse([&](size_t cnt, const std::vector& prefix) + { + if (cnt < minCnt) return; + if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + { + return; + } + Vector p(prefix.begin(), prefix.end()); + cnts_by_length[p.size() - 1].emplace(move(p), cnt); + }); + + Vector suffix; + suffix.reserve(maxLength); + for (size_t i = 1; i < maxLength; ++i) { - return; + for (auto& p : cnts_by_length[i]) + { + suffix.clear(); + suffix.insert(suffix.end(), p.first.begin() + 1, p.first.end()); + auto it = cnts_by_length[i - 1].find(suffix); + if (it == cnts_by_length[i - 1].end() || it->second < p.second) + { + throw std::runtime_error("This should not happen"); + } + it->second -= p.second; + } } - ret.emplace_back(prefix, cnt); - }); + + for (auto& cnts : cnts_by_length) + { + for (auto& p : cnts) + { + if (p.second < minCnt) continue; + ret.emplace_back(std::vector{ p.first.begin(), p.first.end() }, p.second); + } + } + } + else + { + trie.traverse([&](size_t cnt, const std::vector& prefix) + { + if (cnt < minCnt) return; + if (std::find_if(prefix.begin(), prefix.end(), [](uint32_t t) { return t < 2; }) != prefix.end()) + { + return; + } + ret.emplace_back(prefix, cnt); + }); + } std::sort(ret.begin(), ret.end(), [](const Pair& a, const Pair& b) { diff --git a/src/FileUtils.cpp b/src/FileUtils.cpp index 07d219eb..b8376421 100644 --- a/src/FileUtils.cpp +++ b/src/FileUtils.cpp @@ -48,4 +48,19 @@ namespace kiwi f.exceptions(exc); return f; } + + bool isOpenable(const string& filePath) + { + ifstream ifs; + try + { + openFile(ifs, filePath); + } + catch (const IOException&) + { + return false; + } + return true; + } + } diff --git a/src/Joiner.cpp b/src/Joiner.cpp index c80eb13a..c0899453 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -162,7 +162,7 @@ namespace kiwi void Joiner::add(const u16string& form, POSTag tag, Space space) { - return add(nonstd::to_string_view(form), tag, space); + return add(toStringView(form), tag, space); } void Joiner::add(const char16_t* form, POSTag tag, Space space) @@ -212,339 +212,44 @@ namespace kiwi } } - AutoJoiner::~AutoJoiner() - { - reinterpret_cast(candBuf).~CandVector(); - } - - AutoJoiner::AutoJoiner(const AutoJoiner& o) - : kiwi{ o.kiwi } - { - new (&candBuf) CandVector{ reinterpret_cast(o.candBuf) }; - } + AutoJoiner::~AutoJoiner() = default; - AutoJoiner::AutoJoiner(AutoJoiner&& o) - : kiwi{ o.kiwi } - { - new (&candBuf) CandVector{ reinterpret_cast(o.candBuf) }; - } - - AutoJoiner& AutoJoiner::operator=(const AutoJoiner& o) - { - kiwi = o.kiwi; - reinterpret_cast(candBuf) = reinterpret_cast(o.candBuf); - return *this; - } + AutoJoiner::AutoJoiner(const AutoJoiner& o) = default; - AutoJoiner& AutoJoiner::operator=(AutoJoiner&& o) - { - kiwi = o.kiwi; - reinterpret_cast(candBuf) = reinterpret_cast(o.candBuf); - return *this; - } + AutoJoiner::AutoJoiner(AutoJoiner&& o) = default; - template - void AutoJoiner::add(size_t morphemeId, Space space, Vector>& candidates) - { - auto& morph = kiwi->morphemes[morphemeId]; - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, morph.lmMorphemeId); - cand.joiner.add(morph.getForm(), morph.tag, space); - } - - sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) - { - return a.score > b.score; - }); - } + AutoJoiner& AutoJoiner::operator=(const AutoJoiner& o) = default; - template - void AutoJoiner::foreachMorpheme(const Form* formHead, Func&& func) const - { - if (kiwi->isTypoTolerant()) - { - auto tformHead = reinterpret_cast(formHead); - do - { - if (tformHead->score() == 0) - { - for (auto m : tformHead->form(kiwi->forms.data()).candidate) - { - func(m); - } - } - ++tformHead; - } while (tformHead[-1].hash() == tformHead[0].hash()); - } - else - { - do - { - for (auto m : formHead->candidate) - { - func(m); - } - ++formHead; - } while (formHead[-1].form == formHead[0].form); - } - } - - template - void AutoJoiner::add(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) - { - const Form* formHead; - auto node = kiwi->formTrie.root(); - for (auto c : normalizeHangul(form)) - { - node = node->template nextOpt(kiwi->formTrie, c); - if (!node) break; - } - - // prevent unknown or partial tag - POSTag fixedTag = tag; - if (tag == POSTag::unknown || tag == POSTag::p) - { - fixedTag = POSTag::nnp; - } - - if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) - { - Vector cands; - foreachMorpheme(formHead, [&](const Morpheme* m) - { - if (areTagsEqual(m->tag, fixedTag, inferRegularity)) - { - cands.emplace_back(m); - } - }); - - if (cands.size() <= 1) - { - auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; - if (!cands.empty()) tag = cands[0]->tag; - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, lmId); - cand.joiner.add(form, tag, space); - } - } - else - { - size_t oSize = candidates.size(); - for (size_t i = 1; i < cands.size(); ++i) - { - for (size_t o = 0; o < oSize; ++o) - { - candidates.emplace_back(candidates[o]); - auto& n = candidates.back(); - n.score += n.lmState.next(kiwi->langMdl, cands[i]->lmMorphemeId); - n.joiner.add(form, cands[i]->tag, space); - } - } - for (size_t o = 0; o < oSize; ++o) - { - auto& n = candidates[o]; - n.score += n.lmState.next(kiwi->langMdl, cands[0]->lmMorphemeId); - n.joiner.add(form, cands[0]->tag, space); - } - - UnorderedMap> bestScoreByState; - for (size_t i = 0; i < candidates.size(); ++i) - { - auto& c = candidates[i]; - auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i)); - if (!inserted.second) - { - if (inserted.first->second.first < c.score) - { - inserted.first->second = make_pair(c.score, i); - } - } - } - - if (bestScoreByState.size() < candidates.size()) - { - Vector> newCandidates; - newCandidates.reserve(bestScoreByState.size()); - for (auto& p : bestScoreByState) - { - newCandidates.emplace_back(std::move(candidates[p.second.second])); - } - candidates = std::move(newCandidates); - } - } - } - else - { - auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); - for (auto& cand : candidates) - { - cand.score += cand.lmState.next(kiwi->langMdl, lmId); - cand.joiner.add(form, tag, space); - } - } - sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) - { - return a.score > b.score; - }); - } - - template - void AutoJoiner::addWithoutSearch(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) - { - if (inferRegularity) - { - auto node = kiwi->formTrie.root(); - for (auto c : normalizeHangul(form)) - { - node = node->template nextOpt(kiwi->formTrie, c); - if (!node) break; - } - - if (node) - { - if (const Form* formHead = node->val(kiwi->formTrie)) - { - Vector cands; - foreachMorpheme(formHead, [&](const Morpheme* m) - { - if (areTagsEqual(m->tag, tag, true)) - { - cands.emplace_back(m); - } - }); - - if (!cands.empty()) - { - tag = cands[0]->tag; - } - } - } - } - candidates[0].joiner.add(form, tag, space); - } - - template - void AutoJoiner::addWithoutSearch(size_t morphemeId, Space space, Vector>>& candidates) - { - auto& morph = kiwi->morphemes[morphemeId]; - for (auto& cand : candidates) - { - cand.joiner.add(morph.getForm(), morph.tag, space); - } - } - - struct AutoJoiner::AddVisitor - { - AutoJoiner* joiner; - U16StringView form; - POSTag tag; - bool inferRegularity; - Space space; - - AddVisitor(AutoJoiner* _joiner, U16StringView _form, POSTag _tag, bool _inferRegularity, Space _space) - : joiner{ _joiner }, form{ _form }, tag{ _tag }, inferRegularity{ _inferRegularity }, space{ _space } - { - } - - template - void operator()(Vector>>& o) const - { - return joiner->addWithoutSearch(form, tag, inferRegularity, space, o); - } - - template - void operator()(Vector>& o) const - { - return joiner->add(form, tag, inferRegularity, space, o); - } - }; - - struct AutoJoiner::AddVisitor2 - { - AutoJoiner* joiner; - size_t morphemeId; - Space space; - - AddVisitor2(AutoJoiner* _joiner, size_t _morphemeId, Space _space) - : joiner{ _joiner }, morphemeId{ _morphemeId }, space{ _space } - { - } - - template - void operator()(Vector>>& o) const - { - return joiner->addWithoutSearch(morphemeId, space, o); - } - - template - void operator()(Vector>& o) const - { - return joiner->add(morphemeId, space, o); - } - }; - - struct GetU16Visitor - { - vector>* rangesOut; - - GetU16Visitor(vector>* _rangesOut) - : rangesOut{ _rangesOut } - { - } - - template - u16string operator()(const Vector>& o) const - { - return o[0].joiner.getU16(rangesOut); - } - }; - - struct GetU8Visitor - { - vector>* rangesOut; - - GetU8Visitor(vector>* _rangesOut) - : rangesOut{ _rangesOut } - { - } - - template - string operator()(const Vector>& o) const - { - return o[0].joiner.getU8(rangesOut); - } - }; + AutoJoiner& AutoJoiner::operator=(AutoJoiner&& o) = default; void AutoJoiner::add(size_t morphemeId, Space space) { - return mapbox::util::apply_visitor(AddVisitor2{ this, morphemeId, space }, reinterpret_cast(candBuf)); + return (*dfAdd)(this, morphemeId, space, candBuf.get>>()); } void AutoJoiner::add(const u16string& form, POSTag tag, bool inferRegularity, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, nonstd::to_string_view(form), tag, inferRegularity, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, toStringView(form), tag, inferRegularity, space, candBuf.get>>()); } void AutoJoiner::add(const char16_t* form, POSTag tag, bool inferRegularity, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, U16StringView{ form }, tag, inferRegularity, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, U16StringView{ form }, tag, inferRegularity, space, candBuf.get>>()); } void AutoJoiner::add(U16StringView form, POSTag tag, Space space) { - return mapbox::util::apply_visitor(AddVisitor{ this, form, tag, false, space }, reinterpret_cast(candBuf)); + return (*dfAdd2)(this, form, tag, false, space, candBuf.get>>()); } u16string AutoJoiner::getU16(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU16Visitor{ rangesOut }, reinterpret_cast(candBuf)); + return candBuf.get>>()[0].joiner.getU16(rangesOut); } string AutoJoiner::getU8(vector>* rangesOut) const { - return mapbox::util::apply_visitor(GetU8Visitor{ rangesOut }, reinterpret_cast(candBuf)); + return candBuf.get>>()[0].joiner.getU8(rangesOut); } } } diff --git a/src/Joiner.hpp b/src/Joiner.hpp index bf195e2f..b7537039 100644 --- a/src/Joiner.hpp +++ b/src/Joiner.hpp @@ -1,10 +1,9 @@ #pragma once #include #include -#include +#include #include "Combiner.h" #include "StrUtils.h" -#include "LmState.hpp" using namespace std; @@ -12,40 +11,234 @@ namespace kiwi { namespace cmb { - namespace detail + template + void AutoJoiner::addImpl(size_t morphemeId, Space space, Vector>& candidates) + { + auto& morph = kiwi->morphemes[morphemeId]; + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), morph.lmMorphemeId); + cand.joiner.add(morph.getForm(), morph.tag, space); + } + + sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) + { + return a.score > b.score; + }); + } + + template + void AutoJoiner::foreachMorpheme(const Form* formHead, Func&& func) const { - template class Type, class> - struct VCUnpack2nd; + if (kiwi->isTypoTolerant()) + { + auto tformHead = reinterpret_cast(formHead); + do + { + if (tformHead->score() == 0) + { + for (auto m : tformHead->form(kiwi->forms.data()).candidate) + { + func(m); + } + } + ++tformHead; + } while (tformHead[-1].hash() == tformHead[0].hash()); + } + else + { + do + { + for (auto m : formHead->candidate) + { + func(m); + } + ++formHead; + } while (formHead[-1].form == formHead[0].form); + } + } + + template + void AutoJoiner::addImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) + { + const Form* formHead; + auto node = kiwi->formTrie.root(); + for (auto c : normalizeHangul(form)) + { + node = node->template nextOpt(kiwi->formTrie, c); + if (!node) break; + } + + // prevent unknown or partial tag + POSTag fixedTag = tag; + if (tag == POSTag::unknown || tag == POSTag::p) + { + fixedTag = POSTag::nnp; + } + + if (node && kiwi->formTrie.hasMatch(formHead = node->val(kiwi->formTrie))) + { + Vector cands; + foreachMorpheme(formHead, [&](const Morpheme* m) + { + if (areTagsEqual(m->tag, fixedTag, inferRegularity)) + { + cands.emplace_back(m); + } + }); + + if (cands.size() <= 1) + { + auto lmId = cands.empty() ? getDefaultMorphemeId(clearIrregular(fixedTag)) : cands[0]->lmMorphemeId; + if (!cands.empty()) tag = cands[0]->tag; + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), lmId); + cand.joiner.add(form, tag, space); + } + } + else + { + size_t oSize = candidates.size(); + for (size_t i = 1; i < cands.size(); ++i) + { + for (size_t o = 0; o < oSize; ++o) + { + candidates.emplace_back(candidates[o]); + auto& n = candidates.back(); + n.score += n.lmState.next(kiwi->langMdl.get(), cands[i]->lmMorphemeId); + n.joiner.add(form, cands[i]->tag, space); + } + } + for (size_t o = 0; o < oSize; ++o) + { + auto& n = candidates[o]; + n.score += n.lmState.next(kiwi->langMdl.get(), cands[0]->lmMorphemeId); + n.joiner.add(form, cands[0]->tag, space); + } - template class Type, std::ptrdiff_t ...arches> - struct VCUnpack2nd> + UnorderedMap> bestScoreByState; + for (size_t i = 0; i < candidates.size(); ++i) + { + auto& c = candidates[i]; + auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i)); + if (!inserted.second) + { + if (inserted.first->second.first < c.score) + { + inserted.first->second = make_pair(c.score, i); + } + } + } + + if (bestScoreByState.size() < candidates.size()) + { + Vector> newCandidates; + newCandidates.reserve(bestScoreByState.size()); + for (auto& p : bestScoreByState) + { + newCandidates.emplace_back(std::move(candidates[p.second.second])); + } + candidates = std::move(newCandidates); + } + } + } + else + { + auto lmId = getDefaultMorphemeId(clearIrregular(fixedTag)); + for (auto& cand : candidates) + { + cand.score += cand.lmState.next(kiwi->langMdl.get(), lmId); + cand.joiner.add(form, tag, space); + } + } + sort(candidates.begin(), candidates.end(), [](const cmb::Candidate& a, const cmb::Candidate& b) + { + return a.score > b.score; + }); + } + + template + void AutoJoiner::addWithoutSearchImpl2(U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) + { + if (inferRegularity) { - using type = std::tuple(arches)>>>...>; - }; + auto node = kiwi->formTrie.root(); + for (auto c : normalizeHangul(form)) + { + node = node->template nextOpt(kiwi->formTrie, c); + if (!node) break; + } + + if (node) + { + if (const Form* formHead = node->val(kiwi->formTrie)) + { + Vector cands; + foreachMorpheme(formHead, [&](const Morpheme* m) + { + if (areTagsEqual(m->tag, tag, true)) + { + cands.emplace_back(m); + } + }); - template class ... Types> - struct VCUnpack; + if (!cands.empty()) + { + tag = cands[0]->tag; + } + } + } + } + candidates[0].joiner.add(form, tag, space); + } - template class ... Types> - struct VCUnpack, Types...> + template + void AutoJoiner::addWithoutSearchImpl(size_t morphemeId, Space space, Vector>>& candidates) + { + auto& morph = kiwi->morphemes[morphemeId]; + for (auto& cand : candidates) { - using type = TupleCat>::type...>; - }; + cand.joiner.add(morph.getForm(), morph.tag, space); + } } - using CandTypeTuple = typename detail::VCUnpack::type, WrappedKnLM::type, WrappedKnLM::type, WrappedKnLM::type, - WrappedSbg<8, uint8_t>::type, WrappedSbg<8, uint16_t>::type, WrappedSbg<8, uint32_t>::type, WrappedSbg<8, uint64_t>::type - >::type; + template + struct AutoJoiner::Dispatcher + { + static void add(AutoJoiner* joiner, size_t morphemeId, Space space, Vector>& candidates) + { + return joiner->addImpl(morphemeId, space, candidates); + } - using CandVector = typename detail::VariantFromTuple::type; + static void add2(AutoJoiner* joiner, U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>& candidates) + { + return joiner->addImpl2(form, tag, inferRegularity, space, candidates); + } + }; + + template + struct AutoJoiner::Dispatcher> + { + static void add(AutoJoiner* joiner, size_t morphemeId, Space space, Vector>>& candidates) + { + return joiner->addWithoutSearchImpl(morphemeId, space, candidates); + } + + static void add2(AutoJoiner* joiner, U16StringView form, POSTag tag, bool inferRegularity, Space space, Vector>>& candidates) + { + return joiner->addWithoutSearchImpl2(form, tag, inferRegularity, space, candidates); + } + }; template AutoJoiner::AutoJoiner(const Kiwi& _kiwi, Candidate&& state) - : kiwi{ &_kiwi } + : kiwi{ &_kiwi }, candBuf{ Vector>{ { move(state) } } } { - new (&candBuf) CandVector(Vector>{ { move(state) } }); + using Dp = Dispatcher; + dfAdd = reinterpret_cast(&Dp::add); + dfAdd2 = reinterpret_cast(&Dp::add2); } + } } diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 7543038d..66af12ba 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -4,16 +4,17 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "FeatureTestor.h" #include "FrozenTrie.hpp" -#include "LmState.hpp" #include "StrUtils.h" #include "SortUtils.hpp" #include "serializer.hpp" #include "Joiner.hpp" #include "PathEvaluator.hpp" +#include "Kiwi.hpp" using namespace std; @@ -42,68 +43,19 @@ namespace kiwi } Kiwi::Kiwi(ArchType arch, - LangModel _langMdl, + const std::shared_ptr & _langMdl, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant) - : langMdl(_langMdl) + : langMdl{ _langMdl }, selectedArch{ arch } { - selectedArch = arch; dfSplitByTrie = (void*)getSplitByTrieFn(selectedArch, typoTolerant, continualTypoTolerant, lengtheningTypoTolerant); dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant); - - static tp::Table lmKnLM_8{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_32{ FindBestPathGetter::type>{} }; - static tp::Table lmKnLM_64{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_8{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_16{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_32{ FindBestPathGetter::type>{} }; - static tp::Table lmSbg_64{ FindBestPathGetter::type>{} }; - - if (langMdl.sbg) - { - switch (langMdl.sbg->getHeader().keySize) - { - case 1: - dfFindBestPath = (void*)lmSbg_8[static_cast(selectedArch)]; - break; - case 2: - dfFindBestPath = (void*)lmSbg_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmSbg_32[static_cast(selectedArch)]; - break; - case 8: - dfFindBestPath = (void*)lmSbg_64[static_cast(selectedArch)]; - break; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } - else if(langMdl.knlm) - { - switch (langMdl.knlm->getHeader().key_size) - { - case 1: - dfFindBestPath = (void*)lmKnLM_8[static_cast(selectedArch)]; - break; - case 2: - dfFindBestPath = (void*)lmKnLM_16[static_cast(selectedArch)]; - break; - case 4: - dfFindBestPath = (void*)lmKnLM_32[static_cast(selectedArch)]; - break; - case 8: - dfFindBestPath = (void*)lmKnLM_64[static_cast(selectedArch)]; - break; - default: - throw Exception{ "Wrong `lmKeySize`" }; - } - } + dfFindBestPath = langMdl ? langMdl->getFindBestPathFn() : nullptr; + dfNewJoiner = langMdl ? langMdl->getNewJoinerFn() : nullptr; } Kiwi::~Kiwi() = default; @@ -651,7 +603,7 @@ namespace kiwi inline void insertPathIntoResults( vector& ret, Vector& spStatesByRet, - const Vector& pathes, + const Vector& pathes, size_t topN, Match matchOptions, bool integrateAllomorph, @@ -677,7 +629,7 @@ namespace kiwi Vector selectedPathes(pathes.size()); for (size_t i = 0; i < ret.size(); ++i) { - auto pred = [&](const PathEvaluator::ChunkResult& p) + auto pred = [&](const PathResult& p) { return p.prevState == spStatesByRet[i]; }; @@ -1059,7 +1011,7 @@ namespace kiwi if (nodes.size() <= 2) continue; findPretokenizedGroupOfNode(nodeInWhichPretokenized, nodes, pretokenizedPrev, pretokenizedFirst); - Vector res = (*reinterpret_cast(dfFindBestPath))( + Vector res = (*reinterpret_cast(dfFindBestPath))( this, spStatesByRet, nodes.data(), @@ -1197,137 +1149,15 @@ namespace kiwi return _asyncAnalyzeEcho(move(str), move(pretokenized), matchOptions, blocklist); } - using FnNewAutoJoiner = cmb::AutoJoiner(Kiwi::*)() const; - - template class LmState> - struct NewAutoJoinerGetter - { - template - struct Wrapper - { - static constexpr FnNewAutoJoiner value = &Kiwi::newJoinerImpl(i)>>; - }; - }; - cmb::AutoJoiner Kiwi::newJoiner(bool lmSearch) const { - static tp::Table lmVoid{ NewAutoJoinerGetter{} }; - static tp::Table lmKnLM_8{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_16{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_32{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmKnLM_64{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_8{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_16{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_32{ NewAutoJoinerGetter::type>{} }; - static tp::Table lmSbg_64{ NewAutoJoinerGetter::type>{} }; - - const auto archIdx = static_cast(selectedArch); - if (lmSearch) { - size_t vocabTySize = langMdl.knlm->getHeader().key_size; - if (langMdl.sbg) - { - switch (vocabTySize) - { - case 1: - return (this->*lmSbg_8[archIdx])(); - case 2: - return (this->*lmSbg_16[archIdx])(); - case 4: - return (this->*lmSbg_32[archIdx])(); - case 8: - return (this->*lmSbg_64[archIdx])(); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize)}; - } - } - else - { - switch (vocabTySize) - { - case 1: - return (this->*lmKnLM_8[archIdx])(); - case 2: - return (this->*lmKnLM_16[archIdx])(); - case 4: - return (this->*lmKnLM_32[archIdx])(); - case 8: - return (this->*lmKnLM_64[archIdx])(); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } - } - } - else - { - return (this->*lmVoid[archIdx])(); - } - } - - using FnNewLmObject = std::unique_ptr(*)(const LangModel&); - - template - std::unique_ptr makeNewLmObject(const LangModel& lm) - { - return make_unique>(lm); - } - - template class LmState> - struct NewLmObjectGetter - { - template - struct Wrapper - { - static constexpr FnNewLmObject value = makeNewLmObject(i)>>; - }; - }; - - std::unique_ptr Kiwi::newLmObject() const - { - static tp::Table lmKnLM_8{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_16{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_32{ NewLmObjectGetter::type>{} }; - static tp::Table lmKnLM_64{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_8{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_16{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_32{ NewLmObjectGetter::type>{} }; - static tp::Table lmSbg_64{ NewLmObjectGetter::type>{} }; - - const auto archIdx = static_cast(selectedArch); - - size_t vocabTySize = langMdl.knlm->getHeader().key_size; - if (langMdl.sbg) - { - switch (vocabTySize) - { - case 1: - return (lmSbg_8[archIdx])(langMdl); - case 2: - return (lmSbg_16[archIdx])(langMdl); - case 4: - return (lmSbg_32[archIdx])(langMdl); - case 8: - return (lmSbg_64[archIdx])(langMdl); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } + return (*reinterpret_cast(dfNewJoiner))(this); } else { - switch (vocabTySize) - { - case 1: - return (lmKnLM_8[archIdx])(langMdl); - case 2: - return (lmKnLM_16[archIdx])(langMdl); - case 4: - return (lmKnLM_32[archIdx])(langMdl); - case 8: - return (lmKnLM_64[archIdx])(langMdl); - default: - throw Exception{ "invalid `key_size`=" + to_string(vocabTySize) }; - } + return cmb::AutoJoiner{ *this, cmb::Candidate>{ *combiningRule, langMdl.get() }}; } } diff --git a/src/Kiwi.hpp b/src/Kiwi.hpp new file mode 100644 index 00000000..4c39e668 --- /dev/null +++ b/src/Kiwi.hpp @@ -0,0 +1,19 @@ +#pragma once +#include + +namespace kiwi +{ + using FnNewJoiner = cmb::AutoJoiner(*)(const Kiwi*); + + template + cmb::AutoJoiner Kiwi::newJoinerImpl() const + { + return cmb::AutoJoiner{ *this, cmb::Candidate{ *combiningRule, langMdl.get() }}; + } + + template + cmb::AutoJoiner newJoinerWithKiwi(const Kiwi* kiwi) + { + return kiwi->newJoinerImpl(); + } +} diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index f522ac25..fcecc445 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -4,11 +4,11 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "StrUtils.h" #include "FrozenTrie.hpp" -#include "Knlm.hpp" #include "serializer.hpp" #include "count.hpp" #include "FeatureTestor.h" @@ -16,6 +16,7 @@ #include "RaggedVector.hpp" #include "SkipBigramTrainer.hpp" #include "SkipBigramModel.hpp" +#include "CoNgramModel.hpp" #include "SortUtils.hpp" using namespace std; @@ -212,11 +213,11 @@ auto KiwiBuilder::loadMorphemesFromTxt(std::istream& is, Fn&& filter) -> Morphem if (cpolar != CondPolarity::none) throw FormatException{ "wrong line: " + line }; cpolar = CondPolarity::non_adj; } - else if (f.starts_with(u"complex ")) + else if (f.size() >= 8 && f.substr(0, 8) == u"complex ") { if (complex) throw FormatException{ "wrong line: " + line }; complex = true; - complexStr = f.substr(8).to_string(); + complexStr = u16string{ f.substr(8) }; } else if (f[0] == u'=') { @@ -570,7 +571,8 @@ void KiwiBuilder::_addCorpusTo( std::istream& is, MorphemeMap& morphMap, double splitRatio, - RaggedVector* splitOut + RaggedVector* splitOut, + UnorderedMap, size_t>* oovDict ) const { Vector wids; @@ -689,8 +691,16 @@ void KiwiBuilder::_addCorpusTo( } if (t < POSTag::p && t != POSTag::unknown) + { + if (oovDict && (t == POSTag::nng || t == POSTag::nnp)) + { + auto oovId = oovDict->emplace(make_pair(f, t), oovDict->size()).first->second; + wids.emplace_back(-(ptrdiff_t)(oovId + 1)); + } + else { wids.emplace_back(getDefaultMorphemeId(t)); + } continue; } @@ -699,19 +709,28 @@ void KiwiBuilder::_addCorpusTo( } } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } -void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector* splitOut) const +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const { - return _addCorpusTo(out, is, morphMap, splitRatio, splitOut); + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); +} + +void KiwiBuilder::addCorpusTo(RaggedVector& out, std::istream& is, MorphemeMap& morphMap, + double splitRatio, RaggedVector* splitOut, UnorderedMap, size_t>* oovDict) const +{ + return _addCorpusTo(out, is, morphMap, splitRatio, splitOut, oovDict); } void KiwiBuilder::updateForms() @@ -738,12 +757,13 @@ void KiwiBuilder::updateForms() } } -void KiwiBuilder::updateMorphemes() +void KiwiBuilder::updateMorphemes(size_t vocabSize) { + if (vocabSize == 0) vocabSize = langMdl->vocabSize(); for (auto& m : morphemes) { if (m.lmMorphemeId > 0) continue; - if (m.tag == POSTag::p || (&m - morphemes.data() + m.combined) < langMdl.knlm->getHeader().vocab_size) + if (m.tag == POSTag::p || (&m - morphemes.data() + m.combined) < vocabSize) { m.lmMorphemeId = &m - morphemes.data(); } @@ -760,7 +780,7 @@ void KiwiBuilder::loadMorphBin(std::istream& is) for (auto& form : forms) { const size_t idx = &form - &forms[0]; - if (idx < defaultFormSize + 27) continue; + if (idx < defaultFormSize) continue; formMap.emplace(form.form, idx); } } @@ -770,8 +790,28 @@ void KiwiBuilder::saveMorphBin(std::ostream& os) const serializer::writeMany(os, serializer::toKey("KIWI"), forms, morphemes); } -KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOption _options, bool useSBG) - : detector{ modelPath, _numThreads }, options{ _options }, numThreads{ _numThreads ? _numThreads : thread::hardware_concurrency() } +ModelType KiwiBuilder::getModelType(const string& modelPath) +{ + if (isOpenable(modelPath + "/cong.mdl")) + { + return ModelType::congGlobal; + } + else if (isOpenable(modelPath + "/skipbigram.mdl")) + { + return ModelType::sbg; + } + else if (isOpenable(modelPath + "/sj.knlm")) + { + return ModelType::knlm; + } + else + { + return ModelType::none; + } +} + +KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOption _options, ModelType _modelType) + : detector{ modelPath, _numThreads }, options{ _options }, modelType{ _modelType }, numThreads{ _numThreads ? _numThreads : thread::hardware_concurrency() } { archType = getSelectedArch(ArchType::default_); @@ -780,10 +820,29 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio utils::imstream iss{ mm }; loadMorphBin(iss); } - langMdl.knlm = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType); - if (useSBG) + + if (modelType == ModelType::none) { - langMdl.sbg = sb::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); + modelType = getModelType(modelPath); + if (modelType == ModelType::none) + { + throw runtime_error{ "Cannot find any valid model files in the given path" }; + } + } + + if (modelType == ModelType::knlm || modelType == ModelType::knlmTransposed) + { + langMdl = lm::KnLangModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), archType, modelType == ModelType::knlmTransposed); + } + else if (modelType == ModelType::sbg) + { + langMdl = lm::SkipBigramModelBase::create(utils::MMap(modelPath + string{ "/sj.knlm" }), utils::MMap(modelPath + string{ "/skipbigram.mdl" }), archType); + } + else if (ModelType::cong <= modelType && modelType <= ModelType::congGlobalFp32 ) + { + langMdl = lm::CoNgramModelBase::create(utils::MMap(modelPath + string{ "/cong.mdl" }), archType, + (modelType == ModelType::congGlobal || modelType == ModelType::congGlobalFp32), + (modelType == ModelType::cong || modelType == ModelType::congGlobal)); } if (!!(options & BuildOption::loadDefaultDict)) @@ -834,24 +893,10 @@ void KiwiBuilder::initMorphemes() morphemes[defaultTagSize + 28].userScore = -1.5f; } -KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +template +unique_ptr KiwiBuilder::buildKnLM(const ModelBuildArgs& args, size_t lmVocabSize, MorphemeMap& realMorph) const { - if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) - { - throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; - } - - archType = getSelectedArch(ArchType::default_); - initMorphemes(); - - ifstream ifs; - auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) - { - return cnt >= args.minMorphCnt; - }); - updateForms(); - - RaggedVector sents; + RaggedVector sents; for (auto& path : args.corpora) { ifstream ifs; @@ -880,12 +925,8 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) } } } - - size_t lmVocabSize = 0; - for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); - lmVocabSize += 1; - Vector historyTx(lmVocabSize); + Vector historyTx(lmVocabSize); if (args.useLmTagHistory) { for (size_t i = 0; i < lmVocabSize; ++i) @@ -894,14 +935,25 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) } } - vector> bigramList; utils::ThreadPool pool; - if (args.numWorkers > 1) + if (args.numWorkers >= 1) { pool.~ThreadPool(); new (&pool) utils::ThreadPool{ args.numWorkers }; } - size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); + const size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); + std::vector minCnts; + if (args.lmMinCnts.size() == 1) + { + minCnts.clear(); + minCnts.resize(args.lmOrder, args.lmMinCnts[0]); + } + else if (args.lmMinCnts.size() == args.lmOrder) + { + minCnts = args.lmMinCnts; + } + + vector> bigramList; auto cntNodes = utils::count(sents.begin(), sents.end(), lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); // discount for bos node cnt if (args.useLmTagHistory) @@ -912,47 +964,72 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) { cntNodes.root().getNext(0)->val /= 2; } - std::vector minCnts; - if (args.lmMinCnts.size() == 1) + + return lm::KnLangModelBase::create(lm::KnLangModelBase::build( + cntNodes, + args.lmOrder, minCnts, + 2, 0, 1, 1e-5, + args.quantizeLm ? 8 : 0, + sizeof(VocabTy) == 2 ? args.compressLm : false, + &bigramList, + args.useLmTagHistory ? &historyTx : nullptr + ), archType); +} + + +KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +{ + if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) { - minCnts.clear(); - minCnts.resize(args.lmOrder, args.lmMinCnts[0]); + throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; } - else if (args.lmMinCnts.size() == args.lmOrder) + + archType = getSelectedArch(ArchType::default_); + initMorphemes(); + + ifstream ifs; + auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) { - minCnts = args.lmMinCnts; + return cnt >= args.minMorphCnt; + }); + updateForms(); + + size_t lmVocabSize = 0; + for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); + lmVocabSize += 1; + + if (lmVocabSize <= 0xFFFF) + { + langMdl = buildKnLM(args, lmVocabSize, realMorph); + } + else + { + langMdl = buildKnLM(args, lmVocabSize, realMorph); } - langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( - cntNodes, - args.lmOrder, minCnts, - 2, 0, 1, 1e-5, - args.quantizeLm ? 8 : 0, - args.compressLm, - &bigramList, - args.useLmTagHistory ? &historyTx : nullptr - ), archType); updateMorphemes(); } + namespace kiwi { + template class SBDataFeeder { - const RaggedVector& sents; + const RaggedVector& sents; const lm::KnLangModelBase* lm = nullptr; Vector> lmBuf; Vector> nodeBuf; public: - SBDataFeeder(const RaggedVector& _sents, const lm::KnLangModelBase* _lm, size_t numThreads = 1) + SBDataFeeder(const RaggedVector& _sents, const lm::KnLangModelBase* _lm, size_t numThreads = 1) : sents{ _sents }, lm{ _lm }, lmBuf(numThreads), nodeBuf(numThreads) { } - sb::FeedingData operator()(size_t i, size_t threadId = 0) + lm::FeedingData operator()(size_t i, size_t threadId = 0) { - sb::FeedingData ret; + lm::FeedingData ret; ret.len = sents[i].size(); if (lmBuf[threadId].size() < ret.len) { @@ -971,9 +1048,11 @@ namespace kiwi KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) : KiwiBuilder{ modelPath } { + using Vid = uint16_t; + auto realMorph = restoreMorphemeMap(); - sb::SkipBigramTrainer sbg; - RaggedVector sents; + lm::SkipBigramTrainer sbg; + RaggedVector sents; for (auto& path : args.corpora) { ifstream ifs; @@ -1050,7 +1129,9 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) return true; }; - sbg = sb::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, langMdl.knlm->nonLeafNodeSize() }; + auto* knlm = dynamic_cast(langMdl.get()); + + sbg = lm::SkipBigramTrainer{ sents, sbgTokenFilter, sbgPairFilter, 0, 150, 20, true, 0.333f, 1, args.sbgSize, knlm->nonLeafNodeSize() }; Vector lmLogProbs; Vector baseNodes; auto tc = sbg.newContext(); @@ -1071,7 +1152,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) lmLogProbs.resize(sent.size()); baseNodes.resize(sent.size()); } - langMdl.knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin()); + knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin()); //float sum = sbg.evaluate(&sent[0], lmLogProbs.data(), sent.size()); float sum = accumulate(lmLogProbs.begin() + 1, lmLogProbs.begin() + sent.size(), 0.); size_t cnt = sent.size() - 1; @@ -1088,7 +1169,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) if (args.numWorkers <= 1) { - sbg.train(SBDataFeeder{ sents, langMdl.knlm.get() }, [&](const sb::ObservingData& od) + sbg.train(SBDataFeeder{ sents, knlm }, [&](const lm::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; @@ -1102,7 +1183,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) } else { - sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, langMdl.knlm.get(), 8 }, [&](const sb::ObservingData& od) + sbg.trainMulti(args.numWorkers, SBDataFeeder{ sents, knlm, 8 }, [&](const lm::ObservingData& od) { llCnt += od.cntRecent; llMean += (od.llRecent - llMean * od.cntRecent) / llCnt; @@ -1131,7 +1212,7 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, const ModelBuildArgs& args) lmLogProbs.resize(sent.size()); baseNodes.resize(sent.size()); } - langMdl.knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin(), baseNodes.begin()); + knlm->evaluate(sent.begin(), sent.end(), lmLogProbs.begin(), baseNodes.begin()); float sum = sbg.evaluate(&sent[0], baseNodes.data(), lmLogProbs.data(), sent.size()); size_t cnt = sent.size() - 1; llCnt += cnt; @@ -1165,7 +1246,8 @@ void KiwiBuilder::saveModel(const string& modelPath) const saveMorphBin(ofs); } { - auto mem = langMdl.knlm->getMemory(); + auto* knlm = dynamic_cast(langMdl.get()); + auto mem = knlm->getMemory(); ofstream ofs{ modelPath + "/sj.knlm", ios_base::binary }; ofs.write((const char*)mem.get(), mem.size()); } @@ -1289,7 +1371,7 @@ pair KiwiBuilder::addWord(U16StringView newForm, POSTag tag, flo pair KiwiBuilder::addWord(const std::u16string& newForm, POSTag tag, float score, size_t origMorphemeId, size_t lmMorphemeId) { - return addWord(nonstd::to_string_view(newForm), tag, score, origMorphemeId, lmMorphemeId); + return addWord(toStringView(newForm), tag, score, origMorphemeId, lmMorphemeId); } void KiwiBuilder::addCombinedMorpheme( @@ -1596,7 +1678,7 @@ pair KiwiBuilder::addWord(U16StringView form, POSTag tag, float pair KiwiBuilder::addWord(const u16string& form, POSTag tag, float score) { - return addWord(nonstd::to_string_view(form), tag, score); + return addWord(toStringView(form), tag, score); } pair KiwiBuilder::addWord(const char16_t* form, POSTag tag, float score) @@ -1634,7 +1716,7 @@ pair KiwiBuilder::addWord(U16StringView newForm, POSTag tag, flo pair KiwiBuilder::addWord(const u16string& newForm, POSTag tag, float score, const u16string& origForm) { - return addWord(nonstd::to_string_view(newForm), tag, score, origForm); + return addWord(toStringView(newForm), tag, score, origForm); } pair KiwiBuilder::addWord(const char16_t* newForm, POSTag tag, float score, const char16_t* origForm) @@ -1698,7 +1780,7 @@ bool KiwiBuilder::addPreAnalyzedWord(U16StringView form, const vector>& analyzed, vector> positions, float score) { - return addPreAnalyzedWord(nonstd::to_string_view(form), analyzed, positions, score); + return addPreAnalyzedWord(toStringView(form), analyzed, positions, score); } bool KiwiBuilder::addPreAnalyzedWord(const char16_t* form, const vector>& analyzed, vector> positions, float score) @@ -1716,7 +1798,7 @@ size_t KiwiBuilder::loadDictionary(const string& dictPath) u16string wstr; for (size_t lineNo = 1; getline(ifs, line); ++lineNo) { - utf8To16(nonstd::to_string_view(line), wstr); + utf8To16(toStringView(line), wstr); while (!wstr.empty() && kiwi::identifySpecialChr(wstr.back()) == POSTag::unknown) wstr.pop_back(); if (wstr.empty()) continue; if (wstr[0] == u'#') continue; @@ -1770,9 +1852,9 @@ size_t KiwiBuilder::loadDictionary(const string& dictPath) auto suffix = fields[0].substr(0, fields[0].size() - 1); addedCnt += addRule(morphemes[0].second, [&](const u16string& str) { - auto strv = nonstd::to_string_view(str); - if (!strv.ends_with(suffix)) return str; - return strv.substr(0, strv.size() - suffix.size()).to_string() + morphemes[0].first.to_string(); + auto strv = toStringView(str); + if (!(strv.size() >= suffix.size() && strv.substr(strv.size() - suffix.size()) == suffix)) return str; + return u16string{ strv.substr(0, strv.size() - suffix.size()) } + u16string{ morphemes[0].first }; }, score).size(); } else @@ -2141,22 +2223,29 @@ Kiwi KiwiBuilder::build(const TypoTransformer& typos, float typoCostThreshold) c ret.formTrie = freezeTrie(move(formTrie), archType); - for (auto& m : ret.morphemes) + ret.specialMorphIds = getSpecialMorphs(); + return ret; +} + +std::array(Kiwi::SpecialMorph::max)> KiwiBuilder::getSpecialMorphs() const +{ + std::array(Kiwi::SpecialMorph::max)> specialMorphIds = { {0,} }; + for (auto& m : morphemes) { - if (m.kform && *m.kform == u"'") + if (forms[m.kform].form == u"'") { - if (m.tag == POSTag::sso) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteOpen)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ssc) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteClose)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ss) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteNA)] = &m - ret.morphemes.data(); + if (m.tag == POSTag::sso) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteOpen)] = &m - morphemes.data(); + else if (m.tag == POSTag::ssc) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteClose)] = &m - morphemes.data(); + else if (m.tag == POSTag::ss) specialMorphIds[static_cast(Kiwi::SpecialMorph::singleQuoteNA)] = &m - morphemes.data(); } - else if (m.kform && *m.kform == u"\"") + else if (forms[m.kform].form == u"\"") { - if (m.tag == POSTag::sso) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteOpen)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ssc) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteClose)] = &m - ret.morphemes.data(); - else if (m.tag == POSTag::ss) ret.specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteNA)] = &m - ret.morphemes.data(); + if (m.tag == POSTag::sso) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteOpen)] = &m - morphemes.data(); + else if (m.tag == POSTag::ssc) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteClose)] = &m - morphemes.data(); + else if (m.tag == POSTag::ss) specialMorphIds[static_cast(Kiwi::SpecialMorph::doubleQuoteNA)] = &m - morphemes.data(); } } - return ret; + return specialMorphIds; } vector KiwiBuilder::extractWords(const U16MultipleReader& reader, size_t minCnt, size_t maxWordLen, float minScore, float posThreshold, bool lmFilter) const @@ -2270,7 +2359,8 @@ void KiwiBuilder::convertHSData( const vector& inputPathes, const string& outputPath, const string& morphemeDefPath, - size_t morphemeDefMinCnt + size_t morphemeDefMinCnt, + bool generateOovDict ) const { unique_ptr dummyBuilder; @@ -2292,41 +2382,76 @@ void KiwiBuilder::convertHSData( srcBuilder = dummyBuilder.get(); } - RaggedVector sents; + UnorderedMap, size_t> oovDict; + RaggedVector sents; for (auto& path : inputPathes) { ifstream ifs; - srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, 0, nullptr, generateOovDict ? &oovDict: nullptr); } ofstream ofs; sents.write_to_memory(openFile(ofs, outputPath, ios_base::binary)); + if (generateOovDict) + { + Vector> oovDictStr(oovDict.size()); + for (auto& p : oovDict) + { + oovDictStr[p.second] = make_pair(joinHangul(p.first.first), p.first.second); + } + + const uint32_t size = oovDictStr.size(); + ofs.write((const char*)&size, sizeof(uint32_t)); + for (auto& p : oovDictStr) + { + const uint32_t tagAndSize = (uint32_t)p.second | ((uint32_t)p.first.size() << 8); + ofs.write((const char*)&tagAndSize, sizeof(uint32_t)); + ofs.write((const char*)p.first.data(), p.first.size() * sizeof(char16_t)); + } + } } HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb, double dropoutProbOnHistory, + double nounAugmentingProb, + size_t generateUnlikelihoods, const TokenFilter& tokenFilter, const TokenFilter& windowFilter, double splitRatio, bool separateDefaultMorpheme, const string& morphemeDefPath, size_t morphemeDefMinCnt, + const vector>>& contextualMapper, HSDataset* splitDataset ) const { - HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb, dropoutProbOnHistory }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, dropoutProbOnHistory, nounAugmentingProb, generateUnlikelihoods }; auto& sents = dataset.sents.get(); const KiwiBuilder* srcBuilder = this; MorphemeMap realMorph; size_t maxTokenId = 0; + + const bool doesGenerateUnlikelihoods = generateUnlikelihoods != (size_t)-1; + if (morphemeDefPath.empty()) { realMorph = restoreMorphemeMap(separateDefaultMorpheme); + dataset.langModel = langMdl; + if (doesGenerateUnlikelihoods) + { + dataset.kiwiInst = make_unique(build()); + dataset.kiwiInst->setMaxUnkFormSize(2); + } } else { + if (doesGenerateUnlikelihoods) + { + throw invalid_argument{ "cannot generate unlikelihoods with morpheme definition file" }; + } + dataset.dummyBuilder = make_shared(); dataset.dummyBuilder->initMorphemes(); ifstream ifs; @@ -2342,40 +2467,76 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, } } - auto& knlm = srcBuilder->langMdl.knlm; - dataset.knlm = knlm; dataset.morphemes = &srcBuilder->morphemes; dataset.forms = &srcBuilder->forms; + dataset.specialMorphIds = getSpecialMorphs(); if (splitDataset) { - *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; + *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, true, numWorkers, dropoutProb, 0, 0, generateUnlikelihoods }; splitDataset->dummyBuilder = dataset.dummyBuilder; - splitDataset->knlm = knlm; - splitDataset->morphemes = &srcBuilder->morphemes; - splitDataset->forms = &srcBuilder->forms; + splitDataset->langModel = dataset.langModel; + splitDataset->kiwiInst = dataset.kiwiInst; + splitDataset->morphemes = dataset.morphemes; + splitDataset->forms = dataset.forms; + splitDataset->specialMorphIds = dataset.specialMorphIds; } + UnorderedMap, size_t> oovDict; for (auto& path : inputPathes) { try { ifstream ifs; - auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); - if (splitRatio > 0) + auto cvtSents = RaggedVector::from_memory(openFile(ifs, path, ios_base::binary)); + uint32_t oovDictSize = 0; + Vector oovDictMap; + if (ifs.read((char*)&oovDictSize, sizeof(uint32_t))) { - throw invalid_argument("splitDataset cannot be used with binary input"); + for (uint32_t i = 0; i < oovDictSize; ++i) + { + uint32_t tagAndSize = 0; + ifs.read((char*)&tagAndSize, sizeof(uint32_t)); + u16string form(tagAndSize >> 8, 0); + ifs.read((char*)form.data(), form.size() * sizeof(char16_t)); + const POSTag tag = (POSTag)(tagAndSize & 0xff); + if (doesGenerateUnlikelihoods) + { + KString kform = normalizeHangul(form); + const auto oovId = (int32_t)oovDict.emplace(make_pair(kform, tag), oovDict.size()).first->second; + oovDictMap.emplace_back(-oovId - 1); + } + else + { + oovDictMap.emplace_back(getDefaultMorphemeId(tag)); + } + } } + + double splitCnt = 0; for (auto s : cvtSents) { - sents.emplace_back(); - sents.insert_data(s.begin(), s.end()); + splitCnt += splitRatio; + auto& o = splitDataset && splitCnt >= 1 ? splitDataset->sents.get() : sents; + o.emplace_back(); + if (oovDictMap.empty()) + { + o.insert_data(s.begin(), s.end()); + } + else + { + for (auto i : s) + { + o.add_data(i < 0 ? oovDictMap[-i - 1] : i); + } + } + splitCnt = fmod(splitCnt, 1.); } } catch (const runtime_error&) { ifstream ifs; - srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr, doesGenerateUnlikelihoods ? &oovDict : nullptr); } } size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1; @@ -2386,7 +2547,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, tokenSize = max(tokenSize, sents.raw().empty() ? (size_t)0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1); } - const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId; + if (doesGenerateUnlikelihoods) + { + dataset.oovDict = make_unique>>(oovDict.size()); + for (auto& p : oovDict) + { + (*dataset.oovDict)[p.second] = make_pair(joinHangul(p.first.first), p.first.second); + } + if (splitDataset) splitDataset->oovDict = dataset.oovDict; + } + + const size_t knlmVocabSize = dataset.langModel ? dataset.langModel->vocabSize() : maxTokenId; tokenSize = max(tokenSize, knlmVocabSize); size_t filteredKnlmVocabSize = 0; for (size_t i = 0; i < tokenSize; ++i) @@ -2424,6 +2595,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, dataset.totalTokens += dataset.numValidTokensInSent(i) - 1; } + if (!contextualMapper.empty()) + { + utils::ContinuousTrie> cmTrie(1); + for (auto& p : contextualMapper) + { + cmTrie.build(p.second.begin(), p.second.end(), p.first + 1); + } + cmTrie.fillFail(); + dataset.contextualMapper = utils::FrozenTrie{ cmTrie, ArchTypeHolder{} }; + } + if (splitDataset) { splitDataset->windowTokenValidness = dataset.windowTokenValidness; @@ -2434,6 +2616,30 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { splitDataset->totalTokens += splitDataset->numValidTokensInSent(i) - 1; } + + if (!contextualMapper.empty()) + { + splitDataset->contextualMapper = dataset.contextualMapper; + } } return dataset; } + +void KiwiBuilder::buildMorphData(const string& morphemeDefPath, const string& outputPath, size_t minCnt) +{ + KiwiBuilder kb; + kb.initMorphemes(); + ifstream ifs; + auto realMorph = kb.loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt) + { + return cnt >= minCnt; + }); + + size_t lmVocabSize = 0; + for (auto& p : realMorph) lmVocabSize = max(p.second.first, lmVocabSize); + lmVocabSize += 1; + kb.updateForms(); + kb.updateMorphemes(lmVocabSize); + ofstream ofs; + kb.saveMorphBin(openFile(ofs, outputPath + "/sj.morph", ios_base::binary)); +} diff --git a/src/Knlm.cpp b/src/Knlm.cpp index 2ab66388..f4683c9b 100644 --- a/src/Knlm.cpp +++ b/src/Knlm.cpp @@ -1,46 +1,185 @@ #include "Knlm.hpp" +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" namespace kiwi { namespace lm { - template + template + float KnLangModel::getLL(ptrdiff_t node_idx, KeyType next) const + { + DiffType v; + auto* node = &node_data[node_idx]; + if (node_idx == 0) + { + v = all_value_data[next]; + if (v == 0) return unk_ll; + } + else + { + if (!nst::search( + &key_data[node->next_offset], + &value_data[node->next_offset], + node->num_nexts, next, v + )) + { + return node->gamma + getLL(node_idx + node->lower, next); + } + } + + // non-leaf node + if (v > 0) + { + return node_data[node_idx + v].ll; + } + // leaf node + else + { + return reinterpret_cast(v); + } + } + + template + template + float KnLangModel::progress(IdxType& node_idx, KeyType next) const + { + float acc = 0; + while (1) + { + DiffType v; + auto* node = &node_data[node_idx]; + auto* keys = &key_data[node->next_offset]; + auto* values = &value_data[node->next_offset]; + PREFETCH_T0(node + node->lower); + if (node_idx == 0) + { + v = all_value_data[next]; + if (v == 0) + { + if (htx_data) + { + IdxType lv; + if (nst::search( + &key_data[0], + value_data, + node_data[0].num_nexts, htx_data[next], lv + )) node_idx = lv; + else node_idx = 0; + } + return acc + unk_ll; + } + } + else + { + if (!nst::search( + keys, + values, + node->num_nexts, next, v + )) + { + acc += node->gamma; + node_idx += node->lower; + PREFETCH_T0(&key_data[node_data[node_idx].next_offset]); + continue; + } + } + + // non-leaf node + if (v > 0) + { + node_idx += v; + return acc + node_data[node_idx].ll; + } + // leaf node + else + { + while (node->lower) + { + node += node->lower; + DiffType lv; + if (nst::search( + &key_data[node->next_offset], + &value_data[node->next_offset], + node->num_nexts, next, lv + )) + { + if (lv > 0) + { + node += lv; + node_idx = node - &node_data[0]; + return acc + reinterpret_cast(v); + } + } + } + if (htx_data) + { + IdxType lv; + if (nst::search( + &key_data[0], + value_data, + node_data[0].num_nexts, htx_data[next], lv + )) node_idx = lv; + else node_idx = 0; + } + else node_idx = 0; + return acc + reinterpret_cast(v); + } + } + } + + template + void* KnLangModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder>::findBestPath; + } + + template + void* KnLangModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + + template std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) { auto* ptr = reinterpret_cast(mem.get()); - auto& header = *reinterpret_cast(ptr); + auto& header = *reinterpret_cast(ptr); switch (header.key_size) { case 1: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 2: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 4: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); case 8: - return make_unique>(std::move(mem)); + return make_unique>(std::move(mem)); default: throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) }; } } - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + template struct CreateOptimizedModelGetter { template struct Wrapper { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i), transposed>; }; }; - std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType) + std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType, bool transposed) { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; + static tp::Table table{ CreateOptimizedModelGetter{} }; + static tp::Table tableTransposed{ CreateOptimizedModelGetter{} }; + auto fn = (transposed ? tableTransposed : table)[static_cast(archType)]; if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; return (*fn)(std::move(mem)); } } -} \ No newline at end of file +} diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 5b1f0185..180e507f 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -17,80 +17,13 @@ namespace kiwi { namespace lm { - static constexpr size_t serialAlignment = 16; - - using QCode = qe::QCode<0, 2, 8, 16>; - - template - inline void dequantize( - Vector& restored_floats, Vector& restored_leaf_ll, - const char* llq_data, size_t llq_size, - const char* gammaq_data, size_t gammaq_size, - const float* ll_table, - const float* gamma_table, - size_t num_non_leaf_nodes, - size_t num_leaf_nodes - ) - { - FixedLengthEncoder llq{ llq_data, (ptrdiff_t)llq_size }; - FixedLengthEncoder gammaq{ gammaq_data, (ptrdiff_t)gammaq_size }; - - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i] = ll_table[llq.read()]; - } - - for (size_t i = 0; i < num_leaf_nodes; ++i) - { - restored_leaf_ll[i] = ll_table[llq.read()]; - } + template + class KnLMState; - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i + num_non_leaf_nodes] = gamma_table[gammaq.read()]; - } - } - - template<> - inline void dequantize<8>( - Vector& restored_floats, Vector& restored_leaf_ll, - const char* llq_data, size_t llq_size, - const char* gammaq_data, size_t gammaq_size, - const float* ll_table, - const float* gamma_table, - size_t num_non_leaf_nodes, - size_t num_leaf_nodes - ) - { - const uint8_t* non_leaf_q = reinterpret_cast(llq_data); - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i] = ll_table[non_leaf_q[i]]; - } - - const uint8_t* leaf_q = reinterpret_cast(llq_data + num_non_leaf_nodes); - for (size_t i = 0; i < num_leaf_nodes; ++i) - { - restored_leaf_ll[i] = ll_table[leaf_q[i]]; - } - - const uint8_t* gamma_q = reinterpret_cast(gammaq_data); - for (size_t i = 0; i < num_non_leaf_nodes; ++i) - { - restored_floats[i + num_non_leaf_nodes] = gamma_table[gamma_q[i]]; - } - } - - inline const void* toAlignedPtr(const void* ptr, size_t alignment = serialAlignment) - { - auto addr = reinterpret_cast(ptr); - return reinterpret_cast((addr + alignment - 1) & ~(alignment - 1)); - } - - template + template class KnLangModel : public KnLangModelBase { - using MyNode = Node; + using MyNode = KnLangModelNode; std::unique_ptr node_data; std::unique_ptr key_data; @@ -140,311 +73,22 @@ namespace kiwi const float* gamma_table, size_t num_non_leaf_nodes, size_t num_leaf_nodes - ) - { - using Fn = void(*)(Vector&, Vector&, - const char*, size_t, - const char*, size_t, - const float*, - const float*, - size_t, - size_t); - static constexpr Fn table[] = { - &dequantize... - }; - return table[bits - 1](restored_floats, restored_leaf_ll, - llq_data, llq_size, - gammaq_data, gammaq_size, - ll_table, gamma_table, - num_non_leaf_nodes, num_leaf_nodes - ); - } + ); public: - KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } - { - auto* ptr = reinterpret_cast(base.get()); - auto& header = getHeader(); - const size_t quantized = header.quantized & 0x1F; - const bool compressed = header.quantized & 0x80; - - Vector d_node_size; - auto* node_sizes = reinterpret_cast(ptr + header.node_offset); - key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); - std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); - size_t num_leaf_nodes = 0; - if (compressed) - { - d_node_size.resize(header.num_nodes); - auto qc_header = reinterpret_cast(ptr + header.node_offset); - auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); - QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); - node_sizes = d_node_size.data(); - } - - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) num_non_leaf_nodes++; - else num_leaf_nodes++; - } + using VocabType = KeyType; + using LmStateType = KnLMState; - // restore ll & gamma data - Vector restored_leaf_ll, restored_floats; - const float* ll_data = nullptr; - const float* gamma_data = nullptr; - const float* leaf_ll_data = nullptr; - if (quantized) - { - if (quantized > 16) - { - throw std::runtime_error{ "16+ bits quantization not supported." }; - } + KnLangModel(utils::MemoryObject&& mem); + ModelType getType() const override { return transposed ? ModelType::knlmTransposed : ModelType::knlm; } - restored_floats.resize(num_non_leaf_nodes * 2); - restored_leaf_ll.resize(num_leaf_nodes); - leaf_ll_data = restored_leaf_ll.data(); - ll_data = &restored_floats[0]; - gamma_data = &restored_floats[num_non_leaf_nodes]; - - const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); - const float* gamma_table = ll_table + ((size_t)1 << quantized); - - dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, - ptr + header.ll_offset, header.gamma_offset - header.ll_offset, - ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, - ll_table, - gamma_table, - num_non_leaf_nodes, - num_leaf_nodes - ); - extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); - } - else - { - ll_data = reinterpret_cast(ptr + header.ll_offset); - gamma_data = reinterpret_cast(ptr + header.gamma_offset); - leaf_ll_data = ll_data + num_non_leaf_nodes; - extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); - } - - size_t htx_vocab_size = header.vocab_size; - if (header.htx_offset) - { - htx_data = reinterpret_cast(ptr + header.htx_offset); - htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; - extra_buf = toAlignedPtr(htx_data + header.vocab_size); - } - - if (!header.extra_buf_size) - { - extra_buf = nullptr; - } - - // restore node's data - node_data = make_unique(num_non_leaf_nodes); - all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); - value_data = &all_value_data[htx_vocab_size]; - std::fill(&all_value_data[0], value_data, 0); - - size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; - Vector> key_ranges; - for (size_t i = 0; i < header.num_nodes; ++i) - { - if (node_sizes[i]) - { - auto& node = node_data[non_leaf_idx]; - if (!key_ranges.empty()) - { - auto& back = key_ranges.back(); - value_data[back[1]] = non_leaf_idx - back[0]; - } - node.num_nexts = node_sizes[i]; - node.next_offset = next_offset; - node.ll = ll_data[non_leaf_idx]; - node.gamma = gamma_data[non_leaf_idx]; - next_offset += node_sizes[i]; - key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); - non_leaf_idx++; - } - else - { - auto& back = key_ranges.back(); - reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; - back[1]++; - while (key_ranges.back()[1] == key_ranges.back()[2]) - { - key_ranges.pop_back(); - if (key_ranges.empty()) break; - key_ranges.back()[1]++; - } - leaf_idx++; - } - } - - for (size_t i = 0; i < node_data[0].num_nexts; ++i) - { - auto k = key_data[i]; - auto v = value_data[i]; - all_value_data[k] = v; - } - - Vector tempBuf; - for (size_t i = 0; i < non_leaf_idx; ++i) - { - auto& node = node_data[i]; - nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); - } - - if (htx_data) - { - ptrdiff_t node = 0; - progress(node, (KeyType)header.bos_id); - unk_ll = getLL(node, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); - } - else - { - unk_ll = getLL(0, (KeyType)header.unk_id); - bos_node_idx = 0; - progress(bos_node_idx, (KeyType)header.bos_id); - } - - Deque dq; - for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) - { - auto p = dq.front(); - for (size_t i = 0; i < p->num_nexts; ++i) - { - auto k = key_data[p->next_offset + i]; - auto v = value_data[p->next_offset + i]; - if (v <= 0) continue; - auto* child = &p[v]; - child->lower = findLowerNode(p, k) - child; - dq.emplace_back(child); - } - } - } - - float getLL(ptrdiff_t node_idx, KeyType next) const - { - DiffType v; - auto* node = &node_data[node_idx]; - if (node_idx == 0) - { - v = all_value_data[next]; - if (v == 0) return unk_ll; - } - else - { - if (!nst::search( - &key_data[node->next_offset], - &value_data[node->next_offset], - node->num_nexts, next, v - )) - { - return node->gamma + getLL(node_idx + node->lower, next); - } - } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; - // non-leaf node - if (v > 0) - { - return node_data[node_idx + v].ll; - } - // leaf node - else - { - return reinterpret_cast(v); - } - } + float getLL(ptrdiff_t node_idx, KeyType next) const; template - float progress(IdxType& node_idx, KeyType next) const - { - float acc = 0; - while (1) - { - DiffType v; - auto* node = &node_data[node_idx]; - auto* keys = &key_data[node->next_offset]; - auto* values = &value_data[node->next_offset]; - PREFETCH_T0(node + node->lower); - if (node_idx == 0) - { - v = all_value_data[next]; - if (v == 0) - { - if (htx_data) - { - IdxType lv; - if (nst::search( - &key_data[0], - value_data, - node_data[0].num_nexts, htx_data[next], lv - )) node_idx = lv; - else node_idx = 0; - } - return acc + unk_ll; - } - } - else - { - if (!nst::search( - keys, - values, - node->num_nexts, next, v - )) - { - acc += node->gamma; - node_idx += node->lower; - PREFETCH_T0(&key_data[node_data[node_idx].next_offset]); - continue; - } - } - - // non-leaf node - if (v > 0) - { - node_idx += v; - return acc + node_data[node_idx].ll; - } - // leaf node - else - { - while (node->lower) - { - node += node->lower; - DiffType lv; - if (nst::search( - &key_data[node->next_offset], - &value_data[node->next_offset], - node->num_nexts, next, lv - )) - { - if (lv > 0) - { - node += lv; - node_idx = node - &node_data[0]; - return acc + reinterpret_cast(v); - } - } - } - if (htx_data) - { - IdxType lv; - if (nst::search( - &key_data[0], - value_data, - node_data[0].num_nexts, htx_data[next], lv - )) node_idx = lv; - else node_idx = 0; - } - else node_idx = 0; - return acc + reinterpret_cast(v); - } - } - } + float progress(IdxType& node_idx, KeyType next) const; float _progress(ptrdiff_t& node_idx, size_t next) const override { @@ -723,6 +367,100 @@ namespace kiwi } }; + template + struct KnLMState : public LmStateBase> + { + int32_t node = 0; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = _transposed; + + KnLMState() = default; + KnLMState(const ILangModel* lm) : node{ (int32_t)static_cast*>(lm)->getBosNodeIdx() } {} + + bool operator==(const KnLMState& other) const + { + return node == other.node; + } + + float nextImpl(const KnLangModel<_arch, VocabTy, transposed>* lm, VocabTy next) + { + return lm->progress(node, next); + } + + }; + + static constexpr size_t serialAlignment = 16; + + using QCode = qe::QCode<0, 2, 8, 16>; + + template + inline void dequantize( + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + FixedLengthEncoder llq{ llq_data, (ptrdiff_t)llq_size }; + FixedLengthEncoder gammaq{ gammaq_data, (ptrdiff_t)gammaq_size }; + + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i] = ll_table[llq.read()]; + } + + for (size_t i = 0; i < num_leaf_nodes; ++i) + { + restored_leaf_ll[i] = ll_table[llq.read()]; + } + + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i + num_non_leaf_nodes] = gamma_table[gammaq.read()]; + } + } + + template<> + inline void dequantize<8>( + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + const uint8_t* non_leaf_q = reinterpret_cast(llq_data); + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i] = ll_table[non_leaf_q[i]]; + } + + const uint8_t* leaf_q = reinterpret_cast(llq_data + num_non_leaf_nodes); + for (size_t i = 0; i < num_leaf_nodes; ++i) + { + restored_leaf_ll[i] = ll_table[leaf_q[i]]; + } + + const uint8_t* gamma_q = reinterpret_cast(gammaq_data); + for (size_t i = 0; i < num_non_leaf_nodes; ++i) + { + restored_floats[i + num_non_leaf_nodes] = gamma_table[gamma_q[i]]; + } + } + + inline const void* toAlignedPtr(const void* ptr, size_t alignment = serialAlignment) + { + auto addr = reinterpret_cast(ptr); + return reinterpret_cast((addr + alignment - 1) & ~(alignment - 1)); + } + + template void quantize(const std::vector& ll_table, const std::vector& gamma_table, const std::vector& ll, const std::vector& leaf_ll, @@ -797,7 +535,7 @@ namespace kiwi } template - utils::MemoryOwner buildCompressedModel(Header header, + utils::MemoryOwner buildCompressedModel(KnLangModelHeader header, const std::vector& min_cf_by_order, float unigram_alpha, utils::ContinuousTrie&& compressed_ngrams, @@ -968,7 +706,7 @@ namespace kiwi quantizeDispatch(tp::gen_seq<16>{}, quantized, ll_table, gamma_table, - ll, leaf_ll, gamma, + ll, leaf_ll, gamma, llq, gammaq ); } @@ -990,7 +728,7 @@ namespace kiwi size_t final_size = 0; - header.node_offset = alignedOffsetInc(final_size, sizeof(Header)); + header.node_offset = alignedOffsetInc(final_size, sizeof(KnLangModelHeader)); if (compressed) { header.key_offset = alignedOffsetInc(final_size, c_node_size.tellp()); @@ -1026,7 +764,7 @@ namespace kiwi utils::MemoryOwner ret{ final_size + extra_buf_size }; utils::omstream ostr{ (char*)ret.get(), (std::ptrdiff_t)ret.size() }; - ostr.write((const char*)&header, sizeof(Header)); + ostr.write((const char*)&header, sizeof(KnLangModelHeader)); writePadding(ostr); if (compressed) { @@ -1088,17 +826,17 @@ namespace kiwi using type = utils::TrieNodeEx; }; - template - utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, - size_t order, const std::vector& min_cf_by_order, + template + utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, + size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, - const std::vector>* bigram_list, const HistoryTx* history_transformer, + const std::vector>* bigram_list, const HistoryTx* history_transformer, const void* extra_buf, size_t extra_buf_size ) { using TrieNode = typename GetNodeType::type>::type>::type; using Key = typename TrieNode::Key; - if (quantize > 16) throw std::invalid_argument{ "16+ bits quantization not supported."}; + if (quantize > 16) throw std::invalid_argument{ "16+ bits quantization not supported." }; size_t max_vid = 0; utils::ContinuousTrie compressed_ngrams{ 1 }; std::vector unigram_pats, unigram_cnts; @@ -1186,7 +924,7 @@ namespace kiwi denom = std::accumulate(unigram_cnts.begin(), unigram_cnts.end(), 0.); for (auto& p : unigram_cnts) p /= denom; - Header header = { 0, }; + KnLangModelHeader header = { 0, }; header.order = order; header.diff_size = 4; header.unk_id = unk_id; @@ -1198,36 +936,245 @@ namespace kiwi if (max_vid <= 0xFF) { - return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFF) { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFFFFFF) { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } else { return buildCompressedModel(header, min_cf_by_order, - unigram_alpha, move(compressed_ngrams), - unigram_pats, unigram_cnts, ngram_ncnt, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, history_transformer, extra_buf, extra_buf_size); } } + + template + template + void KnLangModel::dequantizeDispatch( + tp::seq, + size_t bits, + Vector& restored_floats, Vector& restored_leaf_ll, + const char* llq_data, size_t llq_size, + const char* gammaq_data, size_t gammaq_size, + const float* ll_table, + const float* gamma_table, + size_t num_non_leaf_nodes, + size_t num_leaf_nodes + ) + { + using Fn = void(*)(Vector&, Vector&, + const char*, size_t, + const char*, size_t, + const float*, + const float*, + size_t, + size_t); + static constexpr Fn table[] = { + &dequantize... + }; + return table[bits - 1](restored_floats, restored_leaf_ll, + llq_data, llq_size, + gammaq_data, gammaq_size, + ll_table, gamma_table, + num_non_leaf_nodes, num_leaf_nodes + ); + } + + template + KnLangModel::KnLangModel(utils::MemoryObject&& mem) : KnLangModelBase{ std::move(mem) } + { + auto* ptr = reinterpret_cast(base.get()); + auto& header = getHeader(); + const size_t quantized = header.quantized & 0x1F; + const bool compressed = header.quantized & 0x80; + + Vector d_node_size; + auto* node_sizes = reinterpret_cast(ptr + header.node_offset); + key_data = make_unique((header.ll_offset - header.key_offset) / sizeof(KeyType)); + std::memcpy(&key_data[0], ptr + header.key_offset, header.ll_offset - header.key_offset); + size_t num_leaf_nodes = 0; + if (compressed) + { + d_node_size.resize(header.num_nodes); + auto qc_header = reinterpret_cast(ptr + header.node_offset); + auto qc_body = reinterpret_cast(qc_header + (header.num_nodes + 3) / 4); + QCode::template decode<8>((uint16_t*)d_node_size.data(), qc_header, qc_body, 0, header.num_nodes); + node_sizes = d_node_size.data(); + } + + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) num_non_leaf_nodes++; + else num_leaf_nodes++; + } + + // restore ll & gamma data + Vector restored_leaf_ll, restored_floats; + const float* ll_data = nullptr; + const float* gamma_data = nullptr; + const float* leaf_ll_data = nullptr; + if (quantized) + { + if (quantized > 16) + { + throw std::runtime_error{ "16+ bits quantization not supported." }; + } + + restored_floats.resize(num_non_leaf_nodes * 2); + restored_leaf_ll.resize(num_leaf_nodes); + leaf_ll_data = restored_leaf_ll.data(); + ll_data = &restored_floats[0]; + gamma_data = &restored_floats[num_non_leaf_nodes]; + + const float* ll_table = reinterpret_cast(ptr + header.qtable_offset); + const float* gamma_table = ll_table + ((size_t)1 << quantized); + + dequantizeDispatch(tp::gen_seq<16>{}, quantized, restored_floats, restored_leaf_ll, + ptr + header.ll_offset, header.gamma_offset - header.ll_offset, + ptr + header.gamma_offset, header.qtable_offset - header.gamma_offset, + ll_table, + gamma_table, + num_non_leaf_nodes, + num_leaf_nodes + ); + extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); + } + else + { + ll_data = reinterpret_cast(ptr + header.ll_offset); + gamma_data = reinterpret_cast(ptr + header.gamma_offset); + leaf_ll_data = ll_data + num_non_leaf_nodes; + extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); + } + + size_t htx_vocab_size = header.vocab_size; + if (header.htx_offset) + { + htx_data = reinterpret_cast(ptr + header.htx_offset); + htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; + extra_buf = toAlignedPtr(htx_data + header.vocab_size); + } + + if (!header.extra_buf_size) + { + extra_buf = nullptr; + } + + // restore node's data + node_data = make_unique(num_non_leaf_nodes); + all_value_data = make_unique(header.num_nodes - 1 + htx_vocab_size); + value_data = &all_value_data[htx_vocab_size]; + std::fill(&all_value_data[0], value_data, 0); + + size_t non_leaf_idx = 0, leaf_idx = 0, next_offset = 0; + Vector> key_ranges; + for (size_t i = 0; i < header.num_nodes; ++i) + { + if (node_sizes[i]) + { + auto& node = node_data[non_leaf_idx]; + if (!key_ranges.empty()) + { + auto& back = key_ranges.back(); + value_data[back[1]] = non_leaf_idx - back[0]; + } + node.num_nexts = node_sizes[i]; + node.next_offset = next_offset; + node.ll = ll_data[non_leaf_idx]; + node.gamma = gamma_data[non_leaf_idx]; + next_offset += node_sizes[i]; + key_ranges.emplace_back(std::array{ non_leaf_idx, (size_t)node.next_offset, (size_t)(node.next_offset + node.num_nexts) }); + non_leaf_idx++; + } + else + { + auto& back = key_ranges.back(); + reinterpret_cast(value_data[back[1]]) = leaf_ll_data[leaf_idx]; + back[1]++; + while (key_ranges.back()[1] == key_ranges.back()[2]) + { + key_ranges.pop_back(); + if (key_ranges.empty()) break; + key_ranges.back()[1]++; + } + leaf_idx++; + } + } + + for (size_t i = 0; i < node_data[0].num_nexts; ++i) + { + auto k = key_data[i]; + auto v = value_data[i]; + all_value_data[k] = v; + } + + Vector tempBuf; + for (size_t i = 0; i < non_leaf_idx; ++i) + { + auto& node = node_data[i]; + nst::prepare(&key_data[node.next_offset], &value_data[node.next_offset], node.num_nexts, tempBuf); + } + + if (htx_data) + { + ptrdiff_t node = 0; + progress(node, (KeyType)header.bos_id); + unk_ll = getLL(node, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, htx_data[(KeyType)header.bos_id]); + } + else + { + unk_ll = getLL(0, (KeyType)header.unk_id); + bos_node_idx = 0; + progress(bos_node_idx, (KeyType)header.bos_id); + } + + Deque dq; + for (dq.emplace_back(&node_data[0]); !dq.empty(); dq.pop_front()) + { + auto p = dq.front(); + for (size_t i = 0; i < p->num_nexts; ++i) + { + auto k = key_data[p->next_offset + i]; + auto v = value_data[p->next_offset + i]; + if (v <= 0) continue; + auto* child = &p[v]; + child->lower = findLowerNode(p, k) - child; + dq.emplace_back(child); + } + } + } } -} \ No newline at end of file + + template + struct Hash> + { + size_t operator()(const lm::KnLMState& state) const + { + std::hash hasher; + return hasher(state.node); + } + }; + +} diff --git a/src/LmState.hpp b/src/LmState.hpp deleted file mode 100644 index c019875d..00000000 --- a/src/LmState.hpp +++ /dev/null @@ -1,234 +0,0 @@ -#pragma once - -#include -#include -#include "Knlm.hpp" -#include "SkipBigramModel.hpp" - -namespace kiwi -{ - template - class VoidState - { - public: - static constexpr ArchType arch = _arch; - - VoidState() = default; - VoidState(const LangModel& lm) {} - - bool operator==(const VoidState& other) const - { - return true; - } - - float next(const LangModel& lm, size_t next) - { - return 0; - } - }; - - template - class KnLMState - { - friend struct Hash>; - int32_t node = 0; - public: - static constexpr ArchType arch = _arch; - - KnLMState() = default; - KnLMState(const LangModel& lm) : node{ (int32_t)static_cast&>(*lm.knlm).getBosNodeIdx() } {} - - bool operator==(const KnLMState& other) const - { - return node == other.node; - } - - float next(const LangModel& lm, VocabTy next) - { - return static_cast&>(*lm.knlm).progress(node, next); - } - - void predict(const LangModel& lm, float* out) const - { - - } - }; - - template - class SbgState : public KnLMState<_arch, VocabTy> - { - friend struct Hash>; - size_t historyPos = 0; - std::array history = { {0,} }; - public: - static constexpr ArchType arch = _arch; - - SbgState() = default; - SbgState(const LangModel& lm) : KnLMState<_arch, VocabTy>{ lm } {} - - bool operator==(const SbgState& other) const - { - return KnLMState<_arch, VocabTy>::operator==(other) && historyPos == other.historyPos && history == other.history; - } - - void getLastHistory(VocabTy* out, size_t n) const - { - for (size_t i = 0; i < n; ++i) - { - out[i] = history[(historyPos + windowSize + i - n) % windowSize]; - } - } - - float next(const LangModel& lm, VocabTy next) - { - auto& sbg = static_cast&>(*lm.sbg); - float ll = KnLMState::next(lm, next); - if (sbg.isValidVocab(next)) - { - if (ll > -13) - { - ll = sbg.evaluate(history.data(), windowSize, next, ll); - } - history[historyPos] = next; - historyPos = (historyPos + 1) % windowSize; - } - return ll; - } - - void predict(const LangModel& lm, float* out) const - { - - } - }; - - // hash for LmState - template - struct Hash> - { - size_t operator()(const VoidState& state) const - { - return 0; - } - }; - - template - struct Hash> - { - size_t operator()(const KnLMState& state) const - { - std::hash hasher; - return hasher(state.node); - } - }; - - template - struct Hash> - { - size_t operator()(const SbgState& state) const - { - Hash> hasher; - std::hash vocabHasher; - size_t ret = hasher(state); - for (size_t i = 0; i < windowSize; ++i) - { - ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); - } - return ret; - } - }; - - template - struct WrappedKnLM - { - template using type = KnLMState; - }; - - template - struct WrappedSbg - { - template using type = SbgState; - }; - - template - class LmObject : public LmObjectBase - { - LangModel mdl; - public: - LmObject(const LangModel& _mdl) : mdl(_mdl) - { - } - - size_t vocabSize() const override - { - return mdl.knlm->getHeader().vocab_size; - } - - template - float evalSequence(It first, It last) const - { - float ret = 0; - LmStateTy state{ mdl }; - for (; first != last; ++first) - { - ret += state.next(mdl, *first); - } - return ret; - } - - float evalSequence(const uint32_t* seq, size_t length, size_t stride) const override - { - float ret = 0; - LmStateTy state{ mdl }; - for (size_t i = 0; i < length; ++i) - { - ret += state.next(mdl, *seq); - seq = reinterpret_cast(reinterpret_cast(seq) + stride); - } - return ret; - } - - void predictNext(const uint32_t* seq, size_t length, size_t stride, float* outScores) const override - { - LmStateTy state{ mdl }; - for (size_t i = 0; i < length; ++i) - { - state.next(mdl, *seq); - seq = reinterpret_cast(reinterpret_cast(seq) + stride); - } - state.predict(mdl, outScores); - } - - void evalSequences( - const uint32_t* prefix, size_t prefixLength, size_t prefixStride, - const uint32_t* suffix, size_t suffixLength, size_t suffixStride, - size_t seqSize, const uint32_t** seq, const size_t* seqLength, const size_t* seqStride, float* outScores - ) const override - { - float ret = 0; - LmStateTy state{ mdl }; - for (size_t i = 0; i < prefixLength; ++i) - { - ret += state.next(mdl, *prefix); - prefix = reinterpret_cast(reinterpret_cast(prefix) + prefixStride); - } - - Vector states(seqSize, state); - std::fill(outScores, outScores + seqSize, ret); - for (size_t s = 0; s < seqSize; ++s) - { - auto p = seq[s]; - for (size_t i = 0; i < seqLength[s]; ++i) - { - outScores[s] += states[s].next(mdl, *p); - p = reinterpret_cast(reinterpret_cast(p) + seqStride[s]); - } - - for (size_t i = 0; i < suffixLength; ++i) - { - outScores[s] += states[s].next(mdl, *suffix); - suffix = reinterpret_cast(reinterpret_cast(suffix) + suffixStride); - } - } - } - }; -} diff --git a/src/MathFunc.h b/src/MathFunc.h new file mode 100644 index 00000000..036a6cec --- /dev/null +++ b/src/MathFunc.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace kiwi +{ + namespace lm + { + template + float logSumExp(const float* arr, size_t size); + + template + void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + + template + void logSoftmax(float* arr, size_t size); + + template + void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } +} diff --git a/src/MathFunc.hpp b/src/MathFunc.hpp new file mode 100644 index 00000000..a990b1c5 --- /dev/null +++ b/src/MathFunc.hpp @@ -0,0 +1,323 @@ +#pragma once +#include +#include +#include +#include "MathFunc.h" +#include "SIMD.hpp" + +namespace kiwi +{ + namespace lm + { + template + float logSumExpImpl(const float* arr) + { + simd::Operator op; + + auto pmax = op.loadf(arr); + for (size_t i = op.packetSize; i < size; i += op.packetSize) + { + pmax = op.maxf(pmax, op.loadf(&arr[i])); + } + pmax = op.redmaxbf(pmax); + + auto sum = op.zerof(); + for (size_t i = 0; i < size; i += op.packetSize) + { + sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); + } + return std::log(op.redsumf(sum)) + op.firstf(pmax); + } + + template + struct LogSumExp + { + template + float operator()(const float* arr, std::integral_constant) + { + return logSumExpImpl(arr); + } + }; + + template<> + struct LogSumExp + { + template + float operator()(const float* arr, std::integral_constant) + { + float maxValue = *std::max_element(arr, arr + size); + float sum = 0; + for (size_t i = 0; i < size; ++i) + { + sum += std::exp(arr[i] - maxValue); + } + return std::log(sum) + maxValue; + } + }; + + template<> + struct LogSumExp : public LogSumExp + { + }; + + template + void logSoftmaxImpl(float* arr) + { + simd::Operator op; + + auto pmax = op.loadf(arr); + for (size_t i = op.packetSize; i < size; i += op.packetSize) + { + pmax = op.maxf(pmax, op.loadf(&arr[i])); + } + pmax = op.redmaxbf(pmax); + + auto sum = op.zerof(); + for (size_t i = 0; i < size; i += op.packetSize) + { + sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); + } + pmax = op.addf(op.logf(op.set1f(op.redsumf(sum))), pmax); + for (size_t i = 0; i < size; i += op.packetSize) + { + op.storef(&arr[i], op.subf(op.loadf(&arr[i]), pmax)); + } + } + + template + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + return logSoftmaxImpl(arr); + } + }; + + template<> + struct LogSoftmax + { + template + void operator()(float* arr, std::integral_constant) + { + float maxValue = *std::max_element(arr, arr + size); + float sum = 0; + for (size_t i = 0; i < size; ++i) + { + sum += std::exp(arr[i] - maxValue); + } + maxValue += std::log(sum); + for (size_t i = 0; i < size; ++i) + { + arr[i] -= maxValue; + } + } + }; + + template<> + struct LogSoftmax : public LogSoftmax + { + }; + + template + struct LogSoftmaxTransposed; + + template + struct LogSoftmaxTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + m = op.expf(a0); + m = op.addf(m, op.expf(a1)); + m = op.addf(m, op.expf(a2)); + m = op.addf(m, op.expf(a3)); + m = op.addf(m, op.expf(a4)); + m = op.addf(m, op.expf(a5)); + m = op.addf(m, op.expf(a6)); + m = op.addf(m, op.expf(a7)); + m = op.logf(m); + + // subtract + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + op.storef(arr, a0); + op.storef(arr + stride, a1); + op.storef(arr + stride * 2, a2); + op.storef(arr + stride * 3, a3); + op.storef(arr + stride * 4, a4); + op.storef(arr + stride * 5, a5); + op.storef(arr + stride * 6, a6); + op.storef(arr + stride * 7, a7); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSoftmaxTransposed + { + void operator()(float* arr, size_t batchSize, size_t stride) + { + throw std::runtime_error("Unsupported architecture"); + } + }; + + template<> + struct LogSoftmaxTransposed : public LogSoftmaxTransposed + { + }; + + template + struct LogSumExpTransposed; + + template + struct LogSumExpTransposed + { + static constexpr size_t size = 8; + + void block(float* arr, size_t stride) + { + simd::Operator op; + simd::FloatPacket a0 = op.loadf(arr), + a1 = op.loadf(arr + stride), + a2 = op.loadf(arr + stride * 2), + a3 = op.loadf(arr + stride * 3), + a4 = op.loadf(arr + stride * 4), + a5 = op.loadf(arr + stride * 5), + a6 = op.loadf(arr + stride * 6), + a7 = op.loadf(arr + stride * 7); + // find maximum + auto m = op.maxf(a0, a1); + m = op.maxf(m, a2); + m = op.maxf(m, a3); + m = op.maxf(m, a4); + m = op.maxf(m, a5); + m = op.maxf(m, a6); + m = op.maxf(m, a7); + + // subtract maximum + a0 = op.subf(a0, m); + a1 = op.subf(a1, m); + a2 = op.subf(a2, m); + a3 = op.subf(a3, m); + a4 = op.subf(a4, m); + a5 = op.subf(a5, m); + a6 = op.subf(a6, m); + a7 = op.subf(a7, m); + + // exp, reduce sum and log + auto s = op.expf(a0); + s = op.addf(s, op.expf(a1)); + s = op.addf(s, op.expf(a2)); + s = op.addf(s, op.expf(a3)); + s = op.addf(s, op.expf(a4)); + s = op.addf(s, op.expf(a5)); + s = op.addf(s, op.expf(a6)); + s = op.addf(s, op.expf(a7)); + s = op.logf(s); + + op.storef(arr, op.addf(m, s)); + } + + void operator()(float* arr, size_t batchSize, size_t stride) + { + simd::Operator op; + for (size_t i = 0; i < batchSize; i += op.packetSize) + { + block(arr, stride); + arr += op.packetSize; + } + } + }; + + template<> + struct LogSumExpTransposed + { + void operator()(float* arr, size_t batchSize, size_t stride) + { + throw std::runtime_error("Unsupported architecture"); + } + }; + + template<> + struct LogSumExpTransposed : public LogSumExpTransposed + { + }; + + template + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template + void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride) + { + if (size == 8) return LogSumExpTransposed{}(arr, batchSize, stride); + throw std::runtime_error("Unsupported size"); + } + + template + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template + void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride) + { + if (size == 8) return LogSoftmaxTransposed{}(arr, batchSize, stride); + throw std::runtime_error("Unsupported size"); + } + } +} diff --git a/src/PathEvaluator.h b/src/PathEvaluator.h new file mode 100644 index 00000000..ef1eae65 --- /dev/null +++ b/src/PathEvaluator.h @@ -0,0 +1,103 @@ +#pragma once +#include + +namespace kiwi +{ + struct SpecialState + { + uint8_t singleQuote : 1; + uint8_t doubleQuote : 1; + uint8_t bulletHash : 6; + + SpecialState() : singleQuote{ 0 }, doubleQuote{ 0 }, bulletHash{ 0 } + { + } + + operator uint8_t() const + { + return reinterpret_cast(*this); + } + + bool operator<(const SpecialState& o) const + { + return (uint8_t)(*this) < (uint8_t)o; + } + + bool operator==(const SpecialState& o) const + { + return (uint8_t)(*this) == (uint8_t)o; + } + }; + + struct PathNode + { + const Morpheme* morph = nullptr; + KString str; + uint32_t begin = 0, end = 0; + float wordScore = 0, typoCost = 0; + uint32_t typoFormId = 0; + uint32_t nodeId = 0; + + PathNode(const Morpheme* _morph = nullptr, + const KString& _str = {}, + uint32_t _begin = 0, + uint32_t _end = 0, + float _wordScore = 0, + float _typoCost = 0, + uint32_t _typoFormId = 0, + uint32_t _nodeId = 0 + ) + : morph{ _morph }, str{ _str }, begin{ _begin }, end{ _end }, + wordScore{ _wordScore }, typoCost{ _typoCost }, typoFormId{ _typoFormId }, nodeId{ _nodeId } + { + } + + bool operator==(const PathNode& o) const + { + return morph == o.morph + && str == o.str + && begin == o.begin + && end == o.end + && wordScore == o.wordScore + && typoCost == o.typoCost + && typoFormId == o.typoFormId; + } + }; + using Path = Vector; + + struct PathResult + { + Path path; + float score = 0; + SpecialState prevState; + SpecialState curState; + + PathResult(Path&& _path = {}, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) + : path{ move(_path) }, score{ _score }, prevState{ _prevState }, curState{ _curState } + { + } + + PathResult(const Path& _path, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) + : path{ _path }, score{ _score }, prevState{ _prevState }, curState{ _curState } + { + } + }; + + template + struct BestPathFinder + { + static Vector findBestPath(const Kiwi* kw, + const Vector& prevSpStates, + const KGraphNode* graph, + const size_t graphSize, + const size_t topN, + bool openEnd, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ); + }; + + using FnFindBestPath = decltype(&BestPathFinder::findBestPath); +} diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 05dd1ebb..33799b1b 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -4,305 +4,21 @@ #include #include #include +#include #include "ArchAvailable.h" #include "KTrie.h" #include "FeatureTestor.h" #include "FrozenTrie.hpp" -#include "LmState.hpp" #include "StrUtils.h" #include "SortUtils.hpp" #include "LimitedVector.hpp" +#include "PathEvaluator.h" +#include "BestPathContainer.hpp" using namespace std; namespace kiwi { - struct SpecialState - { - uint8_t singleQuote : 1; - uint8_t doubleQuote : 1; - uint8_t bulletHash : 6; - - SpecialState() : singleQuote{ 0 }, doubleQuote{ 0 }, bulletHash{ 0 } - { - } - - operator uint8_t() const - { - return reinterpret_cast(*this); - } - - bool operator<(const SpecialState& o) const - { - return (uint8_t)(*this) < (uint8_t)o; - } - - bool operator==(const SpecialState& o) const - { - return (uint8_t)(*this) == (uint8_t)o; - } - }; - - template - struct WordLL; - - using Wid = uint32_t; - - enum class PathEvaluatingMode - { - topN, - top1, - top1Small, - }; - - class PathEvaluator - { - public: - struct Result - { - const Morpheme* morph = nullptr; - KString str; - uint32_t begin = 0, end = 0; - float wordScore = 0, typoCost = 0; - uint32_t typoFormId = 0; - uint32_t nodeId = 0; - - Result(const Morpheme* _morph = nullptr, - const KString& _str = {}, - uint32_t _begin = 0, - uint32_t _end = 0, - float _wordScore = 0, - float _typoCost = 0, - uint32_t _typoFormId = 0, - uint32_t _nodeId = 0 - ) - : morph{ _morph }, str{ _str }, begin{ _begin }, end{ _end }, - wordScore{ _wordScore }, typoCost{ _typoCost }, typoFormId{ _typoFormId }, nodeId{ _nodeId } - { - } - - bool operator==(const Result& o) const - { - return morph == o.morph - && str == o.str - && begin == o.begin - && end == o.end - && wordScore == o.wordScore - && typoCost == o.typoCost - && typoFormId == o.typoFormId; - } - }; - using Path = Vector; - - struct ChunkResult - { - Path path; - float score = 0; - SpecialState prevState; - SpecialState curState; - - ChunkResult(Path&& _path = {}, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) - : path{ move(_path) }, score{ _score }, prevState{ _prevState }, curState{ _curState } - {} - - ChunkResult(const Path& _path, float _score = 0, SpecialState _prevState = {}, SpecialState _curState = {}) - : path{ _path }, score{ _score }, prevState{ _prevState }, curState{ _curState } - {} - }; - - template - static Vector findBestPath(const Kiwi* kw, - const Vector& prevSpStates, - const KGraphNode* graph, - const size_t graphSize, - const size_t topN, - bool openEnd, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - - template - static void evalPath(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex = false, - bool splitSaisiot = false, - bool mergeSaisiot = false, - const std::unordered_set* blocklist = nullptr - ); - - template - static void evalSingleMorpheme( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ); - }; - - using FnFindBestPath = decltype(&PathEvaluator::findBestPath>); - - template class LmState> - struct FindBestPathGetter - { - template - struct Wrapper - { - static constexpr FnFindBestPath value = &PathEvaluator::findBestPath(i)>>; - }; - }; - - template - struct WordLL - { - const Morpheme* morpheme = nullptr; - float accScore = 0, accTypoCost = 0; - const WordLL* parent = nullptr; - LmState lmState; - Wid wid = 0; - uint16_t ownFormId = 0; - uint8_t combineSocket = 0; - uint8_t rootId = 0; - SpecialState spState; - - WordLL() = default; - - WordLL(const Morpheme* _morph, float _accScore, float _accTypoCost, const WordLL* _parent, LmState _lmState, SpecialState _spState) - : morpheme{ _morph }, - accScore{ _accScore }, - accTypoCost{ _accTypoCost }, - parent{ _parent }, - lmState{ _lmState }, - spState{ _spState }, - rootId{ parent ? parent->rootId : (uint8_t)0 } - { - } - - const WordLL* root() const - { - if (parent) return parent->root(); - else return this; - } - }; - - static constexpr uint8_t commonRootId = -1; - - template - struct PathHash - { - LmState lmState; - uint8_t rootId, spState; - - PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) - : lmState{ _lmState }, rootId{ _rootId }, spState { _spState } - { - } - - PathHash(const WordLL& wordLl, const Morpheme* morphBase) - : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } - { - } - - bool operator==(const PathHash& o) const - { - return lmState == o.lmState && rootId == o.rootId && spState == o.spState; - } - }; - - template - struct PathHash> - { - using LmState = SbgState; - - KnLMState<_arch, VocabTy> lmState; - array lastMorphemes; - uint8_t rootId, spState; - - PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) - : lmState{ _lmState }, rootId{ _rootId }, spState{ _spState } - { - _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); - } - - - PathHash(const WordLL& wordLl, const Morpheme* morphBase) - : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } - { - } - - bool operator==(const PathHash& o) const - { - return lmState == o.lmState && lastMorphemes == o.lastMorphemes && spState == o.spState; - } - }; - - template - struct Hash> - { - size_t operator()(const PathHash& p) const - { - size_t ret = 0; - if (sizeof(PathHash) % sizeof(size_t)) - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(uint32_t); ++i) - { - ret ^= ptr[i]; - } - } - else - { - auto ptr = reinterpret_cast(&p); - for (size_t i = 0; i < sizeof(PathHash) / sizeof(size_t); ++i) - { - ret ^= ptr[i]; - } - } - return ret; - } - }; - - struct WordLLGreater - { - template - bool operator()(const WordLL& a, const WordLL& b) const - { - return a.accScore > b.accScore; - } - }; - - template - inline std::ostream& printDebugPath(std::ostream& os, const WordLL& src) - { - if (src.parent) - { - printDebugPath(os, *src.parent); - } - - if (src.morpheme) src.morpheme->print(os); - else os << "NULL"; - os << " , "; - return os; - } - inline bool hasLeftBoundary(const KGraphNode* node) { // 시작 지점은 항상 왼쪽 경계로 처리 @@ -464,411 +180,672 @@ namespace kiwi || m == Kiwi::SpecialMorph::doubleQuoteOpen || m == Kiwi::SpecialMorph::doubleQuoteClose; } - template - class BestPathConatiner; - template - class BestPathConatiner + template + inline void insertToPathContainer( + BestPathConatiner& bestPathCont, + const size_t topN, + const Vector& prevSpStates, + const Morpheme* curMorph, + const Morpheme* morphBase, + LmState&& state, + const float score, + const KGraphNode* node, + const WordLL& prevPath, + const RuleBasedScorer& ruleBasedScorer + ) { - // pair: [index, size] - UnorderedMap, pair> bestPathIndex; - Vector> bestPathValues; - public: - inline void clear() + const auto insert = [&](uint8_t rootId) + { + const auto* prevMorpheme = &morphBase[prevPath.wid]; + auto spState = prevPath.spState; + if (rootId != commonRootId) + { + spState = prevSpStates[rootId]; + } + const float candScoreWithRule = score + ruleBasedScorer(prevMorpheme, spState); + + // update special state + if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; + else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; + if (ruleBasedScorer.curMorphSbType) + { + spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + } + + bestPathCont.insert(topN, prevPath.rootId, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(state), spState); + }; + + if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) { - bestPathIndex.clear(); - bestPathValues.clear(); + for (uint8_t rootId = 0; rootId < prevSpStates.size(); ++rootId) + { + insert(rootId); + } } + else + { + insert(commonRootId); + } + } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + class FormEvaluator + { + const kchar_t* leftFormFirst; + const kchar_t* leftFormLast; + bool leftFormEndswithSSC; + POSTag prevTag; + + public: + template + FormEvaluator(const WordLL& prevPath, + const Vector& ownFormList, + const Morpheme* morphBase + ) { - auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); - if (inserted.second) + if (prevPath.ownFormId) + { + leftFormFirst = ownFormList[prevPath.ownFormId - 1].data(); + leftFormLast = leftFormFirst + ownFormList[0].size(); + } + else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) { - bestPathValues.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); - if (rootId != commonRootId) bestPathValues.back().rootId = rootId; - bestPathValues.resize(bestPathValues.size() + topN - 1); + leftFormFirst = morphBase[prevPath.wid].kform->data(); + leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); } else { - auto bestPathFirst = bestPathValues.begin() + inserted.first->second.first; - auto bestPathLast = bestPathValues.begin() + inserted.first->second.first + inserted.first->second.second; - if (distance(bestPathFirst, bestPathLast) < topN) - { - *bestPathLast = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) bestPathLast->rootId = rootId; - push_heap(bestPathFirst, bestPathLast + 1, WordLLGreater{}); - ++inserted.first->second.second; - } - else - { - if (accScore > bestPathFirst->accScore) - { - pop_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - *(bestPathLast - 1) = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) (*(bestPathLast - 1)).rootId = rootId; - push_heap(bestPathFirst, bestPathLast, WordLLGreater{}); - } - } + leftFormFirst = prevPath.morpheme->getForm().data(); + leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); } + leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; + prevTag = prevPath.morpheme->tag; } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + bool operator()(const Morpheme* curMorph, const float ignoreCondScore, float& candScore) const { - for (auto& p : bestPathIndex) + const CondVowel cvowel = curMorph->vowel; + const CondPolarity cpolar = curMorph->polar; + if (prevTag == POSTag::ssc || leftFormEndswithSSC) { - const auto index = p.second.first; - const auto size = p.second.second; - for (size_t i = 0; i < size; ++i) - { - resultOut.emplace_back(move(bestPathValues[index + i])); - auto& newPath = resultOut.back(); - - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } - } + // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 } + else if (ignoreCondScore) + { + candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + } + else + { + if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) return false; + } + return true; } }; template - class BestPathConatiner + struct LmEvalData { - UnorderedMap, WordLL> bestPathes; - public: - inline void clear() + LmState state; + float score = 0; + uint32_t length = 0; + }; + + template + struct PathEvaluator; + + template + struct PathEvaluator::type> + { + const Kiwi* kw; + const KGraphNode* startNode; + const size_t topN; + Vector>>& cache; + const Vector& ownFormList; + const Vector& prevSpStates; + + PathEvaluator(const Kiwi* _kw, + const KGraphNode* _startNode, + size_t _topN, + Vector>>& _cache, + const Vector& _ownFormList, + const Vector& _prevSpStates + ) + : kw{ _kw }, startNode{ _startNode }, topN{ _topN }, cache{ _cache }, ownFormList{ _ownFormList }, prevSpStates{ _prevSpStates } { - bestPathes.clear(); } - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) + template + void operator()( + const size_t nodeIdx, + const size_t ownFormId, + CandTy&& cands, + bool unknownForm, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ) const { - WordLL newPath{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) newPath.rootId = rootId; - auto inserted = bestPathes.emplace(ph, newPath); - if (!inserted.second) + const size_t langVocabSize = kw->langMdl->vocabSize(); + auto* const node = startNode + nodeIdx; + auto& nCache = cache[nodeIdx]; + Vector> refCache; + + float whitespaceDiscount = 0; + if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) { - auto& target = inserted.first->second; - if (accScore > target.accScore) - { - target = newPath; - } + whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; + } + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) + { + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); } - } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) - { - for (auto& p : bestPathes) + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; + + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { - resultOut.emplace_back(move(p.second)); - auto& newPath = resultOut.back(); + totalPrevPathes += cache[prev - startNode].size(); + } - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) + for (bool ignoreCond : {false, true}) + { + for (auto& curMorph : cands) { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; + if (splitComplex && curMorph->hasComplex()) continue; + if (blocklist && curMorph->hasMorpheme(*blocklist)) continue; + + // 덧붙은 받침(zCoda)을 위한 지름길 + if (curMorph->tag == POSTag::z_coda) + { + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isJClass(lastTag) && !isEClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += curMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= curMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; + newPath.wid = curMorph->lmMorphemeId; + } + } + continue; + } + // 사이시옷(zSiot)을 위한 지름길 + if (curMorph->tag == POSTag::z_siot) + { + if (!(splitSaisiot || mergeSaisiot)) + { + continue; + } + + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& p : cache[prev - startNode]) + { + auto lastTag = kw->morphemes[p.wid].tag; + if (!isNNClass(lastTag)) continue; + nCache.emplace_back(p); + auto& newPath = nCache.back(); + newPath.accScore += curMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= curMorph->userScore; + newPath.parent = &p; + newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; + newPath.wid = curMorph->lmMorphemeId; + } + } + continue; + } + + // if the morpheme has chunk set + if (!curMorph->isSingle()) + { + // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 + if (node->prev && node[-(int)node->prev].endPos < node->startPos + && curMorph->kform + && curMorph->kform->size() == 1 + && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') + && curMorph->chunks[0]->kform + && curMorph->chunks[0]->kform->size() == 1 + && (*curMorph->chunks[0]->kform)[0] == u'하') + { + continue; + } + } + + if (topN > 1) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + else + { + evalSingleMorpheme(nCache, node, ownFormId, + curMorph, ignoreCond ? -10 : 0, nodeLevelDiscount); + } + } + if (!nCache.empty()) break; } - } - }; - - template - class BestPathConatiner - { - Vector> bestPathIndicesSmall; - Vector> bestPathValuesSmall; - public: - inline void clear() - { - bestPathIndicesSmall.clear(); - bestPathValuesSmall.clear(); - } + thread_local Vector maxScores; + maxScores.clear(); + maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); - inline void insert(const PathHash& ph, size_t topN, uint8_t rootId, - const Morpheme* morph, float accScore, float accTypoCost, const WordLL* parent, LmState&& lmState, SpecialState spState) - { - const auto it = find(bestPathIndicesSmall.begin(), bestPathIndicesSmall.end(), ph); - if (it == bestPathIndicesSmall.end()) + if (topN == 1) { - bestPathIndicesSmall.push_back(ph); - bestPathValuesSmall.emplace_back(morph, accScore, accTypoCost, parent, move(lmState), spState); - if (rootId != commonRootId) bestPathValuesSmall.back().rootId = rootId; + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + maxScores[rootId] = max(maxScores[rootId], c.accScore); + } } else { - auto& target = bestPathValuesSmall[it - bestPathIndicesSmall.begin()]; - if (accScore > target.accScore) + for (auto& c : nCache) { - target = WordLL{ morph, accScore, accTypoCost, parent, move(lmState), spState }; - if (rootId != commonRootId) target.rootId = rootId; + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + if (c.accScore > maxScores[rootId * topN]) + { + pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + maxScores[rootId * topN + topN - 1] = c.accScore; + push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + } } } + + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) + { + const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; + if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; + } + nCache.resize(validCount); } - inline void writeTo(Vector>& resultOut, const Morpheme* curMorph, Wid lastSeqId, size_t ownFormId) + template + void evalSingleMorpheme( + Vector>& resultOut, + const KGraphNode* node, + const size_t ownFormId, + const Morpheme* curMorph, + const float ignoreCondScore, + const float nodeLevelDiscount + ) const { - for (auto& p : bestPathValuesSmall) + thread_local BestPathConatiner bestPathCont; + + const auto* langMdl = kw->getLangModel(); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + + const size_t langVocabSize = langMdl->vocabSize(); + + const Morpheme* lastMorph; + Wid firstWid; + if (curMorph->isSingle()) + { + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; + firstWid = curMorph->lmMorphemeId; + } + // if the morpheme has chunk set + else { - resultOut.emplace_back(move(p)); - auto& newPath = resultOut.back(); + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; + firstWid = curMorph->chunks[0]->lmMorphemeId; + } - // fill the rest information of resultOut - newPath.wid = lastSeqId; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - newPath.combineSocket = curMorph->combineSocket; - newPath.ownFormId = ownFormId; - } + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) + { + lastSeqId = lastMorph - kw->morphemes.data(); + } + else + { + lastSeqId = lastMorph->lmMorphemeId; } - } - }; - template - void PathEvaluator::evalSingleMorpheme( - Vector>& resultOut, - const Kiwi* kw, - const Vector& ownForms, - const Vector>>& cache, - size_t ownFormId, - const Morpheme* curMorph, - const KGraphNode* node, - const KGraphNode* startNode, - const size_t topN, - const float ignoreCondScore, - const float nodeLevelDiscount, - const Vector& prevSpStates - ) - { - thread_local BestPathConatiner bestPathCont; - thread_local Vector rootIds; - const LangModel& langMdl = kw->langMdl; - const Morpheme* morphBase = kw->morphemes.data(); - const auto spacePenalty = kw->spacePenalty; - const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + bestPathCont.clear(); + const float additionalScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); - const size_t langVocabSize = langMdl.knlm->getHeader().vocab_size; + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; - const Morpheme* lastMorph; - Wid firstWid; - if (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot) - { - lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; - firstWid = curMorph->lmMorphemeId; - } - // if the morpheme has chunk set - else - { - lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; - firstWid = curMorph->chunks[0]->lmMorphemeId; - } + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + { + for (auto& prevPath : cache[prev - startNode]) + { + // 사이시옷 뒤에 명사가 아닌 태그가 오거나 공백이 있는 경우 제외 + if (prevPath.morpheme->tag == POSTag::z_siot && ( + !isNNClass(curMorph->tag) || prev->endPos < node->startPos + )) + { + continue; + } - Wid lastSeqId; - if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) - { - lastSeqId = lastMorph - kw->morphemes.data(); - } - else - { - lastSeqId = lastMorph->lmMorphemeId; - } + float candScore = prevPath.accScore + additionalScore; + if (prevPath.combineSocket) + { + // merge with only the same socket + if (prevPath.combineSocket != curMorph->combineSocket || curMorph->isSingle()) + { + continue; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) candScore -= spacePenalty; + else continue; + } + firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; + } + + FormEvaluator formEvaluator{ prevPath, ownFormList, morphBase }; + if (!formEvaluator(curMorph, ignoreCondScore, candScore)) continue; + auto cLmState = prevPath.lmState; + if (curMorph->combineSocket && curMorph->isSingle()) + { + // no-op + } + else + { + if (morphBase[firstWid].tag == POSTag::p) + { + // prohibit without + goto continueFor; + } + float ll = cLmState.next(langMdl, firstWid); + candScore += ll; + if (!curMorph->isSingle()) + { + for (size_t i = 1; i < curMorph->chunks.size(); ++i) + { + const auto wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + // prohibit without + goto continueFor; + } + ll = cLmState.next(langMdl, wid); + candScore += ll; + } + } + } - bestPathCont.clear(); - const float additionalScore = curMorph->userScore + nodeLevelDiscount + kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(cLmState), candScore, node, prevPath, ruleBasedScorer); + continueFor:; + } + } - RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); + } + }; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + template + struct MorphemeEvaluator + { + template + void eval( + Vector>& resultOut, + const Kiwi* kw, + const Vector& ownForms, + const Vector>>& cache, + size_t ownFormId, + const Vector& morphs, + const KGraphNode* node, + const KGraphNode* startNode, + const size_t topN, + const size_t totalPrevPathes, + const float ignoreCondScore, + const float nodeLevelDiscount, + const Vector& prevSpStates + ) const { - for (auto& prevPath : cache[prev - startNode]) + thread_local BestPathConatiner bestPathCont; + thread_local Vector> evalMatrix; + thread_local Vector nextWids; + + const auto* langMdl = kw->getLangModel(); + const Morpheme* morphBase = kw->morphemes.data(); + const auto spacePenalty = kw->spacePenalty; + const bool allowedSpaceBetweenChunk = kw->spaceTolerance > 0; + const size_t langVocabSize = langMdl->vocabSize(); + + evalMatrix.resize(totalPrevPathes * morphs.size()); + nextWids.clear(); + + size_t prevId = -1; + size_t length; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { - // 사이시옷 뒤에 명사가 아닌 태그가 오거나 공백이 있는 경우 제외 - if (prevPath.morpheme->tag == POSTag::z_siot && ( - !isNNClass(curMorph->tag) || prev->endPos < node->startPos - )) - { - continue; - } - - float candScore = prevPath.accScore + additionalScore; - if (prevPath.combineSocket) + for (auto& prevPath : cache[prev - startNode]) { - // merge with only the same socket - if (prevPath.combineSocket != curMorph->combineSocket || (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + ++prevId; + FormEvaluator formEvaluator{ prevPath, ownForms, morphBase }; + for (size_t curId = 0; curId < morphs.size(); ++curId) { + const auto curMorph = morphs[curId]; + float candScore = prevPath.accScore + curMorph->userScore + nodeLevelDiscount; + Wid firstWid; + if (curMorph->isSingle()) + { + firstWid = curMorph->lmMorphemeId; + } + else + { + firstWid = curMorph->chunks[0]->lmMorphemeId; + } + + if (prevPath.combineSocket) + { + // merge with only the same socket + if (prevPath.combineSocket != curMorph->combineSocket || curMorph->isSingle()) + { + goto invalidCandidate; + } + if (prev->endPos < node->startPos) + { + if (allowedSpaceBetweenChunk) candScore -= spacePenalty; + else goto invalidCandidate; + } + firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; + } + + if (!formEvaluator(curMorph, ignoreCondScore, candScore)) continue; + + length = 0; + if (curMorph->combineSocket && curMorph->isSingle()) + { + // no op + } + else + { + if (morphBase[firstWid].tag == POSTag::p) + { + goto invalidCandidate; + } + + if (curMorph->isSingle()) + { + length = 1; + } + else + { + length = curMorph->chunks.size(); + for (size_t i = 1; i < length; ++i) + { + const Wid wid = curMorph->chunks[i]->lmMorphemeId; + if (morphBase[wid].tag == POSTag::p) + { + goto invalidCandidate; + } + } + } + } + evalMatrix[prevId * morphs.size() + curId].state = prevPath.lmState; + evalMatrix[prevId * morphs.size() + curId].score = candScore; + evalMatrix[prevId * morphs.size() + curId].length = length; + if (length > 0) nextWids.emplace_back(firstWid); + if (length > 1) + { + for (size_t i = 1; i < length; ++i) + { + nextWids.emplace_back(curMorph->chunks[i]->lmMorphemeId); + } + } continue; + invalidCandidate: + evalMatrix[prevId * morphs.size() + curId].score = -INFINITY; + evalMatrix[prevId * morphs.size() + curId].length = 0; } - if (prev->endPos < node->startPos) - { - if (allowedSpaceBetweenChunk) candScore -= spacePenalty; - else continue; - } - firstWid = morphBase[prevPath.wid].getCombined()->lmMorphemeId; } + } - const kchar_t* leftFormFirst, * leftFormLast; - if (prevPath.ownFormId) + { + size_t widOffset = 0; + for (auto& e : evalMatrix) { - leftFormFirst = ownForms[prevPath.ownFormId - 1].data(); - leftFormLast = leftFormFirst + ownForms[0].size(); + //if (e.length == 0) continue; + float score = 0; + for (size_t i = 0; i < e.length; ++i) + { + score += e.state.next(langMdl, nextWids[widOffset + i]); + } + e.score += score; + widOffset += e.length; } - else if (morphBase[prevPath.wid].kform && !morphBase[prevPath.wid].kform->empty()) + } + + for (size_t curId = 0; curId < morphs.size(); ++curId) + { + const auto curMorph = morphs[curId]; + bestPathCont.clear(); + + const Morpheme* lastMorph; + if (curMorph->isSingle()) { - leftFormFirst = morphBase[prevPath.wid].kform->data(); - leftFormLast = leftFormFirst + morphBase[prevPath.wid].kform->size(); + lastMorph = curMorph->getCombined() ? curMorph->getCombined() : curMorph; } + // if the morpheme has chunk set else { - leftFormFirst = prevPath.morpheme->getForm().data(); - leftFormLast = leftFormFirst + prevPath.morpheme->getForm().size(); + lastMorph = curMorph->chunks[curMorph->chunks.size() - 1]; } - const CondVowel cvowel = curMorph->vowel; - const CondPolarity cpolar = curMorph->polar; - const bool leftFormEndswithSSC = leftFormFirst < leftFormLast && identifySpecialChr(leftFormLast[-1]) == POSTag::ssc; - if (prevPath.morpheme->tag == POSTag::ssc || leftFormEndswithSSC) - { - // 이전 형태소가 닫는 괄호인 경우 좌측 결합조건을 적용하지 않음 - } - else if (ignoreCondScore) + Wid lastSeqId; + if (within(lastMorph, kw->morphemes.data() + langVocabSize, kw->morphemes.data() + kw->morphemes.size())) { - candScore += FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar) ? 0 : ignoreCondScore; + lastSeqId = lastMorph - kw->morphemes.data(); } else { - if (!FeatureTestor::isMatched(leftFormFirst, leftFormLast, cvowel, cpolar)) continue; + lastSeqId = lastMorph->lmMorphemeId; } - auto cLmState = prevPath.lmState; - if (curMorph->combineSocket && (curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) - { - // no-op - } - else + RuleBasedScorer ruleBasedScorer{ kw, curMorph, node }; + const float morphScore = kw->tagScorer.evalLeftBoundary(hasLeftBoundary(node), curMorph->tag); + size_t prevId = -1; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { - if (morphBase[firstWid].tag == POSTag::p) + for (auto& prevPath : cache[prev - startNode]) { - // prohibit without - goto continueFor; - } - float ll = cLmState.next(langMdl, firstWid); - candScore += ll; - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) - { - for (size_t i = 1; i < curMorph->chunks.size(); ++i) + ++prevId; + auto& em = evalMatrix[prevId * morphs.size() + curId]; + if (em.score < -99999) { - const auto wid = curMorph->chunks[i]->lmMorphemeId; - if (morphBase[wid].tag == POSTag::p) - { - // prohibit without - goto continueFor; - } - ll = cLmState.next(langMdl, wid); - candScore += ll; + continue; } - } - } - if ((ruleBasedScorer.curMorphSbType || isQuote(ruleBasedScorer.curMorphSpecialType)) && prevPath.rootId == commonRootId) - { - rootIds.resize(prevSpStates.size()); - iota(rootIds.begin(), rootIds.end(), 0); - } - else - { - rootIds.resize(1); - rootIds[0] = commonRootId; - } - - for (auto rootId : rootIds) - { - const auto* prevMorpheme = &morphBase[prevPath.wid]; - auto spState = prevPath.spState; - if (rootId != commonRootId) - { - spState = prevSpStates[rootId]; - } - const float candScoreWithRule = candScore + ruleBasedScorer(prevMorpheme, spState); - - // update special state - if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteOpen) spState.singleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::singleQuoteClose) spState.singleQuote = 0; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteOpen) spState.doubleQuote = 1; - else if (ruleBasedScorer.curMorphSpecialType == Kiwi::SpecialMorph::doubleQuoteClose) spState.doubleQuote = 0; - if (ruleBasedScorer.curMorphSbType) - { - spState.bulletHash = hashSbTypeOrder(ruleBasedScorer.curMorphSbType, ruleBasedScorer.curMorphSbOrder + 1); + insertToPathContainer(bestPathCont, topN, prevSpStates, curMorph, morphBase, move(em.state), em.score, node, prevPath, ruleBasedScorer); } - - PathHash ph{ cLmState, prevPath.rootId, spState }; - bestPathCont.insert(ph, topN, rootId, curMorph, candScoreWithRule, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), spState); } - continueFor:; + bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); } } + }; - bestPathCont.writeTo(resultOut, curMorph, lastSeqId, ownFormId); - return; - } - - template - void PathEvaluator::evalPath(const Kiwi* kw, - const KGraphNode* startNode, - const KGraphNode* node, - const size_t topN, - Vector>>& cache, - const Vector& ownFormList, - size_t i, - size_t ownFormId, - CandTy&& cands, - bool unknownForm, - const Vector& prevSpStates, - bool splitComplex, - bool splitSaisiot, - bool mergeSaisiot, - const std::unordered_set* blocklist - ) + template + struct PathEvaluator::type> { - const size_t langVocabSize = kw->langMdl.knlm->getHeader().vocab_size; - auto& nCache = cache[i]; - Vector> refCache; - - float whitespaceDiscount = 0; - if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) - { - whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; - } - const float typoDiscount = -node->typoCost * kw->typoCostWeight; - float unknownFormDiscount = 0; - if (unknownForm) + const Kiwi* kw; + const KGraphNode* startNode; + const size_t topN; + Vector>>& cache; + const Vector& ownFormList; + const Vector& prevSpStates; + + PathEvaluator(const Kiwi* _kw, + const KGraphNode* _startNode, + size_t _topN, + Vector>>& _cache, + const Vector& _ownFormList, + const Vector& _prevSpStates + ) + : kw{ _kw }, startNode{ _startNode }, topN{ _topN }, cache{ _cache }, ownFormList{ _ownFormList }, prevSpStates{ _prevSpStates } { - size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); - unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); } - const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; - - size_t totalPrevPathes = 0; - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) + template + void operator()( + const size_t nodeIdx, + const size_t ownFormId, + CandTy&& cands, + bool unknownForm, + bool splitComplex = false, + bool splitSaisiot = false, + bool mergeSaisiot = false, + const std::unordered_set* blocklist = nullptr + ) const { - totalPrevPathes += cache[prev - startNode].size(); - } - const bool useContainerForSmall = totalPrevPathes <= 48; + thread_local Vector maxScores; + thread_local Vector validMorphCands; + const size_t langVocabSize = kw->langMdl->vocabSize(); + auto* const node = startNode + nodeIdx; + auto& nCache = cache[nodeIdx]; + + float whitespaceDiscount = 0; + if (node->uform.empty() && !node->form->form.empty() && node->spaceErrors) + { + whitespaceDiscount = -kw->spacePenalty * node->spaceErrors; + } + const float typoDiscount = -node->typoCost * kw->typoCostWeight; + float unknownFormDiscount = 0; + if (unknownForm) + { + size_t unknownLen = node->uform.empty() ? node->form->form.size() : node->uform.size(); + unknownFormDiscount = -(unknownLen * kw->unkFormScoreScale + kw->unkFormScoreBias); + } - for (bool ignoreCond : {false, true}) - { + const float nodeLevelDiscount = whitespaceDiscount + typoDiscount + unknownFormDiscount; + const Morpheme* zCodaMorph = nullptr; + const Morpheme* zSiotMorph = nullptr; + validMorphCands.clear(); for (auto& curMorph : cands) { if (splitComplex && curMorph->hasComplex()) continue; @@ -876,6 +853,37 @@ namespace kiwi // 덧붙은 받침(zCoda)을 위한 지름길 if (curMorph->tag == POSTag::z_coda) + { + zCodaMorph = curMorph; + continue; + } + else if (curMorph->tag == POSTag::z_siot) + { + zSiotMorph = curMorph; + continue; + } + + if (!curMorph->isSingle()) + { + // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 + if (node->prev && node[-(int)node->prev].endPos < node->startPos + && curMorph->kform + && curMorph->kform->size() == 1 + && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') + && curMorph->chunks[0]->kform + && curMorph->chunks[0]->kform->size() == 1 + && (*curMorph->chunks[0]->kform)[0] == u'하') + { + continue; + } + } + validMorphCands.emplace_back(curMorph); + } + + for (bool ignoreCond : {false, true}) + { + // 덧붙은 받침(zCoda)을 위한 지름길 + if (zCodaMorph) { for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { @@ -885,23 +893,22 @@ namespace kiwi if (!isJClass(lastTag) && !isEClass(lastTag)) continue; nCache.emplace_back(p); auto& newPath = nCache.back(); - newPath.accScore += curMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= curMorph->userScore; + newPath.accScore += zCodaMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zCodaMorph->userScore; newPath.parent = &p; - newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; - newPath.wid = curMorph->lmMorphemeId; + newPath.morpheme = &kw->morphemes[zCodaMorph->lmMorphemeId]; + newPath.wid = zCodaMorph->lmMorphemeId; } } continue; } // 사이시옷(zSiot)을 위한 지름길 - if (curMorph->tag == POSTag::z_siot) + if (zSiotMorph) { if (!(splitSaisiot || mergeSaisiot)) { continue; } - for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { for (auto& p : cache[prev - startNode]) @@ -910,97 +917,91 @@ namespace kiwi if (!isNNClass(lastTag)) continue; nCache.emplace_back(p); auto& newPath = nCache.back(); - newPath.accScore += curMorph->userScore * kw->typoCostWeight; - newPath.accTypoCost -= curMorph->userScore; + newPath.accScore += zSiotMorph->userScore * kw->typoCostWeight; + newPath.accTypoCost -= zSiotMorph->userScore; newPath.parent = &p; - newPath.morpheme = &kw->morphemes[curMorph->lmMorphemeId]; - newPath.wid = curMorph->lmMorphemeId; + newPath.morpheme = &kw->morphemes[zSiotMorph->lmMorphemeId]; + newPath.wid = zSiotMorph->lmMorphemeId; } } continue; } - // if the morpheme has chunk set - if (!(curMorph->chunks.empty() || curMorph->complex || curMorph->saisiot)) + size_t totalPrevPathes = 0; + for (auto* prev = node->getPrev(); prev; prev = prev->getSibling()) { - // '하다/하게/하지'가 '다/게/지'로 축약된 경우인데 앞에 공백이 있는 경우는 탐색후보에서 제외 - if (node->prev && node[-(int)node->prev].endPos < node->startPos - && curMorph->kform - && curMorph->kform->size() == 1 - && ((*curMorph->kform)[0] == u'다' || (*curMorph->kform)[0] == u'게' || (*curMorph->kform)[0] == u'지') - && curMorph->chunks[0]->kform - && curMorph->chunks[0]->kform->size() == 1 - && (*curMorph->chunks[0]->kform)[0] == u'하') - { - continue; - } + totalPrevPathes += cache[prev - startNode].size(); } + MorphemeEvaluator me; if (topN > 1) { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + me.template eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } - else if (useContainerForSmall) + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + me.template eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + } + else if (totalPrevPathes <= BestPathContainerTraits::maxSize) + { + me.template eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } else { - evalSingleMorpheme(nCache, kw, ownFormList, cache, - ownFormId, curMorph, - node, startNode, topN, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); + me.template eval(nCache, kw, ownFormList, cache, + ownFormId, validMorphCands, + node, startNode, topN, totalPrevPathes, ignoreCond ? -10 : 0, nodeLevelDiscount, prevSpStates); } - + if (!nCache.empty()) break; } - if (!nCache.empty()) break; - } - thread_local Vector maxScores; - maxScores.clear(); - maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); + maxScores.clear(); + maxScores.resize((1 + prevSpStates.size()) * topN, -INFINITY); - if (topN == 1) - { - for (auto& c : nCache) + if (topN == 1) { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - maxScores[rootId] = max(maxScores[rootId], c.accScore); + for (auto& c : nCache) + { + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + maxScores[rootId] = max(maxScores[rootId], c.accScore); + } } - } - else - { - for (auto& c : nCache) + else { - if (c.morpheme->combineSocket) continue; - const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; - if (c.accScore > maxScores[rootId * topN]) + for (auto& c : nCache) { - pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); - maxScores[rootId * topN + topN - 1] = c.accScore; - push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + if (c.morpheme->combineSocket) continue; + const auto rootId = c.rootId == commonRootId ? 0 : c.rootId + 1; + if (c.accScore > maxScores[rootId * topN]) + { + pop_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + maxScores[rootId * topN + topN - 1] = c.accScore; + push_heap(maxScores.begin() + rootId * topN, maxScores.begin() + (rootId + 1) * topN, greater{}); + } } } - } - size_t validCount = 0; - for (size_t i = 0; i < nCache.size(); ++i) - { - const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; - if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; - if (validCount != i) nCache[validCount] = move(nCache[i]); - validCount++; + size_t validCount = 0; + for (size_t i = 0; i < nCache.size(); ++i) + { + const auto rootId = nCache[i].rootId == commonRootId ? 0 : nCache[i].rootId + 1; + if (nCache[i].accScore + kw->cutOffThreshold < maxScores[rootId * topN]) continue; + if (validCount != i) nCache[validCount] = move(nCache[i]); + validCount++; + } + nCache.resize(validCount); } - nCache.resize(validCount); - } - + }; template - inline PathEvaluator::Path generateTokenList(const WordLL* result, + inline Path generateTokenList(const WordLL* result, const utils::ContainerSearcher>& csearcher, const KGraphNode* graph, const Vector& ownFormList, @@ -1021,7 +1022,7 @@ namespace kiwi return morphFirst + morph->lmMorphemeId; }; - PathEvaluator::Path ret; + Path ret; const WordLL* prev = steps.back()->parent; for (auto it = steps.rbegin(); it != steps.rend(); ++it) { @@ -1113,8 +1114,8 @@ namespace kiwi return ret; } - template - Vector PathEvaluator::findBestPath(const Kiwi* kw, + template + Vector BestPathFinder::findBestPath(const Kiwi* kw, const Vector& prevSpStates, const KGraphNode* graph, const size_t graphSize, @@ -1127,12 +1128,14 @@ namespace kiwi ) { static constexpr size_t eosId = 1; + using LmState = typename LangModel::LmStateType; + const auto* langMdl = kw->getLangModel(); Vector>> cache(graphSize); Vector ownFormList; Vector unknownNodeCands, unknownNodeLCands; - const size_t langVocabSize = kw->langMdl.knlm->getHeader().vocab_size; + const size_t langVocabSize = langMdl->vocabSize(); const KGraphNode* startNode = graph; const KGraphNode* endNode = graph + graphSize - 1; @@ -1150,7 +1153,7 @@ namespace kiwi } // start node - cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ kw->langMdl }, SpecialState{}); + cache[0].emplace_back(&kw->morphemes[0], 0.f, 0.f, nullptr, LmState{ langMdl }, SpecialState{}); cache[0].back().rootId = commonRootId; #ifdef DEBUG_PRINT @@ -1163,6 +1166,9 @@ namespace kiwi } #endif + PathEvaluator evaluator{ + kw, startNode, topN, cache, ownFormList, uniqStates, + }; // middle nodes for (size_t i = 1; i < graphSize - 1; ++i) { @@ -1176,9 +1182,8 @@ namespace kiwi if (node->form) { - evalPath(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, node->form->candidate, - false, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, node->form->candidate, + false, splitComplex, splitSaisiot, mergeSaisiot, blocklist); if (all_of(node->form->candidate.begin(), node->form->candidate.end(), [](const Morpheme* m) { return m->combineSocket || !(m->chunks.empty() || m->complex || m->saisiot); @@ -1186,16 +1191,14 @@ namespace kiwi { ownFormList.emplace_back(node->form->form); ownFormId = ownFormList.size(); - evalPath(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, unknownNodeLCands, - true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, unknownNodeLCands, + true, splitComplex, splitSaisiot, mergeSaisiot, blocklist); }; } else { - evalPath(kw, startNode, node, topN, cache, - ownFormList, i, ownFormId, unknownNodeCands, - true, uniqStates, splitComplex, splitSaisiot, mergeSaisiot, blocklist); + evaluator(i, ownFormId, unknownNodeCands, + true, splitComplex, splitSaisiot, mergeSaisiot, blocklist); } #ifdef DEBUG_PRINT @@ -1225,7 +1228,7 @@ namespace kiwi } if (p.morpheme->tag == POSTag::z_siot) continue; - float c = p.accScore + (openEnd ? 0 : p.lmState.next(kw->langMdl, eosId)); + float c = p.accScore + (openEnd ? 0 : p.lmState.next(langMdl, eosId)); if (p.spState.singleQuote) c -= 2; if (p.spState.doubleQuote) c -= 2; if (p.rootId == commonRootId) @@ -1265,7 +1268,7 @@ namespace kiwi #endif utils::ContainerSearcher> csearcher{ cache }; - Vector ret; + Vector ret; size_t numUniqRootIdAndSpState; { UnorderedSet> uniqRootIdAndSpState; @@ -1298,7 +1301,7 @@ namespace kiwi ret.emplace_back(move(tokens), cand[i].accScore, uniqStates[cand[i].rootId], cand[i].spState); } } - sort(ret.begin(), ret.end(), [](const ChunkResult& a, const ChunkResult& b) + sort(ret.begin(), ret.end(), [](const PathResult& a, const PathResult& b) { return a.score > b.score; }); diff --git a/src/SIMD.hpp b/src/SIMD.hpp index bf9834bc..6e894c00 100644 --- a/src/SIMD.hpp +++ b/src/SIMD.hpp @@ -11,6 +11,15 @@ #define STRONG_INLINE inline #endif +#if defined(_MSC_VER) +#define FORCE_INLINE __forceinline +#elif defined(__GNUC__) +#define FORCE_INLINE __attribute__((always_inline)) +#else +#define FORCE_INLINE inline +#endif + + #include "ArchAvailable.h" namespace kiwi @@ -27,9 +36,8 @@ namespace kiwi using IntPacket = typename PacketTrait::IntPacket; template - class OperatorBase + struct OperatorBase { - public: enum { packetSize = PacketTrait::size }; using FPacket = typename PacketTrait::FloatPacket; @@ -40,7 +48,9 @@ namespace kiwi static STRONG_INLINE FPacket mulf(FPacket a, FPacket b) { return O::mulf(a, b); } static STRONG_INLINE FPacket divf(FPacket a, FPacket b) { return O::divf(a, b); } static STRONG_INLINE FPacket maddf(FPacket a, FPacket b, FPacket c) { return O::maddf(a, b, c); } + static STRONG_INLINE IPacket set1i(int32_t a) { return O::set1i(a); } static STRONG_INLINE FPacket set1f(float a) { return O::set1f(a); } + static STRONG_INLINE FPacket set1frombits(uint32_t a) { return O::reinterpret_as_float(O::set1i(a)); } static STRONG_INLINE FPacket loadf(const float* a) { return O::loadf(a); } static STRONG_INLINE void storef(float* a, FPacket b) { return O::storef(a, b); } static STRONG_INLINE FPacket maxf(FPacket a, FPacket b) { return O::maxf(a, b); } @@ -48,16 +58,35 @@ namespace kiwi static STRONG_INLINE FPacket floorf(FPacket a) { return O::floorf(a); } static STRONG_INLINE FPacket negatef(FPacket a) { return O::negatef(a); } static STRONG_INLINE FPacket zerof() { return O::zerof(); } + static STRONG_INLINE FPacket cast_to_float(IPacket a) { return O::cast_to_float(a); } static STRONG_INLINE IPacket cast_to_int(FPacket a) { return O::cast_to_int(a); } static STRONG_INLINE FPacket reinterpret_as_float(IPacket a) { return O::reinterpret_as_float(a); } + static STRONG_INLINE IPacket reinterpret_as_int(FPacket a) { return O::reinterpret_as_int(a); } static STRONG_INLINE float firstf(FPacket a) { return O::firstf(a); } static STRONG_INLINE float redsumf(FPacket a) { return O::redsumf(a); } static STRONG_INLINE float redmaxf(FPacket a) { return O::redmaxf(a); } static STRONG_INLINE FPacket redmaxbf(FPacket a) { return O::redmaxbf(a); } + static STRONG_INLINE IPacket band(IPacket a, IPacket b) { return O::band(a, b); } + static STRONG_INLINE FPacket band(FPacket a, FPacket b) { return O::band(a, b); } + + static STRONG_INLINE IPacket bor(IPacket a, IPacket b) { return O::bor(a, b); } + static STRONG_INLINE FPacket bor(FPacket a, FPacket b) { return O::bor(a, b); } + + static STRONG_INLINE IPacket select(IPacket mask, IPacket a, IPacket b) { return O::select(mask, a, b); } + static STRONG_INLINE FPacket select(FPacket mask, FPacket a, FPacket b) { return O::select(mask, a, b); } + + static STRONG_INLINE FPacket cmp_eq(FPacket a, FPacket b) { return O::cmp_eq(a, b); } + static STRONG_INLINE FPacket cmp_le(FPacket a, FPacket b) { return O::cmp_le(a, b); } + static STRONG_INLINE FPacket cmp_lt(FPacket a, FPacket b) { return O::cmp_lt(a, b); } + static STRONG_INLINE FPacket cmp_lt_or_nan(FPacket a, FPacket b) { return O::cmp_lt_or_nan(a, b); } + template static STRONG_INLINE IPacket sll(IPacket a) { return O::template sll(a); } + template + static STRONG_INLINE IPacket srl(IPacket a) { return O::template srl(a); } + static STRONG_INLINE FPacket ldexpf_fast(FPacket a, FPacket exponent) { static constexpr int exponentBits = 8, mantissaBits = 23; @@ -70,6 +99,19 @@ namespace kiwi return mulf(a, reinterpret_as_float(sll(e))); } + static STRONG_INLINE FPacket frexpf_fast(FPacket x, FPacket& exp) + { + // ignore nan, inf, 0, denormalized numbers. + const IPacket exp_mask = set1i(0x7F800000), + inv_exp_mask = set1i(~0x7F800000), + norm_exp = set1i(126 << 23); + const FPacket exp_bias = set1f(126); + IPacket ix = reinterpret_as_int(x); + exp = subf(cast_to_float(srl<23>(band(ix, exp_mask))), exp_bias); + ix = bor(band(ix, inv_exp_mask), norm_exp); + return reinterpret_as_float(ix); + } + static STRONG_INLINE FPacket expf(FPacket _x) { const FPacket cst_1 = set1f(1.0f); @@ -117,10 +159,94 @@ namespace kiwi // TODO: replace pldexp with faster implementation since y in [-1, 1). return maxf(ldexpf_fast(y, m), _x); } + + static STRONG_INLINE FPacket logf(FPacket _x) + { + FPacket x = _x; + + const FPacket cst_1 = set1f(1.0f); + const FPacket cst_neg_half = set1f(-0.5f); + // The smallest non denormalized float number. + const FPacket cst_min_norm_pos = set1frombits(0x00800000u); + const FPacket cst_minus_inf = set1frombits(0xff800000u); + const FPacket cst_pos_inf = set1frombits(0x7f800000u); + + // Polynomial coefficients. + const FPacket cst_cephes_SQRTHF = set1f(0.707106781186547524f); + const FPacket cst_cephes_log_p0 = set1f(7.0376836292E-2f); + const FPacket cst_cephes_log_p1 = set1f(-1.1514610310E-1f); + const FPacket cst_cephes_log_p2 = set1f(1.1676998740E-1f); + const FPacket cst_cephes_log_p3 = set1f(-1.2420140846E-1f); + const FPacket cst_cephes_log_p4 = set1f(+1.4249322787E-1f); + const FPacket cst_cephes_log_p5 = set1f(-1.6668057665E-1f); + const FPacket cst_cephes_log_p6 = set1f(+2.0000714765E-1f); + const FPacket cst_cephes_log_p7 = set1f(-2.4999993993E-1f); + const FPacket cst_cephes_log_p8 = set1f(+3.3333331174E-1f); + + // Truncate input values to the minimum positive normal. + x = maxf(x, cst_min_norm_pos); + + FPacket e; + // extract significant in the range [0.5,1) and exponent + x = frexpf_fast(x, e); + + // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) + // and shift by -1. The values are then centered around 0, which improves + // the stability of the polynomial evaluation. + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + FPacket mask = cmp_lt(x, cst_cephes_SQRTHF); + FPacket tmp = band(x, mask); + x = subf(x, cst_1); + e = subf(e, band(cst_1, mask)); + x = addf(x, tmp); + + FPacket x2 = mulf(x, x); + FPacket x3 = mulf(x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts, probably + // to improve instruction-level parallelism. + FPacket y, y1, y2; + y = maddf(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = maddf(cst_cephes_log_p3, x, cst_cephes_log_p4); + y2 = maddf(cst_cephes_log_p6, x, cst_cephes_log_p7); + y = maddf(y, x, cst_cephes_log_p2); + y1 = maddf(y1, x, cst_cephes_log_p5); + y2 = maddf(y2, x, cst_cephes_log_p8); + y = maddf(y, x3, y1); + y = maddf(y, x3, y2); + y = mulf(y, x3); + + y = maddf(cst_neg_half, x2, y); + x = addf(x, y); + + const FPacket cst_ln2 = set1f(0.69314718f); + x = maddf(e, cst_ln2, x); + + FPacket invalid_mask = cmp_lt_or_nan(_x, zerof()); + FPacket iszero_mask = cmp_eq(_x, zerof()); + FPacket pos_inf_mask = cmp_eq(_x, cst_pos_inf); + // Filter out invalid inputs, i.e.: + // - negative arg will be NAN + // - 0 will be -INF + // - +INF will be +INF + return select(iszero_mask, cst_minus_inf, + bor(select(pos_inf_mask, cst_pos_inf, x), invalid_mask)); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + return 0; + } }; + template + struct OperatorImpl; + template - class Operator; + struct Operator; } } @@ -142,17 +268,16 @@ namespace kiwi struct PacketTrait : public PacketTrait {}; -#if defined(_MSC_VER) || defined(__SSE2__) || defined(__AVX2__) - template<> - class Operator : public OperatorBase> +#if defined(_MSC_VER) || defined(__SSE2__) || defined(__SSE4_1__) || defined(__AVX2__) + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m128 addf(__m128 a, __m128 b) { return _mm_add_ps(a, b); } static STRONG_INLINE __m128 subf(__m128 a, __m128 b) { return _mm_sub_ps(a, b); } static STRONG_INLINE __m128 mulf(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } static STRONG_INLINE __m128 divf(__m128 a, __m128 b) { return _mm_div_ps(a, b); } - static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return addf(mulf(a, b), c); } + static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return O::addf(O::mulf(a, b), c); } + static STRONG_INLINE __m128i set1i(int32_t a) { return _mm_set1_epi32(a); } static STRONG_INLINE __m128 set1f(float a) { return _mm_set1_ps(a); } static STRONG_INLINE __m128 loadf(const float* a) { return _mm_load_ps(a); } static STRONG_INLINE void storef(float* a, __m128 b) { return _mm_store_ps(a, b); } @@ -165,29 +290,44 @@ namespace kiwi return _mm_and_ps(a, mask); } - static STRONG_INLINE __m128 selectf(__m128 mask, __m128 a, __m128 b) + static STRONG_INLINE __m128 band(__m128 a, __m128 b) { return _mm_and_ps(a, b); } + static STRONG_INLINE __m128i band(__m128i a, __m128i b) { return _mm_and_si128(a, b); } + + static STRONG_INLINE __m128 bor(__m128 a, __m128 b) { return _mm_or_ps(a, b); } + static STRONG_INLINE __m128i bor(__m128i a, __m128i b) { return _mm_or_si128(a, b); } + + static STRONG_INLINE __m128 select(__m128 mask, __m128 a, __m128 b) { return _mm_or_ps(_mm_and_ps(mask, a), _mm_andnot_ps(mask, b)); } + static STRONG_INLINE __m128i select(__m128i mask, __m128i a, __m128i b) + { + return _mm_or_si128(_mm_and_si128(mask, a), _mm_andnot_si128(mask, b)); + } + + static STRONG_INLINE __m128 cmp_eq(__m128 a, __m128 b) { return _mm_cmpeq_ps(a, b); } + static STRONG_INLINE __m128 cmp_le(__m128 a, __m128 b) { return _mm_cmple_ps(a, b); } + static STRONG_INLINE __m128 cmp_lt(__m128 a, __m128 b) { return _mm_cmplt_ps(a, b); } + static STRONG_INLINE __m128 cmp_lt_or_nan(__m128 a, __m128 b) { return _mm_cmpnge_ps(a, b); } static STRONG_INLINE __m128 rint(__m128 a) { - const __m128 limit = set1f(static_cast(1 << 23)); - const __m128 abs_a = absf(a); - __m128 r = addf(abs_a, limit); + const __m128 limit = O::set1f(static_cast(1 << 23)); + const __m128 abs_a = O::absf(a); + __m128 r = O::addf(abs_a, limit); #ifdef __GNUC__ __asm__("" : "+g,x" (r)); #endif - r = subf(r, limit); + r = O::subf(r, limit); - r = selectf(_mm_cmplt_ps(abs_a, limit), - selectf(_mm_cmplt_ps(a, zerof()), negatef(r), r), a); + r = O::select(_mm_cmplt_ps(abs_a, limit), + O::select(_mm_cmplt_ps(a, O::zerof()), O::negatef(r), r), a); return r; } static STRONG_INLINE __m128 floorf(__m128 a) { - const __m128 cst_1 = set1f(1.0f); + const __m128 cst_1 = O::set1f(1.0f); __m128 tmp = rint(a); __m128 mask = _mm_cmpgt_ps(tmp, a); mask = _mm_and_ps(mask, cst_1); @@ -197,8 +337,13 @@ namespace kiwi static STRONG_INLINE __m128 zerof() { return _mm_setzero_ps(); } static STRONG_INLINE __m128 negatef(__m128 a) { return subf(zerof(), a); } static STRONG_INLINE __m128i cast_to_int(__m128 a) { return _mm_cvtps_epi32(a); } + static STRONG_INLINE __m128 cast_to_float(__m128i a) { return _mm_cvtepi32_ps(a); } + static STRONG_INLINE __m128i reinterpret_as_int(__m128 a) { return _mm_castps_si128(a); } static STRONG_INLINE __m128 reinterpret_as_float(__m128i a) { return _mm_castsi128_ps(a); } + template static STRONG_INLINE __m128i sll(__m128i a) { return _mm_slli_epi32(a, bit); } + template static STRONG_INLINE __m128i srl(__m128i a) { return _mm_srli_epi32(a, bit); } + static STRONG_INLINE float firstf(__m128 a) { return _mm_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m128 a) @@ -219,49 +364,47 @@ namespace kiwi return _mm_max_ps(tmp, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); } }; -#endif -#if defined(_MSC_VER) || defined(__SSE4_1__) || defined(__AVX2__) template<> - class Operator : public OperatorBase> + struct Operator : public OperatorImpl> { - public: + }; +#endif - static STRONG_INLINE __m128 addf(__m128 a, __m128 b) { return _mm_add_ps(a, b); } - static STRONG_INLINE __m128 subf(__m128 a, __m128 b) { return _mm_sub_ps(a, b); } - static STRONG_INLINE __m128 mulf(__m128 a, __m128 b) { return _mm_mul_ps(a, b); } - static STRONG_INLINE __m128 divf(__m128 a, __m128 b) { return _mm_div_ps(a, b); } - static STRONG_INLINE __m128 maddf(__m128 a, __m128 b, __m128 c) { return addf(mulf(a, b), c); } - static STRONG_INLINE __m128 set1f(float a) { return _mm_set1_ps(a); } - static STRONG_INLINE __m128 loadf(const float* a) { return _mm_load_ps(a); } - static STRONG_INLINE void storef(float* a, __m128 b) { return _mm_store_ps(a, b); } - static STRONG_INLINE __m128 maxf(__m128 a, __m128 b) { return _mm_max_ps(a, b); } - static STRONG_INLINE __m128 minf(__m128 a, __m128 b) { return _mm_min_ps(a, b); } - static STRONG_INLINE __m128 floorf(__m128 a) { return _mm_floor_ps(a); } - static STRONG_INLINE __m128 zerof() { return _mm_setzero_ps(); } - static STRONG_INLINE __m128 negatef(__m128 a) { return subf(zerof(), a); } - static STRONG_INLINE __m128i cast_to_int(__m128 a) { return _mm_cvtps_epi32(a); } - static STRONG_INLINE __m128 reinterpret_as_float(__m128i a) { return _mm_castsi128_ps(a); } - template static STRONG_INLINE __m128i sll(__m128i a) { return _mm_slli_epi32(a, bit); } - static STRONG_INLINE float firstf(__m128 a) { return _mm_cvtss_f32(a); } +#if defined(_MSC_VER) || defined(__SSE4_1__) || defined(__AVX2__) + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE __m128 select(__m128 mask, __m128 a, __m128 b) + { + return _mm_blendv_ps(b, a, mask); + } - static STRONG_INLINE float redsumf(__m128 a) + static STRONG_INLINE __m128i select(__m128i mask, __m128i a, __m128i b) { - __m128 tmp = _mm_add_ps(a, _mm_movehl_ps(a, a)); - return firstf(_mm_add_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1))); + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b), _mm_castsi128_ps(a), _mm_castsi128_ps(mask))); } - static STRONG_INLINE float redmaxf(__m128 a) - { - __m128 tmp = _mm_max_ps(a, _mm_movehl_ps(a, a)); - return firstf(_mm_max_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1))); - } + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m128i pa, pb, sum = _mm_setzero_si128(); + __m128i one16 = _mm_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 16) + { + pa = _mm_loadu_si128(reinterpret_cast(a + i)); + pb = _mm_loadu_si128(reinterpret_cast(b + i)); + pt = _mm_maddubs_epi16(pa, pb); + sum = _mm_add_epi32(sum, _mm_madd_epi16(pt, one16)); + } + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(sum); + } + }; - static STRONG_INLINE __m128 redmaxbf(__m128 a) - { - __m128 tmp = _mm_max_ps(a, _mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 0, 3, 2))); - return _mm_max_ps(tmp, _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); - } + template<> + struct Operator : public OperatorImpl> + { }; #endif @@ -274,16 +417,15 @@ namespace kiwi using FloatPacket = __m256; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m256 addf(__m256 a, __m256 b) { return _mm256_add_ps(a, b); } static STRONG_INLINE __m256 subf(__m256 a, __m256 b) { return _mm256_sub_ps(a, b); } static STRONG_INLINE __m256 mulf(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); } static STRONG_INLINE __m256 divf(__m256 a, __m256 b) { return _mm256_div_ps(a, b); } static STRONG_INLINE __m256 maddf(__m256 a, __m256 b, __m256 c) { return _mm256_fmadd_ps(a, b, c); } + static STRONG_INLINE __m256i set1i(int32_t a) { return _mm256_set1_epi32(a); } static STRONG_INLINE __m256 set1f(float a) { return _mm256_set1_ps(a); } static STRONG_INLINE __m256 loadf(const float* a) { return _mm256_load_ps(a); } static STRONG_INLINE void storef(float* a, __m256 b) { return _mm256_store_ps(a, b); } @@ -293,8 +435,11 @@ namespace kiwi static STRONG_INLINE __m256 zerof() { return _mm256_setzero_ps(); } static STRONG_INLINE __m256 negatef(__m256 a) { return subf(zerof(), a); } static STRONG_INLINE __m256i cast_to_int(__m256 a) { return _mm256_cvtps_epi32(a); } + static STRONG_INLINE __m256 cast_to_float(__m256i a) { return _mm256_cvtepi32_ps(a); } + static STRONG_INLINE __m256i reinterpret_as_int(__m256 a) { return _mm256_castps_si256(a); } static STRONG_INLINE __m256 reinterpret_as_float(__m256i a) { return _mm256_castsi256_ps(a); } template static STRONG_INLINE __m256i sll(__m256i a) { return _mm256_slli_epi32(a, bit); } + template static STRONG_INLINE __m256i srl(__m256i a) { return _mm256_srli_epi32(a, bit); } static STRONG_INLINE float firstf(__m256 a) { return _mm256_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m256 a) @@ -314,8 +459,155 @@ namespace kiwi __m256 tmp = _mm256_max_ps(a, _mm256_permute2f128_ps(a, a, 1)); tmp = _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2))); return _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(2, 3, 0, 1))); + } + + static STRONG_INLINE __m256 band(__m256 a, __m256 b) { return _mm256_and_ps(a, b); } + static STRONG_INLINE __m256i band(__m256i a, __m256i b) { return _mm256_and_si256(a, b); } + + static STRONG_INLINE __m256 bor(__m256 a, __m256 b) { return _mm256_or_ps(a, b); } + static STRONG_INLINE __m256i bor(__m256i a, __m256i b) { return _mm256_or_si256(a, b); } + + static STRONG_INLINE __m256 select(__m256 mask, __m256 a, __m256 b) + { + return _mm256_blendv_ps(b, a, mask); + } + static STRONG_INLINE __m256i select(__m256i mask, __m256i a, __m256i b) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(b), _mm256_castsi256_ps(a), _mm256_castsi256_ps(mask))); + } + + static STRONG_INLINE __m256 cmp_eq(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_EQ_OQ); } + static STRONG_INLINE __m256 cmp_le(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_LE_OQ); } + static STRONG_INLINE __m256 cmp_lt(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_LT_OQ); } + static STRONG_INLINE __m256 cmp_lt_or_nan(__m256 a, __m256 b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } + + static STRONG_INLINE void load_transposed(const float* a, size_t stride, + __m256& r0, __m256& r1, __m256& r2, __m256& r3, + __m256& r4, __m256& r5, __m256& r6, __m256& r7 + ) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + + r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[0 * stride + 0])), _mm_load_ps(&a[4 * stride + 0]), 1); + r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[1 * stride + 0])), _mm_load_ps(&a[5 * stride + 0]), 1); + r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[2 * stride + 0])), _mm_load_ps(&a[6 * stride + 0]), 1); + r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[3 * stride + 0])), _mm_load_ps(&a[7 * stride + 0]), 1); + r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[0 * stride + 4])), _mm_load_ps(&a[4 * stride + 4]), 1); + r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[1 * stride + 4])), _mm_load_ps(&a[5 * stride + 4]), 1); + r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[2 * stride + 4])), _mm_load_ps(&a[6 * stride + 4]), 1); + r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&a[3 * stride + 4])), _mm_load_ps(&a[7 * stride + 4]), 1); + + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0, t2, 0x44); + r1 = _mm256_shuffle_ps(t0, t2, 0xEE); + r2 = _mm256_shuffle_ps(t1, t3, 0x44); + r3 = _mm256_shuffle_ps(t1, t3, 0xEE); + r4 = _mm256_shuffle_ps(t4, t6, 0x44); + r5 = _mm256_shuffle_ps(t4, t6, 0xEE); + r6 = _mm256_shuffle_ps(t5, t7, 0x44); + r7 = _mm256_shuffle_ps(t5, t7, 0xEE); + } + + static STRONG_INLINE void store_transposed(float* a, size_t stride, + __m256 r0, __m256 r1, __m256 r2, __m256 r3, + __m256 r4, __m256 r5, __m256 r6, __m256 r7 + ) + { + __m256 t0 = _mm256_unpacklo_ps(r0, r1); + __m256 t1 = _mm256_unpackhi_ps(r0, r1); + __m256 t2 = _mm256_unpacklo_ps(r2, r3); + __m256 t3 = _mm256_unpackhi_ps(r2, r3); + __m256 t4 = _mm256_unpacklo_ps(r4, r5); + __m256 t5 = _mm256_unpackhi_ps(r4, r5); + __m256 t6 = _mm256_unpacklo_ps(r6, r7); + __m256 t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0)); + r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2)); + r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0)); + r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2)); + r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0)); + r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2)); + r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0)); + r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2)); + + t0 = _mm256_permute2f128_ps(r0, r4, 0x20); + t1 = _mm256_permute2f128_ps(r1, r5, 0x20); + t2 = _mm256_permute2f128_ps(r2, r6, 0x20); + t3 = _mm256_permute2f128_ps(r3, r7, 0x20); + t4 = _mm256_permute2f128_ps(r0, r4, 0x31); + t5 = _mm256_permute2f128_ps(r1, r5, 0x31); + t6 = _mm256_permute2f128_ps(r2, r6, 0x31); + t7 = _mm256_permute2f128_ps(r3, r7, 0x31); + + _mm256_store_ps(&a[0 * stride], t0); + _mm256_store_ps(&a[1 * stride], t1); + _mm256_store_ps(&a[2 * stride], t2); + _mm256_store_ps(&a[3 * stride], t3); + _mm256_store_ps(&a[4 * stride], t4); + _mm256_store_ps(&a[5 * stride], t5); + _mm256_store_ps(&a[6 * stride], t6); + _mm256_store_ps(&a[7 * stride], t7); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m256i pa, pb, acc = _mm256_setzero_si256(); + __m256i one16 = _mm256_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 32) + { + pa = _mm256_loadu_si256(reinterpret_cast(&a[i])); + pb = _mm256_loadu_si256(reinterpret_cast(&b[i])); + pt = _mm256_maddubs_epi16(pa, pb); + acc = _mm256_add_epi32(acc, _mm256_madd_epi16(pt, one16)); + } + // reduce sum of eight int32_t to one int32_t + __m256i sum = _mm256_hadd_epi32(acc, acc); + sum = _mm256_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(_mm256_castsi256_si128(sum)) + _mm256_extract_epi32(sum, 4); } }; + + template<> + struct Operator : public OperatorImpl> + { + }; + + template<> + struct PacketTrait : public PacketTrait + { + }; + + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m256i pa, pb, acc = _mm256_setzero_si256(); + for (size_t i = 0; i < size; i += 32) + { + pa = _mm256_loadu_si256(reinterpret_cast(&a[i])); + pb = _mm256_loadu_si256(reinterpret_cast(&b[i])); + acc = _mm256_dpbusd_epi32(acc, pa, pb); + } + // reduce sum of eight int32_t to one int32_t + __m256i sum = _mm256_hadd_epi32(acc, acc); + sum = _mm256_hadd_epi32(sum, sum); + return _mm_cvtsi128_si32(_mm256_castsi256_si128(sum)) + _mm256_extract_epi32(sum, 4); + } + }; + + template<> + struct Operator : public OperatorImpl> + { + }; #endif #if defined(_MSC_VER) || defined(__AVX512F__) || defined(__AVX512BW__) @@ -327,16 +619,15 @@ namespace kiwi using FloatPacket = __m512; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE __m512 addf(__m512 a, __m512 b) { return _mm512_add_ps(a, b); } static STRONG_INLINE __m512 subf(__m512 a, __m512 b) { return _mm512_sub_ps(a, b); } static STRONG_INLINE __m512 mulf(__m512 a, __m512 b) { return _mm512_mul_ps(a, b); } static STRONG_INLINE __m512 divf(__m512 a, __m512 b) { return _mm512_div_ps(a, b); } static STRONG_INLINE __m512 maddf(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } + static STRONG_INLINE __m512i set1i(int32_t a) { return _mm512_set1_epi32(a); } static STRONG_INLINE __m512 set1f(float a) { return _mm512_set1_ps(a); } static STRONG_INLINE __m512 loadf(const float* a) { return _mm512_load_ps(a); } static STRONG_INLINE void storef(float* a, __m512 b) { return _mm512_store_ps(a, b); } @@ -346,8 +637,11 @@ namespace kiwi static STRONG_INLINE __m512 zerof() { return _mm512_setzero_ps(); } static STRONG_INLINE __m512 negatef(__m512 a) { return subf(zerof(), a); } static STRONG_INLINE __m512i cast_to_int(__m512 a) { return _mm512_cvtps_epi32(a); } + static STRONG_INLINE __m512 cast_to_float(__m512i a) { return _mm512_cvtepi32_ps(a); } + static STRONG_INLINE __m512i reinterpret_as_int(__m512 a) { return _mm512_castps_si512(a); } static STRONG_INLINE __m512 reinterpret_as_float(__m512i a) { return _mm512_castsi512_ps(a); } template static STRONG_INLINE __m512i sll(__m512i a) { return _mm512_slli_epi32(a, bit); } + template static STRONG_INLINE __m512i srl(__m512i a) { return _mm512_srli_epi32(a, bit); } static STRONG_INLINE float firstf(__m512 a) { return _mm512_cvtss_f32(a); } static STRONG_INLINE float redsumf(__m512 a) @@ -376,8 +670,91 @@ namespace kiwi static STRONG_INLINE __m512 redmaxbf(__m512 a) { return set1f(redmaxf(a)); + } + + static STRONG_INLINE __m512 band(__m512 a, __m512 b) { return _mm512_and_ps(a, b); } + static STRONG_INLINE __m512i band(__m512i a, __m512i b) { return _mm512_and_si512(a, b); } + + static STRONG_INLINE __m512 bor(__m512 a, __m512 b) { return _mm512_or_ps(a, b); } + static STRONG_INLINE __m512i bor(__m512i a, __m512i b) { return _mm512_or_si512(a, b); } + + static STRONG_INLINE __m512 select(__m512 mask, __m512 a, __m512 b) + { + __mmask16 mask16 = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mask16, a, b); + } + static STRONG_INLINE __m512i select(__m512i mask, __m512i a, __m512i b) + { + __mmask16 mask16 = _mm512_cmp_epi32_mask(mask, _mm512_setzero_si512(), _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask16, a, b); + } + + static STRONG_INLINE __m512 cmp_eq(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_le(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_lt(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + static STRONG_INLINE __m512 cmp_lt_or_nan(__m512 a, __m512 b) + { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); + } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m512i pa, pb, acc = _mm512_setzero_si512(); + __m512i one16 = _mm512_set1_epi16(1), pt; + for (size_t i = 0; i < size; i += 64) + { + pa = _mm512_loadu_si512(reinterpret_cast(&a[i])); + pb = _mm512_loadu_si512(reinterpret_cast(&b[i])); + pt = _mm512_maddubs_epi16(pa, pb); + acc = _mm512_add_epi32(acc, _mm512_madd_epi16(pt, one16)); + } + return _mm512_reduce_add_epi32(acc); } }; + + template<> + struct Operator : public OperatorImpl> + { + }; + + template<> + struct PacketTrait : public PacketTrait + { + }; + + template + struct OperatorImpl : public OperatorImpl + { + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + __m512i pa, pb, acc = _mm512_setzero_si512(); + for (size_t i = 0; i < size; i += 64) + { + pa = _mm512_loadu_si512(reinterpret_cast(&a[i])); + pb = _mm512_loadu_si512(reinterpret_cast(&b[i])); + acc = _mm512_dpbusd_epi32(acc, pa, pb); + } + return _mm512_reduce_add_epi32(acc); + } + }; + + template<> + struct Operator : public OperatorImpl> + { + }; #endif } } @@ -396,11 +773,9 @@ namespace kiwi using FloatPacket = float32x4_t; }; - template<> - class Operator : public OperatorBase> + template + struct OperatorImpl : public OperatorBase { - public: - static STRONG_INLINE float32x4_t addf(float32x4_t a, float32x4_t b) { return vaddq_f32(a, b); } static STRONG_INLINE float32x4_t subf(float32x4_t a, float32x4_t b) { return vsubq_f32(a, b); } static STRONG_INLINE float32x4_t mulf(float32x4_t a, float32x4_t b) { return vmulq_f32(a, b); } @@ -435,6 +810,41 @@ namespace kiwi { return set1f(redmaxf(a)); } + + static STRONG_INLINE int32_t dotprod(const uint8_t* a, const int8_t* b, size_t size) + { + int32x4_t sum = vdupq_n_s32(0); + uint16x8_t pa; + int8x16_t pb; + for (size_t i = 0; i < size; i += 16) + { + // + } + sum = vpaddq_s32(sum, sum); + sum = vpaddq_s32(sum, sum); + return vgetq_lane_s32(sum, 0); + } + + static STRONG_INLINE int32_t dotprod(const int8_t* a, const int8_t* b, size_t size) + { + int32x4_t sum = vdupq_n_s32(0); + int8x16_t pa, pb; + for (size_t i = 0; i < size; i += 16) + { + pa = vld1q_s8(a + i); + pb = vld1q_s8(b + i); + sum = vpadalq_s16(sum, vmull_s8(vget_low_s8(pb), vget_low_s8(pa))); + sum = vpadalq_s16(sum, vmull_s8(vget_high_s8(pb), vget_high_s8(pa))); + } + sum = vpaddq_s32(sum, sum); + sum = vpaddq_s32(sum, sum); + return vgetq_lane_s32(sum, 0); + } + }; + + template<> + struct Operator : public OperatorImpl> + { }; } } diff --git a/src/SkipBigramModel.cpp b/src/SkipBigramModel.cpp new file mode 100644 index 00000000..36d731fb --- /dev/null +++ b/src/SkipBigramModel.cpp @@ -0,0 +1,113 @@ +#include "PathEvaluator.hpp" +#include "Joiner.hpp" +#include "Kiwi.hpp" +#include "SkipBigramModel.hpp" + +namespace kiwi +{ + template + struct PathHash> + { + using LmState = lm::SbgState; + + lm::KnLMState<_arch, VocabTy> lmState; + std::array lastMorphemes; + uint8_t rootId, spState; + + PathHash(LmState _lmState = {}, uint8_t _rootId = 0, SpecialState _spState = {}) + : lmState{ _lmState.knlm }, rootId{ _rootId }, spState{ _spState } + { + _lmState.getLastHistory(lastMorphemes.data(), lastMorphemes.size()); + } + + + PathHash(const WordLL& wordLl, const Morpheme* morphBase) + : PathHash{ wordLl.lmState, wordLl.rootId, wordLl.spState } + { + } + + bool operator==(const PathHash& o) const + { + return lmState == o.lmState && lastMorphemes == o.lastMorphemes && spState == o.spState; + } + }; + + template + struct Hash>> + { + size_t operator()(const PathHash>& state) const + { + size_t ret = 0; + if (sizeof(state) % sizeof(size_t)) + { + auto ptr = reinterpret_cast(&state); + for (size_t i = 0; i < sizeof(state) / sizeof(uint32_t); ++i) + { + ret = ptr[i] ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + } + else + { + auto ptr = reinterpret_cast(&state); + for (size_t i = 0; i < sizeof(state) / sizeof(size_t); ++i) + { + ret = ptr[i] ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + } + return ret; + } + }; + + namespace lm + { + template + void* SkipBigramModel::getFindBestPathFn() const + { + return (void*)&BestPathFinder>::findBestPath; + } + + template + void* SkipBigramModel::getNewJoinerFn() const + { + return (void*)&newJoinerWithKiwi; + } + + template + std::unique_ptr createOptimizedModel(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem) + { + auto& header = *reinterpret_cast(sbgMem.get()); + switch (header.keySize) + { + case 1: + return make_unique>(std::move(knlmMem), std::move(sbgMem)); + case 2: + return make_unique>(std::move(knlmMem), std::move(sbgMem)); + case 4: + return make_unique>(std::move(knlmMem), std::move(sbgMem)); + case 8: + return make_unique>(std::move(knlmMem), std::move(sbgMem)); + default: + throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; + } + } + + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + + struct CreateOptimizedModelGetter + { + template + struct Wrapper + { + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + }; + }; + + std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem, ArchType archType) + { + static tp::Table table{ CreateOptimizedModelGetter{} }; + auto fn = table[static_cast(archType)]; + if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; + return (*fn)(std::move(knlmMem), std::move(sbgMem)); + } + } +} diff --git a/src/SkipBigramModel.hpp b/src/SkipBigramModel.hpp index 160afd3a..598b6355 100644 --- a/src/SkipBigramModel.hpp +++ b/src/SkipBigramModel.hpp @@ -5,15 +5,23 @@ #include #include #include "ArchAvailable.h" +#include "Knlm.hpp" +#include "MathFunc.h" #include "search.h" namespace kiwi { - namespace sb + namespace lm { + template + class SbgState; + template class SkipBigramModel : public SkipBigramModelBase { + friend class SbgState; + + KnLangModel knlm; std::unique_ptr ptrs; std::unique_ptr restoredFloats; std::unique_ptr keyData; @@ -22,12 +30,19 @@ namespace kiwi const float* compensations = nullptr; float logWindowSize; public: - SkipBigramModel(utils::MemoryObject&& mem) : SkipBigramModelBase{ std::move(mem) } + using VocabType = KeyType; + using LmStateType = SbgState; + + size_t getMemorySize() const override { return base.size() + knlm.getMemorySize(); } + void* getFindBestPathFn() const override; + void* getNewJoinerFn() const override; + + SkipBigramModel(utils::MemoryObject&& knlmMem, utils::MemoryObject&& sbgMem) : SkipBigramModelBase{ std::move(sbgMem) }, knlm{ std::move(knlmMem) } { auto* ptr = reinterpret_cast(base.get()); auto& header = getHeader(); - const KeyType* kSizes = reinterpret_cast(ptr += sizeof(Header)); + const KeyType* kSizes = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); ptrs = make_unique(header.vocabSize + 1); ptrs[0] = 0; for (size_t i = 0; i < header.vocabSize; ++i) @@ -95,45 +110,94 @@ namespace kiwi return !!vocabValidness[k]; } - float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const; + float evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const + { + if (!cnt) return base; + if (!vocabValidness[next]) return base; + +#if defined(__GNUC__) && __GNUC__ < 5 + alignas(256) float arr[windowSize * 2]; +#else + alignas(ArchInfo::alignment) float arr[windowSize * 2]; +#endif + std::fill(arr, arr + windowSize, base); + std::fill(arr + windowSize, arr + windowSize * 2, -INFINITY); + + size_t b = ptrs[next], e = ptrs[next + 1]; + size_t size = e - b; + + for (size_t i = 0; i < cnt; ++i) + { + arr[i] = discnts[history[i]] + base; + float out; + if (nst::search(&keyData[b], &compensations[b], size, history[i], out)) + { + arr[i + windowSize] = out; + } + } + return logSumExp(arr, windowSize * 2) - logWindowSize; + } + }; - template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + template + struct SbgState : public LmStateBase> { - auto& header = *reinterpret_cast(mem.get()); - switch (header.keySize) + KnLMState<_arch, VocabTy> knlm; + size_t historyPos = 0; + std::array history = { {0,} }; + + static constexpr ArchType arch = _arch; + static constexpr bool transposed = false; + + SbgState() = default; + SbgState(const ILangModel* lm) : knlm{ &static_cast*>(lm)->knlm } {} + + bool operator==(const SbgState& other) const { - case 1: - return make_unique>(std::move(mem)); - case 2: - return make_unique>(std::move(mem)); - case 4: - return make_unique>(std::move(mem)); - case 8: - return make_unique>(std::move(mem)); - default: - throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.keySize) }; + return knlm == other.knlm && historyPos == other.historyPos && history == other.history; } - } - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + void getLastHistory(VocabTy* out, size_t n) const + { + for (size_t i = 0; i < n; ++i) + { + out[i] = history[(historyPos + windowSize + i - n) % windowSize]; + } + } - struct CreateOptimizedModelGetter - { - template - struct Wrapper + float nextImpl(const SkipBigramModel<_arch, VocabTy, windowSize>* lm, VocabTy next) { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; - }; + float ll = lm->knlm.progress(knlm.node, next); + if (lm->isValidVocab(next)) + { + if (ll > -13) + { + ll = lm->evaluate(history.data(), windowSize, next, ll); + } + history[historyPos] = next; + historyPos = (historyPos + 1) % windowSize; + } + return ll; + } }; + } - inline std::unique_ptr SkipBigramModelBase::create(utils::MemoryObject&& mem, ArchType archType) + + template + struct Hash> + { + size_t operator()(const lm::SbgState& state) const { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; - if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; - return (*fn)(std::move(mem)); + Hash> hasher; + std::hash vocabHasher; + size_t ret = hasher(state.knlm); + for (size_t i = 0; i < windowSize; ++i) + { + ret = vocabHasher(state.history[i]) ^ ((ret << 3) | (ret >> (sizeof(size_t) * 8 - 3))); + } + return ret; } - } + }; + } diff --git a/src/SkipBigramModelImpl.hpp b/src/SkipBigramModelImpl.hpp deleted file mode 100644 index 887ab5f9..00000000 --- a/src/SkipBigramModelImpl.hpp +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -#include -#include "SkipBigramModel.hpp" -#include "SIMD.hpp" - -namespace kiwi -{ - namespace sb - { - template - struct LogExpSum - { - template - float operator()(const float* arr, std::integral_constant) - { - float maxValue = *std::max_element(arr, arr + size); - float sum = 0; - for (size_t i = 0; i < size; ++i) - { - sum += std::exp(arr[i] - maxValue); - } - return std::log(sum) + maxValue; - } - }; - - template - float logExpSumImpl(const float* arr) - { - simd::Operator op; - - auto pmax = op.loadf(arr); - for (size_t i = op.packetSize; i < size; i += op.packetSize) - { - pmax = op.maxf(pmax, op.loadf(&arr[i])); - } - pmax = op.redmaxbf(pmax); - - auto sum = op.zerof(); - for (size_t i = 0; i < size; i += op.packetSize) - { - sum = op.addf(sum, op.expf(op.subf(op.loadf(&arr[i]), pmax))); - } - return std::log(op.redsumf(sum)) + op.firstf(pmax); - } - - template - float SkipBigramModel::evaluate(const KeyType* history, size_t cnt, KeyType next, float base) const - { - if (!cnt) return base; - if (!vocabValidness[next]) return base; - -#if defined(__GNUC__) && __GNUC__ < 5 - alignas(256) float arr[windowSize * 2]; -#else - alignas(ArchInfo::alignment) float arr[windowSize * 2]; -#endif - std::fill(arr, arr + windowSize, base); - std::fill(arr + windowSize, arr + windowSize * 2, -INFINITY); - - size_t b = ptrs[next], e = ptrs[next + 1]; - size_t size = e - b; - - for (size_t i = 0; i < cnt; ++i) - { - arr[i] = discnts[history[i]] + base; - float out; - if (nst::search(&keyData[b], &compensations[b], size, history[i], out)) - { - arr[i + windowSize] = out; - } - } - return LogExpSum{}(arr, std::integral_constant{}) - logWindowSize; - } - } -} diff --git a/src/SkipBigramTrainer.hpp b/src/SkipBigramTrainer.hpp index 9195a607..b9bcf7c3 100644 --- a/src/SkipBigramTrainer.hpp +++ b/src/SkipBigramTrainer.hpp @@ -12,7 +12,7 @@ namespace kiwi { - namespace sb + namespace lm { struct TrainContext { @@ -768,7 +768,7 @@ namespace kiwi utils::MemoryOwner convertToModel(float trimThreshold = -15, bool quantize = true) const { - Header header = { 0, }; + SkipBigramModelHeader header = { 0, }; header.vocabSize = ptrs.size() - 1; header.keySize = sizeof(VocabTy); header.windowSize = windowSize; @@ -821,7 +821,7 @@ namespace kiwi mse = nuq::nuquant(compensationTable.data(), allCompensations, 256); std::transform(compensationTable.begin(), compensationTable.end(), compensationTable.begin(), [](float f) { return -std::pow(f, 16.f); }); - size_t totalModelSize = sizeof(Header); + size_t totalModelSize = sizeof(SkipBigramModelHeader); totalModelSize += header.vocabSize * sizeof(VocabTy); totalModelSize += finalVocabSize * sizeof(VocabTy); totalModelSize += header.vocabSize * sizeof(uint8_t); @@ -832,8 +832,8 @@ namespace kiwi utils::MemoryOwner ret{ totalModelSize }; auto* ptr = reinterpret_cast(ret.get()); - *reinterpret_cast(ptr) = header; - auto* ks = reinterpret_cast(ptr += sizeof(Header)); + *reinterpret_cast(ptr) = header; + auto* ks = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); for (auto& v : compensations) { *ks++ = v.first.size(); @@ -870,7 +870,7 @@ namespace kiwi } else { - size_t totalModelSize = sizeof(Header); + size_t totalModelSize = sizeof(SkipBigramModelHeader); totalModelSize += header.vocabSize * sizeof(VocabTy); totalModelSize += finalVocabSize * sizeof(VocabTy); totalModelSize += header.vocabSize * sizeof(float); @@ -879,8 +879,8 @@ namespace kiwi utils::MemoryOwner ret{ totalModelSize }; auto* ptr = reinterpret_cast(ret.get()); - *reinterpret_cast(ptr) = header; - auto* ks = reinterpret_cast(ptr += sizeof(Header)); + *reinterpret_cast(ptr) = header; + auto* ks = reinterpret_cast(ptr += sizeof(SkipBigramModelHeader)); for (auto& v : compensations) { *ks++ = v.first.size(); diff --git a/src/StrUtils.h b/src/StrUtils.h index 73b6390a..142bc9ad 100644 --- a/src/StrUtils.h +++ b/src/StrUtils.h @@ -3,7 +3,7 @@ #include #include #include -#include "string_view.hpp" +#include namespace kiwi { @@ -76,7 +76,7 @@ namespace kiwi size_t t = s.find(delim, p); if (t == s.npos) { - *(result++) = nonstd::basic_string_view{ &s[e] , s.size() - e}; + *(result++) = std::basic_string_view{ &s[e] , s.size() - e}; return result; } else @@ -91,28 +91,28 @@ namespace kiwi } else { - *(result++) = nonstd::basic_string_view{ &s[e] , t - e }; + *(result++) = std::basic_string_view{ &s[e] , t - e }; p = t + 1; e = t + 1; } } } - *(result++) = nonstd::basic_string_view{ &s[e] , s.size() - e }; + *(result++) = std::basic_string_view{ &s[e] , s.size() - e }; return result; } template - inline std::vector> split(nonstd::basic_string_view s, BaseChr delim, BaseChr delimEscape = 0) + inline std::vector> split(std::basic_string_view s, BaseChr delim, BaseChr delimEscape = 0) { - std::vector> ret; + std::vector> ret; split(s, delim, std::back_inserter(ret), -1, delimEscape); return ret; } template - inline std::vector> split(const std::basic_string& s, BaseChr delim, BaseChr delimEscape = 0) + inline std::vector> split(const std::basic_string& s, BaseChr delim, BaseChr delimEscape = 0) { - std::vector> ret; + std::vector> ret; split(s, delim, std::back_inserter(ret), -1, delimEscape); return ret; } @@ -141,9 +141,9 @@ namespace kiwi template> inline std::basic_string replace( - nonstd::basic_string_view s, - nonstd::basic_string_view from, - nonstd::basic_string_view to) + std::basic_string_view s, + std::basic_string_view from, + std::basic_string_view to) { std::basic_string ret; ret.reserve(s.size()); @@ -153,15 +153,15 @@ namespace kiwi template> inline std::basic_string replace( - nonstd::basic_string_view s, + std::basic_string_view s, const BaseChr(&from)[fromSize], const BaseChr(&to)[toSize]) { - return replace(s, nonstd::basic_string_view{ from, fromSize - 1 }, nonstd::basic_string_view{ to, toSize - 1 }); + return replace(s, std::basic_string_view{ from, fromSize - 1 }, std::basic_string_view{ to, toSize - 1 }); } - inline void utf8To16(nonstd::string_view str, std::u16string& ret) + inline void utf8To16(std::string_view str, std::u16string& ret) { ret.clear(); for (auto it = str.begin(); it != str.end(); ++it) @@ -224,7 +224,7 @@ namespace kiwi } } - inline std::u16string utf8To16(nonstd::string_view str) + inline std::u16string utf8To16(std::string_view str) { std::u16string ret; utf8To16(str, ret); @@ -232,7 +232,7 @@ namespace kiwi } template - inline std::u16string utf8To16(nonstd::string_view str, std::vector& bytePositions) + inline std::u16string utf8To16(std::string_view str, std::vector& bytePositions) { std::u16string ret; bytePositions.clear(); @@ -302,7 +302,7 @@ namespace kiwi } template - inline std::u16string utf8To16ChrPoisition(nonstd::string_view str, std::vector& chrPositions) + inline std::u16string utf8To16ChrPoisition(std::string_view str, std::vector& chrPositions) { std::u16string ret; size_t chrPosition = 0; @@ -371,7 +371,7 @@ namespace kiwi return ret; } - inline std::string utf16To8(nonstd::u16string_view str) + inline std::string utf16To8(std::u16string_view str) { std::string ret; for (auto it = str.begin(); it != str.end(); ++it) @@ -417,7 +417,7 @@ namespace kiwi } template - inline std::string utf16To8(nonstd::u16string_view str, std::vector& positions) + inline std::string utf16To8(std::u16string_view str, std::vector& positions) { std::string ret; positions.clear(); @@ -504,7 +504,7 @@ namespace kiwi return normalizeHangul(hangul.begin(), hangul.end()); } - inline KString normalizeHangul(nonstd::u16string_view hangul) + inline KString normalizeHangul(std::u16string_view hangul) { return normalizeHangul(hangul.begin(), hangul.end()); } @@ -553,7 +553,7 @@ namespace kiwi return normalizeHangulWithPosition(hangul.begin(), hangul.end()); } - inline std::pair> normalizeHangulWithPosition(nonstd::u16string_view hangul) + inline std::pair> normalizeHangulWithPosition(std::u16string_view hangul) { return normalizeHangulWithPosition(hangul.begin(), hangul.end()); } @@ -563,12 +563,12 @@ namespace kiwi return normalizeHangul(utf8To16(hangul)); } - inline KString normalizeHangul(nonstd::string_view hangul) + inline KString normalizeHangul(std::string_view hangul) { return normalizeHangul(utf8To16(hangul)); } - inline POSTag toPOSTag(nonstd::u16string_view tagStr) + inline POSTag toPOSTag(std::u16string_view tagStr) { if (tagStr == u"NNG") return POSTag::nng; if (tagStr == u"NNP") return POSTag::nnp; @@ -745,4 +745,10 @@ namespace kiwi || (0x2E80 <= c && c <= 0x2EFF) ; } + + template + inline std::basic_string_view toStringView(const std::basic_string& str) + { + return std::basic_string_view{ str.data(), str.size() }; + } } diff --git a/src/SubstringExtractor.cpp b/src/SubstringExtractor.cpp index 1ac6d5e4..ac238df1 100644 --- a/src/SubstringExtractor.cpp +++ b/src/SubstringExtractor.cpp @@ -266,6 +266,11 @@ namespace kiwi _addArray(first, last); } + void PrefixCounter::addArray(const int32_t* first, const int32_t* last) + { + _addArray(first, last); + } + void PrefixCounter::addArray(const uint64_t* first, const uint64_t* last) { _addArray(first, last); @@ -384,7 +389,7 @@ namespace kiwi utils::MemoryOwner mem; { auto trie = count(); - mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, + mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, 1e-5f, 0, false, nullptr, (const Vector*)nullptr, extraBuf.data(), extraBuf.size()); } diff --git a/src/SwTokenizer.cpp b/src/SwTokenizer.cpp index ad948989..6ef3d29b 100644 --- a/src/SwTokenizer.cpp +++ b/src/SwTokenizer.cpp @@ -576,7 +576,7 @@ namespace kiwi }; }; - inline void utf8To16IgnoringErrors(nonstd::string_view str, std::u16string& ret) + inline void utf8To16IgnoringErrors(std::string_view str, std::u16string& ret) { ret.clear(); for (auto it = str.begin(); it != str.end(); ++it) @@ -675,7 +675,7 @@ namespace kiwi } } - inline std::u16string utf8To16IgnoringErrors(nonstd::string_view str) + inline std::u16string utf8To16IgnoringErrors(std::string_view str) { std::u16string ret; utf8To16IgnoringErrors(str, ret); diff --git a/src/Utils.cpp b/src/Utils.cpp index a1a4f17c..0976377d 100644 --- a/src/Utils.cpp +++ b/src/Utils.cpp @@ -6,12 +6,12 @@ namespace kiwi { std::u16string utf8To16(const std::string & str) { - return utf8To16(nonstd::to_string_view(str)); + return utf8To16(toStringView(str)); } std::u16string utf8To16(const std::string& str, std::vector& bytePositions) { - return utf8To16(nonstd::to_string_view(str), bytePositions); + return utf8To16(toStringView(str), bytePositions); } size_t utf8FromCode(std::string& ret, char32_t code) @@ -54,7 +54,7 @@ namespace kiwi std::string utf16To8(const std::u16string & str) { - return utf16To8(nonstd::to_string_view(str)); + return utf16To8(toStringView(str)); } /** @@ -292,7 +292,7 @@ namespace kiwi POSTag toPOSTag(const std::u16string& tagStr) { - return toPOSTag(nonstd::to_string_view(tagStr)); + return toPOSTag(toStringView(tagStr)); } const char* tagToString(POSTag t) diff --git a/src/WordDetector.cpp b/src/WordDetector.cpp index 61dd759e..01552f13 100644 --- a/src/WordDetector.cpp +++ b/src/WordDetector.cpp @@ -461,7 +461,7 @@ void WordDetector::loadNounTailModelFromTxt(std::istream & is) auto fields = split(utf8To16(line), u'\t'); if (fields.size() < 4) continue; float p = stof(fields[1].begin(), fields[1].end()); - nounTailScore[fields[0].to_string()] = p; + nounTailScore[u16string{ fields[0] }] = p; } } diff --git a/src/archImpl/avx2.cpp b/src/archImpl/avx2.cpp index 76f27d75..ec7c1b5e 100644 --- a/src/archImpl/avx2.cpp +++ b/src/archImpl/avx2.cpp @@ -1,22 +1,124 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenAVX2 +#include namespace kiwi { - namespace sb + namespace qgemm { + // emulate _mm256_dpbusd_epi32 using AVX2 + static FORCE_INLINE __m256i dpbusd(__m256i src, __m256i a, __m256i b) + { + __m256i one16 = _mm256_set1_epi16(1); + __m256i t0 = _mm256_maddubs_epi16(a, b); + __m256i t1 = _mm256_madd_epi16(t0, one16); + return _mm256_add_epi32(src, t1); + } + } +} + +#define DPBUSD dpbusd +#include "avx2_qgemm.hpp" + +namespace kiwi +{ + namespace lm + { + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + template<> - struct LogExpSum + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) { - template - float operator()(const float* arr, std::integral_constant) + return scatteredGEMV2_256(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) { - return logExpSumImpl(arr); + return scatteredGEMMSmall_256(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); } }; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/archImpl/avx2_qgemm.hpp b/src/archImpl/avx2_qgemm.hpp new file mode 100644 index 00000000..7963df37 --- /dev/null +++ b/src/archImpl/avx2_qgemm.hpp @@ -0,0 +1,456 @@ +#pragma once +#include "../qgemm.hpp" +#include + +namespace kiwi +{ + namespace qgemm + { + inline void pack4x32to4x8x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m256i& p0, __m256i& p1, __m256i& p2, __m256i& p3 + ) + { + __m256i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm256_loadu_si256((const __m256i*)a0); + p1 = _mm256_loadu_si256((const __m256i*)a1); + p2 = _mm256_loadu_si256((const __m256i*)a2); + p3 = _mm256_loadu_si256((const __m256i*)a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm256_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm256_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm256_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm256_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm256_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm256_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm256_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm256_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 4, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m256i pa[4], pb, pbs, psum, pbSum; + __m128 paScale, paBias, pbScale, r; + pbScale = _mm_set1_ps(bScale); + pbSum = _mm256_set1_epi32(-bSum / 2); + __m128i pr; + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + pr = _mm_add_epi32(_mm256_castsi256_si128(psum), _mm256_extracti128_si256(psum, 1)); + aIdx += 4; + + paScale = _mm_loadu_ps(aScale); + paBias = _mm_loadu_ps(aBias); + r = _mm_fmadd_ps(_mm_mul_ps(_mm_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm_storeu_ps(c, r); + c += microM; + } + } + + inline void scatteredGEMV8x1_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 8, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m256i pa[4], pb, pbs, psum, pbSum; + __m256 paScale, paBias, pbScale, r; + __m256i pr = _mm256_setzero_si256(); + pbScale = _mm256_set1_ps(bScale); + pbSum = _mm256_set1_epi32(-bSum / 2); + + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + pr = _mm256_add_epi32(psum, _mm256_castsi128_si256(_mm256_extracti128_si256(psum, 1))); + aIdx += 4; + } + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pbs = _mm256_shuffle_epi32(pb, 0x00); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm256_shuffle_epi32(pb, 0x55); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xAA); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm256_shuffle_epi32(pb, 0xFF); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm256_add_epi32(psum, _mm256_castsi128_si256(_mm256_extracti128_si256(psum, 1))); + pr = _mm256_inserti128_si256(pr, _mm256_castsi256_si128(psum), 1); + aIdx += 4; + } + + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + c += 8; + } + + inline void scatteredGEMV2_256( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 2, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 2; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + + __m256i pa[4], pb[2], psum[2], pbSum[2], pt[2]; + __m256 paScale, paBias, pbScale, r; + __m256i pr; + pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); + pbSum[0] = _mm256_set1_epi32(-bSum[0] / 2); + pbSum[1] = _mm256_set1_epi32(-bSum[1] / 2); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t j = 0; j < k; j += 32) + { + pack4x32to4x8x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm256_loadu_si256((const __m256i*)(bBuffer + j)); + pb[1] = _mm256_loadu_si256((const __m256i*)(bBuffer + j + packK)); + psum[0] = DPBUSD(psum[0], pa[0], _mm256_shuffle_epi32(pb[0], 0x00)); + psum[0] = DPBUSD(psum[0], pa[1], _mm256_shuffle_epi32(pb[0], 0x55)); + psum[0] = DPBUSD(psum[0], pa[2], _mm256_shuffle_epi32(pb[0], 0xAA)); + psum[0] = DPBUSD(psum[0], pa[3], _mm256_shuffle_epi32(pb[0], 0xFF)); + psum[1] = DPBUSD(psum[1], pa[0], _mm256_shuffle_epi32(pb[1], 0x00)); + psum[1] = DPBUSD(psum[1], pa[1], _mm256_shuffle_epi32(pb[1], 0x55)); + psum[1] = DPBUSD(psum[1], pa[2], _mm256_shuffle_epi32(pb[1], 0xAA)); + psum[1] = DPBUSD(psum[1], pa[3], _mm256_shuffle_epi32(pb[1], 0xFF)); + aPtr += 32; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm256_add_epi32(psum[0], _mm256_castsi128_si256(_mm256_extracti128_si256(psum[0], 1))); + psum[1] = _mm256_add_epi32(psum[1], _mm256_castsi128_si256(_mm256_extracti128_si256(psum[1], 1))); + + // 00, 01, 10, 11, ... + pt[0] = _mm256_unpacklo_epi32(psum[0], psum[1]); + // 20, 21, 30, 31, ... + pt[1] = _mm256_unpackhi_epi32(psum[0], psum[1]); + + // 00, 01, 10, 11, 20, 21, 30, 31 + pr = _mm256_inserti128_si256(pt[0], _mm256_castsi256_si128(pt[1]), 1); + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + + aIdx += microM; + c += microM * 2; + } + } + + inline int32_t reduce_sum(__m128i x) + { + __m128i hi64 = _mm_unpackhi_epi64(x, x); + __m128i sum64 = _mm_add_epi32(hi64, x); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i sum32 = _mm_add_epi32(sum64, hi32); + return _mm_cvtsi128_si32(sum32); + } + + inline int32_t reduce_sum(__m256i x) + { + __m128i sum128 = _mm_add_epi32( + _mm256_castsi256_si128(x), + _mm256_extracti128_si256(x, 1)); + return reduce_sum(sum128); + } + + template + inline void scatteredGEMMSmall_256(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + static_assert(m <= 3, "m should be less than or equal to 3"); + static_assert(n <= 3, "n should be less than or equal to 3"); + __m256i pa[3], pb[3], psum[3][3]; + const uint8_t* aPtr[3]; + const int8_t* bPtr[3]; + + psum[0][0] = _mm256_setzero_si256(); + if (m > 1) psum[1][0] = _mm256_setzero_si256(); + if (m > 2) psum[2][0] = _mm256_setzero_si256(); + if (n > 1) psum[0][1] = _mm256_setzero_si256(); + if (m > 1 && n > 1) psum[1][1] = _mm256_setzero_si256(); + if (m > 2 && n > 1) psum[2][1] = _mm256_setzero_si256(); + if (n > 2) psum[0][2] = _mm256_setzero_si256(); + if (m > 1 && n > 2) psum[1][2] = _mm256_setzero_si256(); + if (m > 2 && n > 2) psum[2][2] = _mm256_setzero_si256(); + + aPtr[0] = aBase + aIdx[0] * aIdxScale; + if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; + if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; + + bPtr[0] = bBase + bIdx[0] * bIdxScale; + if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; + if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; + + for (size_t x = 0; x < k; x += 32) + { + if (m > 0) + { + pa[0] = _mm256_loadu_si256((const __m256i*)aPtr[0]); + aPtr[0] += 32; + } + if (m > 1) + { + pa[1] = _mm256_loadu_si256((const __m256i*)aPtr[1]); + aPtr[1] += 32; + } + if (m > 2) + { + pa[2] = _mm256_loadu_si256((const __m256i*)aPtr[2]); + aPtr[2] += 32; + } + + if (n > 0) + { + pb[0] = _mm256_loadu_si256((const __m256i*)bPtr[0]); + bPtr[0] += 32; + } + if (n > 1) + { + pb[1] = _mm256_loadu_si256((const __m256i*)bPtr[1]); + bPtr[1] += 32; + } + if (n > 2) + { + pb[2] = _mm256_loadu_si256((const __m256i*)bPtr[2]); + bPtr[2] += 32; + } + + psum[0][0] = DPBUSD(psum[0][0], pa[0], pb[0]); + if (m > 1) psum[1][0] = DPBUSD(psum[1][0], pa[1], pb[0]); + if (m > 2) psum[2][0] = DPBUSD(psum[2][0], pa[2], pb[0]); + if (n > 1) psum[0][1] = DPBUSD(psum[0][1], pa[0], pb[1]); + if (m > 1 && n > 1) psum[1][1] = DPBUSD(psum[1][1], pa[1], pb[1]); + if (m > 2 && n > 1) psum[2][1] = DPBUSD(psum[2][1], pa[2], pb[1]); + if (n > 2) psum[0][2] = DPBUSD(psum[0][2], pa[0], pb[2]); + if (m > 1 && n > 2) psum[1][2] = DPBUSD(psum[1][2], pa[1], pb[2]); + if (m > 2 && n > 2) psum[2][2] = DPBUSD(psum[2][2], pa[2], pb[2]); + } + + float contextScale[3], outputScale[3], contextBias[3]; + int32_t hsum[3]; + + if (m > 0) + { + contextScale[0] = *reinterpret_cast(aPtr[0]); + contextBias[0] = *reinterpret_cast(aPtr[0] + 4); + } + if (m > 1) + { + contextScale[1] = *reinterpret_cast(aPtr[1]); + contextBias[1] = *reinterpret_cast(aPtr[1] + 4); + } + if (m > 2) + { + contextScale[2] = *reinterpret_cast(aPtr[2]); + contextBias[2] = *reinterpret_cast(aPtr[2] + 4); + } + + + if (n > 0) + { + outputScale[0] = *reinterpret_cast(bPtr[0]); + hsum[0] = *reinterpret_cast(bPtr[0] + 4); + } + if (n > 1) + { + outputScale[1] = *reinterpret_cast(bPtr[1]); + hsum[1] = *reinterpret_cast(bPtr[1] + 4); + } + if (n > 2) + { + outputScale[2] = *reinterpret_cast(bPtr[2]); + hsum[2] = *reinterpret_cast(bPtr[2] + 4); + } + + { + int32_t acc = reduce_sum(psum[0][0]); + c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; + } + if (m > 1) + { + int32_t acc = reduce_sum(psum[1][0]); + c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; + } + if (m > 2) + { + int32_t acc = reduce_sum(psum[2][0]); + c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; + } + if (n > 1) + { + int32_t acc = reduce_sum(psum[0][1]); + c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; + } + if (m > 1 && n > 1) + { + int32_t acc = reduce_sum(psum[1][1]); + c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; + } + if (m > 2 && n > 1) + { + int32_t acc = reduce_sum(psum[2][1]); + c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; + } + if (n > 2) + { + int32_t acc = reduce_sum(psum[0][2]); + c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; + } + if (m > 1 && n > 2) + { + int32_t acc = reduce_sum(psum[1][2]); + c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; + } + if (m > 2 && n > 2) + { + int32_t acc = reduce_sum(psum[2][2]); + c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; + } + } + } +} diff --git a/src/archImpl/avx512_qgemm.hpp b/src/archImpl/avx512_qgemm.hpp new file mode 100644 index 00000000..f5eedce4 --- /dev/null +++ b/src/archImpl/avx512_qgemm.hpp @@ -0,0 +1,673 @@ +#pragma once +#include "../qgemm.hpp" +#include + +#define UNROLL4() do { {LOOP_BODY(0)} {LOOP_BODY(1)} {LOOP_BODY(2)} {LOOP_BODY(3)} } while(0) + +namespace kiwi +{ + namespace qgemm + { + inline void pack4x64to4x16x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m512i& p0, __m512i& p1, __m512i& p2, __m512i& p3 + ) + { + __m512i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm512_loadu_si512(a0); + p1 = _mm512_loadu_si512(a1); + p2 = _mm512_loadu_si512(a2); + p3 = _mm512_loadu_si512(a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm512_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm512_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm512_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm512_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm512_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm512_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm512_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm512_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 16, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum, pr = _mm512_setzero_si512(); + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); +#define LOOP_BODY(mj) \ + const int32_t aOffsets[4] = {\ + mj * 4 < microM ? (int32_t)(aIdx[0] * aIdxScale) : 0,\ + mj * 4 + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0,\ + mj * 4 + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0,\ + mj * 4 + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0,\ + };\ + auto* aPtr = aBase;\ + psum = pbSum;\ + for (size_t j = 0; j < k; j += 64)\ + {\ + pack4x64to4x16x4(aPtr + aOffsets[0],\ + aPtr + aOffsets[1],\ + aPtr + aOffsets[2],\ + aPtr + aOffsets[3],\ + pa[0], pa[1], pa[2], pa[3]);\ + pb = _mm512_loadu_si512(bBuffer + j);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA);\ + psum = DPBUSD(psum, pa[0], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB);\ + psum = DPBUSD(psum, pa[1], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC);\ + psum = DPBUSD(psum, pa[2], pbs);\ + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD);\ + psum = DPBUSD(psum, pa[3], pbs);\ + aPtr += 64;\ + }\ + for (size_t i = 0; i < 4; ++i)\ + {\ + aScale[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i]);\ + aBias[mj * 4 + i] = *reinterpret_cast(aPtr + aOffsets[i] + 4);\ + }\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4));\ + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8));\ + pr = _mm512_inserti32x4(pr, _mm512_castsi512_si128(psum), mj);\ + aIdx += 4; + + UNROLL4(); +#undef LOOP_BODY + + paScale = _mm512_loadu_ps(aScale); + paBias = _mm512_loadu_ps(aBias); + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + c += microM; + } + } + + inline void scatteredGEMV8x1_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 8, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m512i pa[4], pb, pbs, psum, pbSum; + __m256 paScale, paBias, pbScale, r; + __m256i pr = _mm256_setzero_si256(); + pbScale = _mm256_set1_ps(bScale); + pbSum = _mm512_set1_epi32(-bSum / 4); + + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm512_castsi512_si256(psum); + aIdx += 4; + } + { + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aIdx[0] * aIdxScale, + aPtr + aIdx[1] * aIdxScale, + aPtr + aIdx[2] * aIdxScale, + aPtr + aIdx[3] * aIdxScale, + pa[0], pa[1], pa[2], pa[3]); + pb = _mm512_loadu_si512(bBuffer + j); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_AAAA); + psum = DPBUSD(psum, pa[0], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_BBBB); + psum = DPBUSD(psum, pa[1], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_CCCC); + psum = DPBUSD(psum, pa[2], pbs); + pbs = _mm512_shuffle_epi32(pb, _MM_PERM_DDDD); + psum = DPBUSD(psum, pa[3], pbs); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale); + aBias[i + 4] = *reinterpret_cast(aPtr + aIdx[i] * aIdxScale + 4); + } + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 4)); + psum = _mm512_add_epi32(psum, _mm512_alignr_epi32(psum, psum, 8)); + pr = _mm256_inserti32x4(pr, _mm512_castsi512_si128(psum), 1); + aIdx += 4; + } + + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + c += 8; + } + + inline void scatteredGEMV2_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 2, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 2; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + float bScale[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k) + }; + int32_t bSum[2] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4) + }; + + __m512i pa[4], pb[2], psum[2], pbSum[2], pt[2]; + __m256 paScale, paBias, pbScale, r; + __m256i pr; + pbScale = _mm256_castsi256_ps(_mm256_set1_epi64x(*reinterpret_cast(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 2] = aScale[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 2] = aBias[i * 2 + 1] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + + // 00, 01, 10, 11, ... + pt[0] = _mm512_unpacklo_epi32(psum[0], psum[1]); + // 20, 21, 30, 31, ... + pt[1] = _mm512_unpackhi_epi32(psum[0], psum[1]); + + // 00, 01, 10, 11, 20, 21, 30, 31 + pr = _mm256_permute2x128_si256(_mm512_castsi512_si256(pt[0]), _mm512_castsi512_si256(pt[1]), 0x20); + paScale = _mm256_loadu_ps(aScale); + paBias = _mm256_loadu_ps(aBias); + r = _mm256_fmadd_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(pr), pbScale), paScale, paBias); + _mm256_storeu_ps(c, r); + + aIdx += microM; + c += microM * 2; + } + } + + inline void scatteredGEMV3_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 3, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + float bScale[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + 0 + }; + int32_t bSum[3] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4) + }; + __m512i pa[4], pb[3], psum[3], pbSum[3]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_permutexvar_ps( + _mm512_setr_epi32(0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 3, 3), + _mm512_castps128_ps512(_mm_loadu_ps(bScale))); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + __m512i shfIdxT = _mm512_setr_epi32(0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + psum[2] = pbSum[2]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + psum[2] = DPBUSD(psum[2], pa[0], _mm512_shuffle_epi32(pb[2], _MM_PERM_AAAA)); + psum[2] = DPBUSD(psum[2], pa[1], _mm512_shuffle_epi32(pb[2], _MM_PERM_BBBB)); + psum[2] = DPBUSD(psum[2], pa[2], _mm512_shuffle_epi32(pb[2], _MM_PERM_CCCC)); + psum[2] = DPBUSD(psum[2], pa[3], _mm512_shuffle_epi32(pb[2], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); + + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); + + // 00, 01, 02, 10, 11, 12, 20, 21, 22, 30, 31, 32, ... + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 1, 5, 17, 2, 6, 18, 3, 7, 19, 0, 0, 0, 0 + ), psum[2]); + + paScale = _mm512_castps128_ps512(_mm_loadu_ps(aScale)); + paScale = _mm512_permutexvar_ps(shfIdxT, paScale); + paBias = _mm512_castps128_ps512(_mm_loadu_ps(aBias)); + paBias = _mm512_permutexvar_ps(shfIdxT, paBias); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_mask_storeu_ps(c, 0x0FFF, r); + + aIdx += microM; + c += microM * 3; + } + } + + inline void scatteredGEMV4_512( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + constexpr size_t packM = 4, packN = 4, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM * 4; + memcpy(bBuffer, bBase + bIdx[0] * bIdxScale, k); + memcpy(bBuffer + packK, bBase + bIdx[1] * bIdxScale, k); + memcpy(bBuffer + packK * 2, bBase + bIdx[2] * bIdxScale, k); + memcpy(bBuffer + packK * 3, bBase + bIdx[3] * bIdxScale, k); + float bScale[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k) + }; + int32_t bSum[4] = { + *reinterpret_cast(bBase + bIdx[0] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[1] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[2] * bIdxScale + k + 4), + *reinterpret_cast(bBase + bIdx[3] * bIdxScale + k + 4) + }; + __m512i pa[4], pb[4], psum[4], pbSum[4]; + __m512 paScale, paBias, pbScale, r; + pbScale = _mm512_broadcast_f32x4(_mm_loadu_ps(bScale)); + pbSum[0] = _mm512_set1_epi32(-bSum[0] / 4); + pbSum[1] = _mm512_set1_epi32(-bSum[1] / 4); + pbSum[2] = _mm512_set1_epi32(-bSum[2] / 4); + pbSum[3] = _mm512_set1_epi32(-bSum[3] / 4); + + for (size_t mi = 0; mi < m; mi += packM) + { + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum[0] = pbSum[0]; + psum[1] = pbSum[1]; + psum[2] = pbSum[2]; + psum[3] = pbSum[3]; + for (size_t j = 0; j < k; j += 64) + { + pack4x64to4x16x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb[0] = _mm512_loadu_si512(bBuffer + j); + pb[1] = _mm512_loadu_si512(bBuffer + packK + j); + pb[2] = _mm512_loadu_si512(bBuffer + packK * 2 + j); + pb[3] = _mm512_loadu_si512(bBuffer + packK * 3 + j); + psum[0] = DPBUSD(psum[0], pa[0], _mm512_shuffle_epi32(pb[0], _MM_PERM_AAAA)); + psum[0] = DPBUSD(psum[0], pa[1], _mm512_shuffle_epi32(pb[0], _MM_PERM_BBBB)); + psum[0] = DPBUSD(psum[0], pa[2], _mm512_shuffle_epi32(pb[0], _MM_PERM_CCCC)); + psum[0] = DPBUSD(psum[0], pa[3], _mm512_shuffle_epi32(pb[0], _MM_PERM_DDDD)); + psum[1] = DPBUSD(psum[1], pa[0], _mm512_shuffle_epi32(pb[1], _MM_PERM_AAAA)); + psum[1] = DPBUSD(psum[1], pa[1], _mm512_shuffle_epi32(pb[1], _MM_PERM_BBBB)); + psum[1] = DPBUSD(psum[1], pa[2], _mm512_shuffle_epi32(pb[1], _MM_PERM_CCCC)); + psum[1] = DPBUSD(psum[1], pa[3], _mm512_shuffle_epi32(pb[1], _MM_PERM_DDDD)); + psum[2] = DPBUSD(psum[2], pa[0], _mm512_shuffle_epi32(pb[2], _MM_PERM_AAAA)); + psum[2] = DPBUSD(psum[2], pa[1], _mm512_shuffle_epi32(pb[2], _MM_PERM_BBBB)); + psum[2] = DPBUSD(psum[2], pa[2], _mm512_shuffle_epi32(pb[2], _MM_PERM_CCCC)); + psum[2] = DPBUSD(psum[2], pa[3], _mm512_shuffle_epi32(pb[2], _MM_PERM_DDDD)); + psum[3] = DPBUSD(psum[3], pa[0], _mm512_shuffle_epi32(pb[3], _MM_PERM_AAAA)); + psum[3] = DPBUSD(psum[3], pa[1], _mm512_shuffle_epi32(pb[3], _MM_PERM_BBBB)); + psum[3] = DPBUSD(psum[3], pa[2], _mm512_shuffle_epi32(pb[3], _MM_PERM_CCCC)); + psum[3] = DPBUSD(psum[3], pa[3], _mm512_shuffle_epi32(pb[3], _MM_PERM_DDDD)); + aPtr += 64; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i * 4] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i * 4] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 4)); + psum[0] = _mm512_add_epi32(psum[0], _mm512_alignr_epi32(psum[0], psum[0], 8)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 4)); + psum[1] = _mm512_add_epi32(psum[1], _mm512_alignr_epi32(psum[1], psum[1], 8)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 4)); + psum[2] = _mm512_add_epi32(psum[2], _mm512_alignr_epi32(psum[2], psum[2], 8)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 4)); + psum[3] = _mm512_add_epi32(psum[3], _mm512_alignr_epi32(psum[3], psum[3], 8)); + + // 00, 10, 20, 30, 01, 11, 21, 31 + psum[0] = _mm512_inserti32x4(psum[0], _mm512_castsi512_si128(psum[1]), 1); + // 02, 12, 22, 32, 03, 13, 23, 33 + psum[2] = _mm512_inserti32x4(psum[2], _mm512_castsi512_si128(psum[3]), 1); + + // 00, 01, 02, 03, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33 + psum[0] = _mm512_permutex2var_epi32(psum[0], _mm512_setr_epi32( + 0, 4, 16, 20, 1, 5, 17, 21, 2, 6, 18, 22, 3, 7, 19, 23 + ), psum[2]); + + paScale = _mm512_loadu_ps(aScale); + paScale = _mm512_shuffle_ps(paScale, paScale, 0); + paBias = _mm512_loadu_ps(aBias); + paBias = _mm512_shuffle_ps(paBias, paBias, 0); + + r = _mm512_fmadd_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(psum[0]), pbScale), paScale, paBias); + _mm512_storeu_ps(c, r); + + aIdx += microM; + c += microM * 4; + } + } + + template + inline void scatteredGEMMSmall_512(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + static_assert(m <= 3, "m should be less than or equal to 3"); + static_assert(n <= 3, "n should be less than or equal to 3"); + __m512i pa[3], pb[3], psum[3][3]; + const uint8_t* aPtr[3]; + const int8_t* bPtr[3]; + + psum[0][0] = _mm512_setzero_si512(); + if (m > 1) psum[1][0] = _mm512_setzero_si512(); + if (m > 2) psum[2][0] = _mm512_setzero_si512(); + if (n > 1) psum[0][1] = _mm512_setzero_si512(); + if (m > 1 && n > 1) psum[1][1] = _mm512_setzero_si512(); + if (m > 2 && n > 1) psum[2][1] = _mm512_setzero_si512(); + if (n > 2) psum[0][2] = _mm512_setzero_si512(); + if (m > 1 && n > 2) psum[1][2] = _mm512_setzero_si512(); + if (m > 2 && n > 2) psum[2][2] = _mm512_setzero_si512(); + + aPtr[0] = aBase + aIdx[0] * aIdxScale; + if (m > 1) aPtr[1] = aBase + aIdx[1] * aIdxScale; + if (m > 2) aPtr[2] = aBase + aIdx[2] * aIdxScale; + + bPtr[0] = bBase + bIdx[0] * bIdxScale; + if (n > 1) bPtr[1] = bBase + bIdx[1] * bIdxScale; + if (n > 2) bPtr[2] = bBase + bIdx[2] * bIdxScale; + + for (size_t x = 0; x < k; x += 64) + { + if (m > 0) + { + pa[0] = _mm512_loadu_si512(aPtr[0]); + aPtr[0] += 64; + } + if (m > 1) + { + pa[1] = _mm512_loadu_si512(aPtr[1]); + aPtr[1] += 64; + } + if (m > 2) + { + pa[2] = _mm512_loadu_si512(aPtr[2]); + aPtr[2] += 64; + } + + if (n > 0) + { + pb[0] = _mm512_loadu_si512(bPtr[0]); + bPtr[0] += 64; + } + if (n > 1) + { + pb[1] = _mm512_loadu_si512(bPtr[1]); + bPtr[1] += 64; + } + if (n > 2) + { + pb[2] = _mm512_loadu_si512(bPtr[2]); + bPtr[2] += 64; + } + + psum[0][0] = DPBUSD(psum[0][0], pa[0], pb[0]); + if (m > 1) psum[1][0] = DPBUSD(psum[1][0], pa[1], pb[0]); + if (m > 2) psum[2][0] = DPBUSD(psum[2][0], pa[2], pb[0]); + if (n > 1) psum[0][1] = DPBUSD(psum[0][1], pa[0], pb[1]); + if (m > 1 && n > 1) psum[1][1] = DPBUSD(psum[1][1], pa[1], pb[1]); + if (m > 2 && n > 1) psum[2][1] = DPBUSD(psum[2][1], pa[2], pb[1]); + if (n > 2) psum[0][2] = DPBUSD(psum[0][2], pa[0], pb[2]); + if (m > 1 && n > 2) psum[1][2] = DPBUSD(psum[1][2], pa[1], pb[2]); + if (m > 2 && n > 2) psum[2][2] = DPBUSD(psum[2][2], pa[2], pb[2]); + } + + float contextScale[3], outputScale[3], contextBias[3]; + int32_t hsum[3]; + + if (m > 0) + { + contextScale[0] = *reinterpret_cast(aPtr[0]); + contextBias[0] = *reinterpret_cast(aPtr[0] + 4); + } + if (m > 1) + { + contextScale[1] = *reinterpret_cast(aPtr[1]); + contextBias[1] = *reinterpret_cast(aPtr[1] + 4); + } + if (m > 2) + { + contextScale[2] = *reinterpret_cast(aPtr[2]); + contextBias[2] = *reinterpret_cast(aPtr[2] + 4); + } + + + if (n > 0) + { + outputScale[0] = *reinterpret_cast(bPtr[0]); + hsum[0] = *reinterpret_cast(bPtr[0] + 4); + } + if (n > 1) + { + outputScale[1] = *reinterpret_cast(bPtr[1]); + hsum[1] = *reinterpret_cast(bPtr[1] + 4); + } + if (n > 2) + { + outputScale[2] = *reinterpret_cast(bPtr[2]); + hsum[2] = *reinterpret_cast(bPtr[2] + 4); + } + + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][0]); + c[0] = (acc - hsum[0]) * contextScale[0] * outputScale[0] + contextBias[0]; + } + if (m > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][0]); + c[ldc] = (acc - hsum[0]) * contextScale[1] * outputScale[0] + contextBias[1]; + } + if (m > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][0]); + c[ldc * 2] = (acc - hsum[0]) * contextScale[2] * outputScale[0] + contextBias[2]; + } + if (n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][1]); + c[1] = (acc - hsum[1]) * contextScale[0] * outputScale[1] + contextBias[0]; + } + if (m > 1 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][1]); + c[ldc + 1] = (acc - hsum[1]) * contextScale[1] * outputScale[1] + contextBias[1]; + } + if (m > 2 && n > 1) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][1]); + c[ldc * 2 + 1] = (acc - hsum[1]) * contextScale[2] * outputScale[1] + contextBias[2]; + } + if (n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[0][2]); + c[2] = (acc - hsum[2]) * contextScale[0] * outputScale[2] + contextBias[0]; + } + if (m > 1 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[1][2]); + c[ldc + 2] = (acc - hsum[2]) * contextScale[1] * outputScale[2] + contextBias[1]; + } + if (m > 2 && n > 2) + { + int32_t acc = _mm512_reduce_add_epi32(psum[2][2]); + c[ldc * 2 + 2] = (acc - hsum[2]) * contextScale[2] * outputScale[2] + contextBias[2]; + } + } + } +} \ No newline at end of file diff --git a/src/archImpl/avx512bw.cpp b/src/archImpl/avx512bw.cpp index ad290331..df2cd031 100644 --- a/src/archImpl/avx512bw.cpp +++ b/src/archImpl/avx512bw.cpp @@ -1,22 +1,162 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenAVX512 +#include namespace kiwi { - namespace sb + namespace qgemm { + // emulate _mm512_dpbusd_epi32 using AVX512BW + static FORCE_INLINE __m512i dpbusd(__m512i src, __m512i a, __m512i b) + { + __m512i one16 = _mm512_set1_epi16(1); + __m512i t0 = _mm512_maddubs_epi16(a, b); + __m512i t1 = _mm512_madd_epi16(t0, one16); + return _mm512_add_epi32(src, t1); + } + } +} + +#define DPBUSD dpbusd +#include "avx512_qgemm.hpp" + +namespace kiwi +{ + namespace lm + { + template<> + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template<> + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + template<> - struct LogExpSum + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) { - template - float operator()(const float* arr, std::integral_constant) + return scatteredGEMV2_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV3_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV4_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) { - return logExpSumImpl(arr); + return scatteredGEMMSmall_512(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); } }; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/archImpl/avx512vnni.cpp b/src/archImpl/avx512vnni.cpp new file mode 100644 index 00000000..60fc86bd --- /dev/null +++ b/src/archImpl/avx512vnni.cpp @@ -0,0 +1,138 @@ +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define DPBUSD _mm512_dpbusd_epi32 +#include "avx512_qgemm.hpp" + +namespace kiwi +{ + namespace lm + { + template<> + float logSumExp(const float* arr, size_t size) + { + if (size == 8) return LogSumExp()(arr, std::integral_constant()); + if (size == 16) return LogSumExp()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template<> + void logSoftmax(float* arr, size_t size) + { + if (size == 8) return LogSoftmax()(arr, std::integral_constant()); + if (size == 16) return LogSoftmax()(arr, std::integral_constant()); + throw std::runtime_error("Unsupported size"); + } + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_512(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV3_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV4_512(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_512(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } + } +} diff --git a/src/archImpl/avx_vnni.cpp b/src/archImpl/avx_vnni.cpp new file mode 100644 index 00000000..178c4b8b --- /dev/null +++ b/src/archImpl/avx_vnni.cpp @@ -0,0 +1,100 @@ +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define DPBUSD _mm256_dpbusd_epi32 +#include "avx2_qgemm.hpp" + +namespace kiwi +{ + namespace lm + { + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV8x1_256(m, k, aBase, aIdx, aIdxScale, b, c); + } + + template<> + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMV2_256(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + + template<> + struct ScatteredGEMMSmall + { + template + static void op(size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMSmall_256(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } + } +} diff --git a/src/archImpl/neon.cpp b/src/archImpl/neon.cpp index 57d80ec8..e37b9ca4 100644 --- a/src/archImpl/neon.cpp +++ b/src/archImpl/neon.cpp @@ -1,22 +1,60 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenNEON +#include namespace kiwi { - namespace sb + namespace lm + { + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm { template<> - struct LogExpSum + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/archImpl/none.cpp b/src/archImpl/none.cpp index 42ab13b8..95060632 100644 --- a/src/archImpl/none.cpp +++ b/src/archImpl/none.cpp @@ -1,17 +1,73 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../gemm.h" + +#include namespace kiwi { - namespace sb + namespace lm { - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + return gemm(m, n, k, aT, strideA, b, strideB, c, strideC); + } + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + return gemv(m, k, aT, strideA, b, c); + } } } diff --git a/src/archImpl/sse2.cpp b/src/archImpl/sse2.cpp index ecab2c80..6669bd7a 100644 --- a/src/archImpl/sse2.cpp +++ b/src/archImpl/sse2.cpp @@ -1,22 +1,48 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../gemm.h" + +#define Eigen EigenSSE2 +#include namespace kiwi { - namespace sb + namespace lm + { + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace gemm { template<> - struct LogExpSum + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) { - template - float operator()(const float* arr, std::integral_constant) - { - return logExpSumImpl(arr); - } - }; + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } + diff --git a/src/archImpl/sse4_1.cpp b/src/archImpl/sse4_1.cpp index 8c7efa4d..cf19986d 100644 --- a/src/archImpl/sse4_1.cpp +++ b/src/archImpl/sse4_1.cpp @@ -1,22 +1,176 @@ -#include "../SkipBigramModelImpl.hpp" +#include "../MathFunc.hpp" +#include "../qgemm.hpp" +#include "../gemm.h" + +#define Eigen EigenSSE41 +#include namespace kiwi { - namespace sb + namespace lm { - template<> - struct LogExpSum + template float logSumExp(const float* arr, size_t size); + template void logSumExpTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + template void logSoftmax(float* arr, size_t size); + template void logSoftmaxTransposed(float* arr, size_t size, size_t batchSize, size_t stride); + } + + namespace qgemm + { + template int32_t dotprod(const uint8_t* a, const int8_t* b, size_t n); + + static FORCE_INLINE __m128i dpbusd(__m128i src, __m128i a, __m128i b) + { + __m128i one16 = _mm_set1_epi16(1); + __m128i t0 = _mm_maddubs_epi16(a, b); + __m128i t1 = _mm_madd_epi16(t0, one16); + return _mm_add_epi32(src, t1); + } + + inline void pack4x16to4x4x4( + const void* a0, const void* a1, const void* a2, const void* a3, + __m128i& p0, __m128i& p1, __m128i& p2, __m128i& p3 + ) { - template - float operator()(const float* arr, std::integral_constant) + __m128i q0, q1, q2, q3; + // 00, 01, 02, 03, 04, 05, 06, 07, ... + p0 = _mm_loadu_si128((const __m128i*)a0); + p1 = _mm_loadu_si128((const __m128i*)a1); + p2 = _mm_loadu_si128((const __m128i*)a2); + p3 = _mm_loadu_si128((const __m128i*)a3); + + // 00, 10, 01, 11, 04, 14, 05, 15, ... + q0 = _mm_unpacklo_epi32(p0, p1); + // 02, 12, 03, 13, 06, 16, 07, 17, ... + q1 = _mm_unpackhi_epi32(p0, p1); + // 20, 30, 21, 31, 24, 34, 25, 35, ... + q2 = _mm_unpacklo_epi32(p2, p3); + // 22, 32, 23, 33, 26, 36, 27, 37, ... + q3 = _mm_unpackhi_epi32(p2, p3); + + // 00, 10, 20, 30, 04, 14, 24, 34, ... + p0 = _mm_unpacklo_epi64(q0, q2); + // 01, 11, 21, 31, 05, 15, 25, 35, ... + p1 = _mm_unpackhi_epi64(q0, q2); + // 02, 12, 22, 32, 06, 16, 26, 36, ... + p2 = _mm_unpacklo_epi64(q1, q3); + // 03, 13, 23, 33, 07, 17, 27, 37, ... + p3 = _mm_unpackhi_epi64(q1, q3); + } + + inline void scatteredGEMV_128( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + constexpr size_t packM = 4, packN = 1, packK = 384; + auto* buffer = SharedThreadLocalBuffer<>::get(); + int8_t* bBuffer = reinterpret_cast(buffer); + float* aScale = reinterpret_cast(bBuffer + packN * packK); + float* aBias = aScale + packM; + memcpy(bBuffer, b, k); + float bScale = *reinterpret_cast(b + k); + int32_t bSum = *reinterpret_cast(b + k + 4); + + __m128i pa[4], pb, pbs, psum, pbSum; + __m128 paScale, paBias, pbScale, r; + pbScale = _mm_set1_ps(bScale); + pbSum = _mm_set1_epi32(-bSum); + + for (size_t mi = 0; mi < m; mi += packM) { - return logExpSumImpl(arr); + const size_t microM = std::min(packM, m - mi); + const int32_t aOffsets[4] = { + (int32_t)(aIdx[0] * aIdxScale), + 1 < microM ? (int32_t)(aIdx[1] * aIdxScale) : 0, + 2 < microM ? (int32_t)(aIdx[2] * aIdxScale) : 0, + 3 < microM ? (int32_t)(aIdx[3] * aIdxScale) : 0, + }; + auto* aPtr = aBase; + psum = pbSum; + for (size_t j = 0; j < k; j += 16) + { + pack4x16to4x4x4(aPtr + aOffsets[0], + aPtr + aOffsets[1], + aPtr + aOffsets[2], + aPtr + aOffsets[3], + pa[0], pa[1], pa[2], pa[3]); + pb = _mm_loadu_si128((const __m128i*)(bBuffer + j)); + pbs = _mm_shuffle_epi32(pb, 0x00); + psum = dpbusd(psum, pa[0], pbs); + pbs = _mm_shuffle_epi32(pb, 0x55); + psum = dpbusd(psum, pa[1], pbs); + pbs = _mm_shuffle_epi32(pb, 0xAA); + psum = dpbusd(psum, pa[2], pbs); + pbs = _mm_shuffle_epi32(pb, 0xFF); + psum = dpbusd(psum, pa[3], pbs); + aPtr += 16; + } + for (size_t i = 0; i < 4; ++i) + { + aScale[i] = *reinterpret_cast(aPtr + aOffsets[i]); + aBias[i] = *reinterpret_cast(aPtr + aOffsets[i] + 4); + } + aIdx += 4; + + paScale = _mm_loadu_ps(aScale); + paBias = _mm_loadu_ps(aBias); + r = _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_cvtepi32_ps(psum), pbScale), paScale), paBias); + _mm_storeu_ps(c, r); + c += microM; } - }; + } + + template<> + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + return scatteredGEMV_128(m, k, aBase, aIdx, aIdxScale, b, c); + } - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; - template class SkipBigramModel; + template void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } + + namespace gemm + { + template<> + void gemm( + size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map> bMap(b, k, n, Eigen::OuterStride<>(strideB)); + Eigen::Map> cMap(c, m, n, Eigen::OuterStride<>(strideC)); + cMap.noalias() += aMap.transpose() * bMap; + } + + + template<> + void gemv( + size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ) + { + Eigen::Map> aMap(aT, k, m, Eigen::OuterStride<>(strideA)); + Eigen::Map bMap(b, k); + Eigen::Map cMap(c, m); + cMap.noalias() += aMap.transpose() * bMap; + } } } diff --git a/src/capi/kiwi_c.cpp b/src/capi/kiwi_c.cpp index 1ba7abf4..858c04ef 100644 --- a/src/capi/kiwi_c.cpp +++ b/src/capi/kiwi_c.cpp @@ -110,8 +110,13 @@ kiwi_builder_h kiwi_builder_init(const char* model_path, int num_threads, int op try { BuildOption buildOption = (BuildOption)(options & 0xFF); - bool useSBG = !!(options & KIWI_BUILD_MODEL_TYPE_SBG); - return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, useSBG}; + const auto mtMask = options & (KIWI_BUILD_MODEL_TYPE_KNLM | KIWI_BUILD_MODEL_TYPE_SBG | KIWI_BUILD_MODEL_TYPE_CONG | KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL); + const ModelType modelType = mtMask == KIWI_BUILD_MODEL_TYPE_KNLM ? ModelType::knlm + : mtMask == KIWI_BUILD_MODEL_TYPE_SBG ? ModelType::sbg + : mtMask == KIWI_BUILD_MODEL_TYPE_CONG ? ModelType::cong + : mtMask == KIWI_BUILD_MODEL_TYPE_CONG_GLOBAL ? ModelType::congGlobal + : ModelType::none; + return (kiwi_builder_h)new KiwiBuilder{ model_path, (size_t)num_threads, buildOption, modelType }; } catch (...) { diff --git a/src/count.hpp b/src/count.hpp index ac6e7505..4d6deed8 100644 --- a/src/count.hpp +++ b/src/count.hpp @@ -29,10 +29,10 @@ namespace kiwi #else template using map = std::map; #endif - using Vid = uint16_t; - using CTrieNode = TrieNodeEx>>; + template + using CTrieNode = TrieNodeEx>>; - static constexpr Vid non_vocab_id = (Vid)-1; + static constexpr size_t non_vocab_id = (size_t)-1; template class StrideIter : public _Iterator @@ -77,14 +77,15 @@ namespace kiwi { struct vvhash { - size_t operator()(const std::pair& k) const + template + size_t operator()(const std::pair& k) const { - return std::hash{}(k.first) ^ std::hash{}(k.second); + return std::hash{}(k.first) ^ std::hash{}(k.second); } }; } - template + template void countUnigrams(std::vector& unigramCf, std::vector& unigramDf, _DocIter docBegin, _DocIter docEnd ) @@ -93,7 +94,7 @@ namespace kiwi { auto doc = *docIt; if (!doc.size()) continue; - std::unordered_set uniqs; + std::unordered_set uniqs; for (size_t i = 0; i < doc.size(); ++i) { if (doc[i] == non_vocab_id) continue; @@ -110,24 +111,24 @@ namespace kiwi } } - template> - void countBigrams(map, size_t>& bigramCf, - map, size_t>& bigramDf, + template> + void countBigrams(map, size_t>& bigramCf, + map, size_t>& bigramDf, _DocIter docBegin, _DocIter docEnd, _Freqs&& vocabFreqs, _Freqs&& vocabDf, size_t candMinCnt, size_t candMinDf, const _HistoryTx* historyTransformer = nullptr ) { - std::unordered_set, detail::vvhash> uniqBigram; + std::unordered_set, detail::vvhash> uniqBigram; for (auto docIt = docBegin; docIt != docEnd; ++docIt) { auto doc = *docIt; if (!doc.size()) continue; - Vid prevWord = doc[0]; + VocabTy prevWord = doc[0]; for (size_t j = 1; j < doc.size(); ++j) { - Vid curWord = doc[j]; + VocabTy curWord = doc[j]; if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf) { if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf) @@ -144,8 +145,8 @@ namespace kiwi } } - template> - void countNgrams(ContinuousTrie& dest, + template> + void countNgrams(ContinuousTrie>& dest, _DocIter docBegin, _DocIter docEnd, _Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs, size_t candMinCnt, size_t candMinDf, size_t maxNgrams, @@ -154,7 +155,7 @@ namespace kiwi { if (dest.empty()) { - dest = ContinuousTrie{ 1, 1024 }; + dest = ContinuousTrie>{ 1, 1024 }; } const auto& allocNode = [&]() { return dest.newNode(); }; const auto& historyTx = [&](size_t i) { return (*historyTransformer)[i]; }; @@ -165,7 +166,7 @@ namespace kiwi if (!doc.size()) continue; dest.reserveMore(doc.size() * maxNgrams * 2); - Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin(); + VocabTy prevWord = _reverse ? *doc.rbegin() : *doc.begin(); size_t labelLen = 0; auto node = &dest[0]; if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf) @@ -175,7 +176,7 @@ namespace kiwi labelLen = 1; } - const auto func = [&](Vid curWord) + const auto func = [&](VocabTy curWord) { if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf)) { @@ -239,19 +240,21 @@ namespace kiwi } } - inline void mergeNgramCounts(ContinuousTrie& dest, ContinuousTrie&& src) + template + inline void mergeNgramCounts(ContinuousTrie>& dest, ContinuousTrie>&& src) { if (src.empty()) return; - if (dest.empty()) dest = ContinuousTrie{ 1 }; + if (dest.empty()) dest = ContinuousTrie>{ 1 }; - std::vector rkeys; - src.traverseWithKeys([&](const CTrieNode* node, const std::vector& rkeys) + std::vector rkeys; + src.traverseWithKeys([&](const CTrieNode* node, const std::vector& rkeys) { dest.build(rkeys.begin(), rkeys.end(), 0)->val += node->val; }, rkeys); } - inline float branchingEntropy(const CTrieNode* node, size_t minCnt) + template + inline float branchingEntropy(const CTrieNode* node, size_t minCnt) { float entropy = 0; size_t rest = node->val; @@ -300,16 +303,16 @@ namespace kiwi return std::move(data[0]); } - template> - ContinuousTrie count(_DocIter docBegin, _DocIter docEnd, + template> + ContinuousTrie> count(_DocIter docBegin, _DocIter docEnd, size_t minCf, size_t minDf, size_t maxNgrams, - ThreadPool* pool = nullptr, std::vector>* bigramList = nullptr, + ThreadPool* pool = nullptr, std::vector>* bigramList = nullptr, const _HistoryTx* historyTransformer = nullptr ) { // counting unigrams & bigrams std::vector unigramCf, unigramDf; - map, size_t> bigramCf, bigramDf; + map, size_t> bigramCf, bigramDf; if (pool && pool->size() > 1) { @@ -325,7 +328,7 @@ namespace kiwi { futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid) { - countUnigrams(localdata[tid].first, localdata[tid].second, + countUnigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd) ); @@ -351,7 +354,7 @@ namespace kiwi } else { - countUnigrams(unigramCf, unigramDf, docBegin, docEnd); + countUnigrams(unigramCf, unigramDf, docBegin, docEnd); } if (pool && pool->size() > 1) @@ -402,7 +405,7 @@ namespace kiwi } } - ContinuousTrie trieNodes{ 1 }; + ContinuousTrie> trieNodes{ 1 }; if (historyTransformer) { for (size_t i = 0; i < unigramCf.size(); ++i) @@ -434,7 +437,7 @@ namespace kiwi // counting ngrams else { - std::unordered_set, detail::vvhash> validPairs; + std::unordered_set, detail::vvhash> validPairs; for (auto& p : bigramCf) { if (p.second >= minCf && bigramDf[p.first] >= minDf) validPairs.emplace(p.first); @@ -442,7 +445,7 @@ namespace kiwi if (pool && pool->size() > 1) { - using LocalFw = ContinuousTrie; + using LocalFw = ContinuousTrie>; std::vector localdata(pool->size()); std::vector> futures; const size_t stride = pool->size() * 8; diff --git a/src/gemm.h b/src/gemm.h new file mode 100644 index 00000000..7ade297e --- /dev/null +++ b/src/gemm.h @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace kiwi +{ + namespace gemm + { + // c += a.transpose() * b + template + void gemm(size_t m, size_t n, size_t k, + const float* aT, size_t strideA, + const float* b, size_t strideB, + float* c, size_t strideC + ); + + // c += a.transpose() * b + template + void gemv(size_t m, size_t k, + const float* aT, size_t strideA, + const float* b, + float* c + ); + } +} diff --git a/src/qgemm.h b/src/qgemm.h new file mode 100644 index 00000000..9871582c --- /dev/null +++ b/src/qgemm.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include + +namespace kiwi +{ + namespace qgemm + { + template + int32_t dotprod( + const uint8_t* a, const int8_t* b, size_t n + ); + + template + void scatteredGEMM( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + + template + void scatteredGEMMBaseline( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + + template + void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ); + } +} diff --git a/src/qgemm.hpp b/src/qgemm.hpp new file mode 100644 index 00000000..0874b440 --- /dev/null +++ b/src/qgemm.hpp @@ -0,0 +1,194 @@ +#pragma once +#include +#include +#include "qgemm.h" +#include "SIMD.hpp" + +namespace kiwi +{ + namespace qgemm + { + static constexpr size_t TLBSize = 32768; + + template + struct SharedThreadLocalBuffer + { + thread_local static uint8_t buffer[size]; + static uint8_t* get() + { + return buffer; + } + }; + + template + thread_local uint8_t SharedThreadLocalBuffer::buffer[size]; + + template + int32_t dotprod( + const uint8_t* a, const int8_t* b, size_t n + ) + { + simd::Operator op; + return op.dotprod(a, b, n); + } + + template + void scatteredGEMMBaseline( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + thread_local Vector buffer; + buffer.resize((m + n) * (k + 8)); + uint8_t* aBuffer = buffer.data(); + int8_t* bBuffer = reinterpret_cast(aBuffer + m * (k + 8)); + simd::Operator op; + + for (size_t i = 0; i < m; ++i) + { + std::memcpy(aBuffer + i * (k + 8), &aBase[aIdx[i] * aIdxScale], k + 8); + } + for (size_t i = 0; i < n; ++i) + { + std::memcpy(bBuffer + i * (k + 8), &bBase[bIdx[i] * bIdxScale], k + 8); + } + + for (size_t i = 0; i < m; ++i) + { + for (size_t j = 0; j < n; ++j) + { + const auto* aPtr = aBuffer + i * (k + 8); + const auto* bPtr = bBuffer + j * (k + 8); + int32_t acc = op.dotprod(aPtr, bPtr, k); + const float contextScale = *reinterpret_cast(aPtr + k), + outputScale = *reinterpret_cast(bPtr + k), + contextBias = *reinterpret_cast(aPtr + k + 4); + const int32_t hsum = *reinterpret_cast(bPtr + k + 4); + c[i * ldc + j] = (acc - hsum) * contextScale * outputScale + contextBias; + } + } + } + + template + inline void scatteredGEMV( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + const int32_t bIdx[1] = { 0 }; + return scatteredGEMMBaseline(m, 1, k, aBase, aIdx, aIdxScale, b, bIdx, 0, c, 1); + } + + template + inline void scatteredGEMV8x1( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* b, + float* c + ) + { + const int32_t bIdx[1] = { 0 }; + return scatteredGEMMBaseline(m, 1, k, aBase, aIdx, aIdxScale, b, bIdx, 0, c, 1); + } + + template + inline void scatteredGEMV2( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 2, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 2); + } + + template + inline void scatteredGEMV3( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 3, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 3); + } + + template + inline void scatteredGEMV4( + size_t m, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c + ) + { + return scatteredGEMMBaseline(m, 4, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, 4); + } + + template + struct ScatteredGEMMSmall + { + template + static void op( + size_t, size_t, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc) + { + return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + }; + + template + void scatteredGEMMOpt( + size_t m, size_t n, size_t k, + const uint8_t* aBase, const int32_t* aIdx, size_t aIdxScale, + const int8_t* bBase, const int32_t* bIdx, size_t bIdxScale, + float* c, size_t ldc + ) + { + using Fn = decltype(&scatteredGEMMBaseline); + static constexpr Fn fnTable[] = { + scatteredGEMMBaseline, + ScatteredGEMMSmall::template op<1, 2>, + ScatteredGEMMSmall::template op<1, 3>, + ScatteredGEMMSmall::template op<2, 1>, + ScatteredGEMMSmall::template op<2, 2>, + ScatteredGEMMSmall::template op<2, 3>, + ScatteredGEMMSmall::template op<3, 1>, + ScatteredGEMMSmall::template op<3, 2>, + ScatteredGEMMSmall::template op<3, 3> + }; + + if (m <= 3 && n <= 3) + { + return (*fnTable[(m - 1) * 3 + (n - 1)])(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + + if (n == 1 && ldc == 1) + { + if (m == 8) + { + return scatteredGEMV8x1(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + else + { + return scatteredGEMV(m, k, aBase, aIdx, aIdxScale, bBase + bIdx[0] * bIdxScale, c); + } + } + + if (m >= 4) + { + if (n == 2 && ldc == 2) return scatteredGEMV2(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 3 && ldc == 3) return scatteredGEMV3(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + if (n == 4 && ldc == 4) return scatteredGEMV4(m, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c); + } + return scatteredGEMMBaseline(m, n, k, aBase, aIdx, aIdxScale, bBase, bIdx, bIdxScale, c, ldc); + } + + // real implementations are in `archImpl/.cpp` + } +} diff --git a/src/sais/mp_utils.hpp b/src/sais/mp_utils.hpp index ccec344a..f176edbe 100644 --- a/src/sais/mp_utils.hpp +++ b/src/sais/mp_utils.hpp @@ -58,7 +58,7 @@ namespace mp ThreadPool(size_t threads = 0); template auto runParallel(size_t workers, F&& f, Args&&... args) - -> std::vector::type>>; + -> std::vector::type>>; ~ThreadPool(); size_t size() const { return workers.size(); } size_t limitedSize() const { return std::min(size(), _limitedSize); }; @@ -106,9 +106,9 @@ namespace mp template auto ThreadPool::runParallel(size_t workers, F&& f, Args&&... args) - -> std::vector::type>> + -> std::vector::type>> { - using return_type = typename std::result_of::type; + using return_type = typename std::invoke_result::type; std::vector> ret; { auto b = std::make_shared(getBarrier(workers)); @@ -305,7 +305,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline auto runParallel(ThreadPool* pool, Fn&& func, Args&&... args) -> std::vector { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`runParallel` receives arguments of wrong type."); @@ -331,7 +331,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline void runParallel(ThreadPool* pool, Fn&& func, Args&&... args) { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`runParallel` receives arguments of wrong type."); @@ -360,7 +360,7 @@ namespace mp } template::type, void>::value, int>::type = 0> + typename std::enable_if::type, void>::value, int>::type = 0> inline void forParallel(ThreadPool* pool, ptrdiff_t start, ptrdiff_t stop, ptrdiff_t step, Fn&& func, Args&&... args) { static_assert(detail::AllOfType, detail::IsRunParallelArg>::value, "`forParallel` receives arguments of wrong type."); @@ -396,7 +396,7 @@ namespace mp ThreadPool* pool; public: OverrideLimitedSize(ThreadPool* _pool, size_t newSize) - : pool{ _pool }, prevSize{ pool ? pool->limitedSize() : -1 } + : pool{ _pool }, prevSize{ _pool ? _pool->limitedSize() : -1 } { if (pool) pool->_limitedSize = newSize; } diff --git a/src/search.cpp b/src/search.cpp index 6f14f0ef..140a18cf 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -1,9 +1,11 @@ #include #include +#include #include #include +#include #include "ArchAvailable.h" #include "search.h" @@ -14,11 +16,23 @@ template bool detail::searchImpl(const uint32_t*, size_t, uint32_t, size_t&);\ template bool detail::searchImpl(const uint64_t*, size_t, uint64_t, size_t&);\ template bool detail::searchImpl(const char16_t*, size_t, char16_t, size_t&);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint8_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint16_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint32_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, uint64_t);\ + template uint32_t detail::searchKVImpl(const void*, size_t, char16_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint8_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint16_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint32_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, uint64_t);\ + template uint64_t detail::searchKVImpl(const void*, size_t, char16_t);\ template Vector detail::reorderImpl(const uint8_t*, size_t);\ template Vector detail::reorderImpl(const uint16_t*, size_t);\ template Vector detail::reorderImpl(const uint32_t*, size_t);\ template Vector detail::reorderImpl(const uint64_t*, size_t);\ - template Vector detail::reorderImpl(const char16_t*, size_t); + template Vector detail::reorderImpl(const char16_t*, size_t);\ + template size_t detail::getPacketSizeImpl();\ + template size_t detail::findAllImpl(const uint8_t*, size_t, uint8_t);\ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 || KIWI_ARCH_X86 || KIWI_ARCH_X86_64 #include @@ -26,6 +40,14 @@ #include #endif +#if defined(_MSC_VER) +#define FORCE_INLINE __forceinline +#elif defined(__GNUC__) +#define FORCE_INLINE __attribute__((always_inline)) +#else +#define FORCE_INLINE inline +#endif + #ifdef __GNUC__ #define ARCH_TARGET(x) __attribute__((target(x))) #else @@ -149,12 +171,33 @@ namespace kiwi template bool detail::searchImpl(const IntTy* keys, size_t size, IntTy target, size_t& ret) { - return OptimizedImpl::template search(keys, size, target, ret); + return OptimizedImpl::search(keys, size, target, ret); + + } + + template + ValueTy detail::searchKVImpl(const void* kv, size_t size, IntTy target) + { + return OptimizedImpl::template searchKV(kv, size, target); + } + + template + size_t detail::getPacketSizeImpl() + { + return OptimizedImpl::packetSize; + } + + template + size_t detail::findAllImpl(const uint8_t* arr, size_t size, uint8_t key) + { + return OptimizedImpl::findAll(arr, size, key); } template<> struct OptimizedImpl { + static constexpr size_t packetSize = 0; + template static Vector reorder(const IntTy* keys, size_t size) { @@ -166,12 +209,37 @@ namespace kiwi { return bstSearch(keys, size, target, ret); } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + size_t idx; + const IntTy* keys = reinterpret_cast(kv); + const ValueTy* values = reinterpret_cast(keys + size); + if (search(keys, size, target, idx)) + { + return values[idx]; + } + else return 0; + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + size_t ret = 0; + for (size_t i = 0; i < size; ++i) + { + ret |= (size_t)(arr[i] == key) << i; + } + return ret; + } }; INSTANTIATE_IMPL(ArchType::none); template<> struct OptimizedImpl { + static constexpr size_t packetSize = 0; + template static Vector reorder(const IntTy* keys, size_t size) { @@ -197,26 +265,31 @@ namespace kiwi ret = left1; return true; } - }; - INSTANTIATE_IMPL(ArchType::balanced); - - template - struct SignedType { using type = IntTy; }; - - template<> - struct SignedType { using type = int8_t; }; - - template<> - struct SignedType { using type = int16_t; }; - - template<> - struct SignedType { using type = int32_t; }; - template<> - struct SignedType { using type = int64_t; }; + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + size_t idx; + const IntTy* keys = reinterpret_cast(kv); + const ValueTy* values = reinterpret_cast(keys + size); + if (search(keys, size, target, idx)) + { + return values[idx]; + } + else return 0; + } - template<> - struct SignedType { using type = int16_t; }; + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + size_t ret = 0; + for (size_t i = 0; i < size; ++i) + { + ret |= (size_t)(arr[i] == key) << i; + } + return ret; + } + }; + INSTANTIATE_IMPL(ArchType::balanced); } } @@ -241,7 +314,7 @@ namespace kiwi template ARCH_TARGET("sse2") - bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + inline bool nstSearchSSE2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -401,21 +474,154 @@ namespace kiwi return false; } + template + ARCH_TARGET("sse2") + inline ValueTy nstSearchKVSSE2(const uint8_t* kv, size_t size, IntTy target) + { + size_t i = 0, r; + + __m128i ptarget, pkey, peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm_set1_epi8(target); + break; + case 2: + ptarget = _mm_set1_epi16(target); + break; + case 4: + ptarget = _mm_set1_epi32(target); + break; + } + + if (size < n) + { + pkey = _mm_loadu_si128(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm_cmpeq_epi8(ptarget, pkey); + break; + case 2: + peq = _mm_cmpeq_epi16(ptarget, pkey); + break; + case 4: + peq = _mm_cmpeq_epi32(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = _mm_loadu_si128(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm_cmpeq_epi8(ptarget, pkey); + pgt = _mm_cmpgt_epi8(ptarget, pkey); + break; + case 2: + peq = _mm_cmpeq_epi16(ptarget, pkey); + pgt = _mm_cmpgt_epi16(ptarget, pkey); + break; + case 4: + peq = _mm_cmpeq_epi32(ptarget, pkey); + pgt = _mm_cmpgt_epi32(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + + r = utils::popcount((uint32_t)_mm_movemask_epi8(pgt)) / sizeof(IntTy); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + ARCH_TARGET("sse2") + inline size_t findAllSSE2(const uint8_t* arr, size_t size, uint8_t key) + { + __m128i pkey = _mm_set1_epi8(key); + if (size <= 16) + { + __m128i parr = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i pcmp = _mm_cmpeq_epi8(pkey, parr); + return (size_t)_mm_movemask_epi8(pcmp) & (((size_t)1 << size) - 1); + } + else if (size <= 32) + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16)) & (((size_t)1 << size) - 1); + } + else if (size <= 48) + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i parr2 = _mm_loadu_si128(reinterpret_cast(arr + 32)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + __m128i pcmp2 = _mm_cmpeq_epi8(pkey, parr2); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16) | ((size_t)_mm_movemask_epi8(pcmp2) << 32)) & (((size_t)1 << size) - 1); + } + else + { + __m128i parr0 = _mm_loadu_si128(reinterpret_cast(arr)); + __m128i parr1 = _mm_loadu_si128(reinterpret_cast(arr + 16)); + __m128i parr2 = _mm_loadu_si128(reinterpret_cast(arr + 32)); + __m128i parr3 = _mm_loadu_si128(reinterpret_cast(arr + 48)); + __m128i pcmp0 = _mm_cmpeq_epi8(pkey, parr0); + __m128i pcmp1 = _mm_cmpeq_epi8(pkey, parr1); + __m128i pcmp2 = _mm_cmpeq_epi8(pkey, parr2); + __m128i pcmp3 = _mm_cmpeq_epi8(pkey, parr3); + return ((size_t)_mm_movemask_epi8(pcmp0) | ((size_t)_mm_movemask_epi8(pcmp1) << 16) | ((size_t)_mm_movemask_epi8(pcmp2) << 32) | ((size_t)_mm_movemask_epi8(pcmp3) << 48)) & (((size_t)1 << size) - 1); + } + } + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchSSE2<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchSSE2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVSSE2((const uint8_t*)kv, size, target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllSSE2(arr, size, key); } }; INSTANTIATE_IMPL(ArchType::sse2); @@ -455,7 +661,7 @@ namespace kiwi template ARCH_TARGET("avx2") - bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) + inline bool nstSearchAVX2(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -642,9 +848,122 @@ namespace kiwi return false; } + template + ARCH_TARGET("avx2") + inline ValueTy nstSearchKVAVX2(const uint8_t* kv, size_t size, IntTy target) + { + if (size < (n + 1) / 2) + { + return nstSearchKVSSE2<(n + 1) / 2, IntTy, ValueTy>(kv, size, target); + } + + size_t i = 0, r; + + __m256i ptarget, pkey, peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm256_set1_epi8(target); + break; + case 2: + ptarget = _mm256_set1_epi16(target); + break; + case 4: + ptarget = _mm256_set1_epi32(target); + break; + case 8: + ptarget = _mm256_set1_epi64x(target); + break; + } + + if (size < n) + { + pkey = _mm256_loadu_si256(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm256_cmpeq_epi8(ptarget, pkey); + break; + case 2: + peq = _mm256_cmpeq_epi16(ptarget, pkey); + break; + case 4: + peq = _mm256_cmpeq_epi32(ptarget, pkey); + break; + case 8: + peq = _mm256_cmpeq_epi64(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = _mm256_loadu_si256(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm256_cmpeq_epi8(ptarget, pkey); + pgt = _mm256_cmpgt_epi8(ptarget, pkey); + break; + case 2: + peq = _mm256_cmpeq_epi16(ptarget, pkey); + pgt = _mm256_cmpgt_epi16(ptarget, pkey); + break; + case 4: + peq = _mm256_cmpeq_epi32(ptarget, pkey); + pgt = _mm256_cmpgt_epi32(ptarget, pkey); + break; + case 8: + peq = _mm256_cmpeq_epi64(ptarget, pkey); + pgt = _mm256_cmpgt_epi64(ptarget, pkey); + break; + } + + if (testEq(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + + r = utils::popcount((uint32_t)_mm256_movemask_epi8(pgt)) / sizeof(IntTy); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + ARCH_TARGET("avx2") + inline size_t findAllAVX2(const uint8_t* arr, size_t size, uint8_t key) + { + if (size <= 32) + { + __m256i pkey = _mm256_set1_epi8(key); + __m256i parr = _mm256_loadu_si256(reinterpret_cast(arr)); + __m256i pcmp = _mm256_cmpeq_epi8(pkey, parr); + return (size_t)_mm256_movemask_epi8(pcmp) & (((size_t)1 << size) - 1); + } + else + { + __m256i pkey = _mm256_set1_epi8(key); + __m256i parr0 = _mm256_loadu_si256(reinterpret_cast(arr)); + __m256i parr1 = _mm256_loadu_si256(reinterpret_cast(arr + 32)); + __m256i pcmp0 = _mm256_cmpeq_epi8(pkey, parr0); + __m256i pcmp1 = _mm256_cmpeq_epi8(pkey, parr1); + return ((size_t)_mm256_movemask_epi8(pcmp0) | ((size_t)_mm256_movemask_epi8(pcmp1) << 32)) & (((size_t)1 << size) - 1); + } + } + template - ARCH_TARGET("avx512bw") - bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) + ARCH_TARGET("avx,avx2,avx512f,avx512bw,avx512dq") + inline bool nstSearchAVX512(const IntTy* keys, size_t size, IntTy target, size_t& ret) { size_t i = 0, r; @@ -832,21 +1151,140 @@ namespace kiwi return false; } + template + ARCH_TARGET("avx,avx2,avx512f,avx512bw,avx512dq") + inline ValueTy nstSearchKVAVX512(const uint8_t* kv, size_t size, IntTy target) + { + if (size < (n + 1) / 2) + { + return nstSearchKVAVX2<(n + 1) / 2, IntTy, ValueTy>(kv, size, target); + } + + size_t i = 0, r; + const IntTy* keys; + + __m512i ptarget, pkey; + uint64_t peq, pgt; + switch (sizeof(IntTy)) + { + case 1: + ptarget = _mm512_set1_epi8(target); + break; + case 2: + ptarget = _mm512_set1_epi16(target); + break; + case 4: + ptarget = _mm512_set1_epi32(target); + break; + case 8: + ptarget = _mm512_set1_epi64(target); + break; + } + + if (size < n) + { + keys = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))]); + pkey = _mm512_loadu_si512(reinterpret_cast(keys)); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm512_cmpeq_epi8_mask(ptarget, pkey); + break; + case 2: + peq = _mm512_cmpeq_epi16_mask(ptarget, pkey); + break; + case 4: + peq = _mm512_cmpeq_epi32_mask(ptarget, pkey); + break; + case 8: + peq = _mm512_cmpeq_epi64_mask(ptarget, pkey); + break; + } + + if (testEqMask(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&keys[groupSize]); + return values[r]; + } + return 0; + } + + while (i < size) + { + keys = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))]); + pkey = _mm512_loadu_si512(reinterpret_cast(keys)); + switch (sizeof(IntTy)) + { + case 1: + peq = _mm512_cmpeq_epi8_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi8_mask(ptarget, pkey); + break; + case 2: + peq = _mm512_cmpeq_epi16_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi16_mask(ptarget, pkey); + break; + case 4: + peq = _mm512_cmpeq_epi32_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi32_mask(ptarget, pkey); + break; + case 8: + peq = _mm512_cmpeq_epi64_mask(ptarget, pkey); + pgt = _mm512_cmpgt_epi64_mask(ptarget, pkey); + break; + } + + if (testEqMask(peq, 0, size - i, r)) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&keys[groupSize]); + return values[r]; + } + + r = utils::popcount(pgt); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + ARCH_TARGET("avx512bw") + inline size_t findAllAVX512(const uint8_t* arr, size_t size, uint8_t key) + { + __m512i pkey = _mm512_set1_epi8(key); + __m512i parr = _mm512_loadu_si512(reinterpret_cast(arr)); + __mmask64 pcmp = _mm512_cmpeq_epi8_mask(pkey, parr); + return (size_t)pcmp & (((size_t)1 << size) - 1); + } + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchSSE2<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchSSE2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVSSE2((const uint8_t*)kv, size, target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllSSE2(arr, size, key); } }; INSTANTIATE_IMPL(ArchType::sse4_1); @@ -854,40 +1292,80 @@ namespace kiwi template<> struct OptimizedImpl { + static constexpr size_t packetSize = 32; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<32 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchAVX2<32 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchAVX2((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVAVX2((const uint8_t*)kv, size, target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllAVX2(arr, size, key); } }; INSTANTIATE_IMPL(ArchType::avx2); + template<> + struct OptimizedImpl : public OptimizedImpl + { + }; + INSTANTIATE_IMPL(ArchType::avx_vnni); + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 64; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<64 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchAVX512<64 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchAVX512((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVAVX512((const uint8_t*)kv, size, target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllAVX512(arr, size, key); } }; INSTANTIATE_IMPL(ArchType::avx512bw); + + template<> + struct OptimizedImpl : public OptimizedImpl + { + }; + INSTANTIATE_IMPL(ArchType::avx512vnni); } } #endif @@ -899,7 +1377,7 @@ namespace kiwi { template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int8_t* keys, size_t size, int8_t target, size_t& ret) + inline bool nstSearchNeon(const int8_t* keys, size_t size, int8_t target, size_t& ret) { size_t i = 0; @@ -932,7 +1410,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int16_t* keys, size_t size, int16_t target, size_t& ret) + inline bool nstSearchNeon(const int16_t* keys, size_t size, int16_t target, size_t& ret) { size_t i = 0; @@ -962,7 +1440,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int32_t* keys, size_t size, int32_t target, size_t& ret) + inline bool nstSearchNeon(const int32_t* keys, size_t size, int32_t target, size_t& ret) { size_t i = 0; @@ -992,7 +1470,7 @@ namespace kiwi template ARCH_TARGET("arch=armv8-a") - bool nstSearchNeon(const int64_t* keys, size_t size, int64_t target, size_t& ret) + inline bool nstSearchNeon(const int64_t* keys, size_t size, int64_t target, size_t& ret) { size_t i = 0; @@ -1020,21 +1498,307 @@ namespace kiwi return false; } + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int8_t target) + { + using IntTy = int8_t; + size_t i = 0; + int8x16_t ptarget = vdupq_n_s8(target), pkey; + uint8x16_t peq, pgt, pmasked; + + static const uint8x16_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128 }; + + if (size < n) + { + pkey = vld1q_s8(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s8(ptarget, pkey); + pmasked = vandq_u8(peq, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + uint32_t mm = mm0 | ((uint32_t)mm1 << 8); + uint32_t r = utils::countTrailingZeroes(mm); + + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s8(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s8(ptarget, pkey); + pgt = vcgtq_s8(ptarget, pkey); + pmasked = vandq_u8(peq, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + uint32_t mm = mm0 | ((uint32_t)mm1 << 8); + uint32_t r = utils::countTrailingZeroes(mm); + + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u8(vandq_u8(pgt, vdupq_n_u8(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int16_t target) + { + using IntTy = int16_t; + size_t i = 0; + int16x8_t ptarget = vdupq_n_s16(target), pkey; + uint16x8_t peq, pgt; + + static const uint16x8_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128 }; + + if (size < n) + { + pkey = vld1q_s16(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s16(ptarget, pkey); + uint32_t mm = vaddvq_u16(vandq_u16(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s16(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s16(ptarget, pkey); + pgt = vcgtq_s16(ptarget, pkey); + uint32_t mm = vaddvq_u16(vandq_u16(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u16(vandq_u16(pgt, vdupq_n_u16(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int32_t target) + { + using IntTy = int32_t; + size_t i = 0; + int32x4_t ptarget = vdupq_n_s32(target), pkey; + uint32x4_t peq, pgt; + + static const uint32x4_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8 }; + + if (size < n) + { + pkey = vld1q_s32(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s32(ptarget, pkey); + uint32_t mm = vaddvq_u32(vandq_u32(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s32(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s32(ptarget, pkey); + pgt = vcgtq_s32(ptarget, pkey); + uint32_t mm = vaddvq_u32(vandq_u32(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u32(vandq_u32(pgt, vdupq_n_u32(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + template + ARCH_TARGET("arch=armv8-a") + inline ValueTy nstSearchKVNeon(const uint8_t* kv, size_t size, int64_t target) + { + using IntTy = int64_t; + size_t i = 0; + int64x2_t ptarget = vdupq_n_s64(target), pkey; + uint64x2_t peq, pgt; + + static const uint64x2_t __attribute__((aligned(16))) mask = { 1, 2 }; + + if (size < n) + { + pkey = vld1q_s64(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s64(ptarget, pkey); + uint32_t mm = vaddvq_u64(vandq_u64(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + return 0; + } + + while (i < size) + { + pkey = vld1q_s64(reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy))])); + peq = vceqq_s64(ptarget, pkey); + pgt = vcgtq_s64(ptarget, pkey); + uint32_t mm = vaddvq_u64(vandq_u64(peq, mask)); + uint32_t r = utils::countTrailingZeroes(mm); + if (mm && (i + r) < size) + { + const size_t groupSize = std::min(n - 1, size - i); + const ValueTy* values = reinterpret_cast(&kv[i * (sizeof(IntTy) + sizeof(ValueTy)) + groupSize * sizeof(IntTy)]); + return values[r]; + } + r = vaddvq_u64(vandq_u64(pgt, vdupq_n_u64(1))); + i = i * n + (n - 1) * (r + 1); + } + return 0; + } + + ARCH_TARGET("arch=armv8-a") + inline size_t findAllNeon(const uint8_t* arr, size_t size, uint8_t key) + { + int8x16_t pkey = vdupq_n_s8(key); + static const uint8x16_t __attribute__((aligned(16))) mask = { 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128 }; + uint8x16_t pmasked; + if (size <= 16) + { + int8x16_t parr = vld1q_s8(reinterpret_cast(arr)); + uint8x16_t pcmp = vceqq_s8(pkey, parr); + pmasked = vandq_u8(pcmp, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + return (mm0 | ((size_t)mm1 << 8)) & (((size_t)1 << size) - 1); + } + else if (size <= 32) + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16)) & (((size_t)1 << size) - 1); + } + else if (size <= 48) + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + int8x16_t parr2 = vld1q_s8(reinterpret_cast(arr + 32)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + uint8x16_t pcmp2 = vceqq_s8(pkey, parr2); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp2, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r2 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16) | (r2 << 32)) & (((size_t)1 << size) - 1); + } + else + { + int8x16_t parr0 = vld1q_s8(reinterpret_cast(arr)); + int8x16_t parr1 = vld1q_s8(reinterpret_cast(arr + 16)); + int8x16_t parr2 = vld1q_s8(reinterpret_cast(arr + 32)); + int8x16_t parr3 = vld1q_s8(reinterpret_cast(arr + 48)); + uint8x16_t pcmp0 = vceqq_s8(pkey, parr0); + uint8x16_t pcmp1 = vceqq_s8(pkey, parr1); + uint8x16_t pcmp2 = vceqq_s8(pkey, parr2); + uint8x16_t pcmp3 = vceqq_s8(pkey, parr3); + pmasked = vandq_u8(pcmp0, mask); + uint8_t mm0 = vaddv_u8(vget_low_u8(pmasked)); + uint8_t mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r0 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp1, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r1 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp2, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r2 = mm0 | ((size_t)mm1 << 8); + pmasked = vandq_u8(pcmp3, mask); + mm0 = vaddv_u8(vget_low_u8(pmasked)); + mm1 = vaddv_u8(vget_high_u8(pmasked)); + size_t r3 = mm0 | ((size_t)mm1 << 8); + return (r0 | (r1 << 16) | (r2 << 32) | (r3 << 48)) & (((size_t)1 << size) - 1); + } + } + + template<> struct OptimizedImpl { + static constexpr size_t packetSize = 16; + template static Vector reorder(const IntTy* keys, size_t size) { using SignedIntTy = typename SignedType::type; - return getNstOrder<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, true); + return getNstOrder((const SignedIntTy*)keys, size, true); } template static bool search(const IntTy* keys, size_t size, IntTy target, size_t& ret) { using SignedIntTy = typename SignedType::type; - return nstSearchNeon<16 / sizeof(IntTy) + 1>((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + return nstSearchNeon((const SignedIntTy*)keys, size, (SignedIntTy)target, ret); + } + + template + static ValueTy searchKV(const void* kv, size_t size, IntTy target) + { + using SignedIntTy = typename SignedType::type; + return nstSearchKVNeon((const uint8_t*)kv, size, (SignedIntTy)target); + } + + static size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + return findAllNeon(arr, size, key); } }; INSTANTIATE_IMPL(ArchType::neon); diff --git a/src/search.h b/src/search.h index e3a118de..1e153452 100644 --- a/src/search.h +++ b/src/search.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -20,8 +21,17 @@ namespace kiwi template bool searchImpl(const IntTy* keys, size_t size, IntTy target, size_t& ret); + template + ValueTy searchKVImpl(const void* keys, size_t size, IntTy target); + template Vector reorderImpl(const IntTy* keys, size_t size); + + template + size_t getPacketSizeImpl(); + + template + size_t findAllImpl(const uint8_t* arr, size_t size, uint8_t key); } template @@ -50,6 +60,56 @@ namespace kiwi } } + template + void prepareKV(void* dest, size_t size, Vector& tempBuf) + { + const size_t packetSize = detail::getPacketSizeImpl() / sizeof(IntTy); + if (size <= 1 || packetSize <= 1) return; + auto order = detail::reorderImpl(reinterpret_cast(dest), size); + if (order.empty()) return; + + if (tempBuf.size() < (sizeof(IntTy) + sizeof(Value)) * size) + { + tempBuf.resize((sizeof(IntTy) + sizeof(Value)) * size); + } + std::memcpy(tempBuf.data(), dest, (sizeof(IntTy) + sizeof(Value)) * size); + auto tempKeys = (IntTy*)tempBuf.data(); + auto tempValues = (Value*)(tempKeys + size); + for (size_t i = 0; i < size; i += packetSize) + { + const size_t groupSize = std::min(packetSize, size - i); + for (size_t j = 0; j < groupSize; ++j) + { + *reinterpret_cast(dest) = tempKeys[order[i + j]]; + dest = reinterpret_cast(dest) + sizeof(IntTy); + } + for (size_t j = 0; j < groupSize; ++j) + { + *reinterpret_cast(dest) = tempValues[order[i + j]]; + dest = reinterpret_cast(dest) + sizeof(Value); + } + } + } + + template + std::pair extractKV(const void* kv, size_t totSize, size_t idx) + { + const size_t packetSize = detail::getPacketSizeImpl() / sizeof(IntTy); + if (packetSize <= 1) + { + const auto* key = reinterpret_cast(kv); + const auto* value = reinterpret_cast(key + totSize); + return std::make_pair(key[idx], value[idx]); + } + + const size_t groupIdx = idx / packetSize; + const size_t groupOffset = idx % packetSize; + const auto* group = reinterpret_cast(kv) + groupIdx * packetSize * (sizeof(IntTy) + sizeof(Value)); + const auto* key = reinterpret_cast(group); + const auto* value = reinterpret_cast(key + std::min(packetSize, totSize - groupIdx * packetSize)); + return std::make_pair(key[groupOffset], value[groupOffset]); + } + template bool search(const IntTy* keys, const Value* values, size_t size, IntTy target, Out& ret) { @@ -73,5 +133,18 @@ namespace kiwi } else return false; } + + template + Out searchKV(const void* kv, size_t size, IntTy target) + { + return detail::searchKVImpl::type>(kv, size, target); + } + + template + size_t findAll(const uint8_t* arr, size_t size, uint8_t key) + { + if (size == 0) return 0; + return detail::findAllImpl(arr, size, key); + } } } diff --git a/src/string_view.hpp b/src/string_view.hpp deleted file mode 100644 index 6f4f724b..00000000 --- a/src/string_view.hpp +++ /dev/null @@ -1,1773 +0,0 @@ -// Copyright 2017-2020 by Martin Moene -// -// string-view lite, a C++17-like string_view for C++98 and later. -// For more information see https://github.com/martinmoene/string-view-lite -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) - -#pragma once - -#ifndef NONSTD_SV_LITE_H_INCLUDED -#define NONSTD_SV_LITE_H_INCLUDED - -#define string_view_lite_MAJOR 1 -#define string_view_lite_MINOR 6 -#define string_view_lite_PATCH 0 - -#define string_view_lite_VERSION nssv_STRINGIFY(string_view_lite_MAJOR) "." nssv_STRINGIFY(string_view_lite_MINOR) "." nssv_STRINGIFY(string_view_lite_PATCH) - -#define nssv_STRINGIFY( x ) nssv_STRINGIFY_( x ) -#define nssv_STRINGIFY_( x ) #x - -// string-view lite configuration: - -#define nssv_STRING_VIEW_DEFAULT 0 -#define nssv_STRING_VIEW_NONSTD 1 -#define nssv_STRING_VIEW_STD 2 - -// tweak header support: - -#ifdef __has_include -# if __has_include() -# include -# endif -#define nssv_HAVE_TWEAK_HEADER 1 -#else -#define nssv_HAVE_TWEAK_HEADER 0 -//# pragma message("string_view.hpp: Note: Tweak header not supported.") -#endif - -// string_view selection and configuration: - -#if !defined( nssv_CONFIG_SELECT_STRING_VIEW ) -# define nssv_CONFIG_SELECT_STRING_VIEW ( nssv_HAVE_STD_STRING_VIEW ? nssv_STRING_VIEW_STD : nssv_STRING_VIEW_NONSTD ) -#endif - -#ifndef nssv_CONFIG_STD_SV_OPERATOR -# define nssv_CONFIG_STD_SV_OPERATOR 0 -#endif - -#ifndef nssv_CONFIG_USR_SV_OPERATOR -# define nssv_CONFIG_USR_SV_OPERATOR 1 -#endif - -#ifdef nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS nssv_CONFIG_CONVERSION_STD_STRING -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS 1 -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS 1 -#endif - -#ifndef nssv_CONFIG_NO_STREAM_INSERTION -# define nssv_CONFIG_NO_STREAM_INSERTION 0 -#endif - -// Control presence of exception handling (try and auto discover): - -#ifndef nssv_CONFIG_NO_EXCEPTIONS -# if defined(_MSC_VER) -# include // for _HAS_EXCEPTIONS -# endif -# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (_HAS_EXCEPTIONS) -# define nssv_CONFIG_NO_EXCEPTIONS 0 -# else -# define nssv_CONFIG_NO_EXCEPTIONS 1 -# endif -#endif - -// C++ language version detection (C++20 is speculative): -// Note: VC14.0/1900 (VS2015) lacks too much from C++14. - -#ifndef nssv_CPLUSPLUS -# if defined(_MSVC_LANG ) && !defined(__clang__) -# define nssv_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) -# else -# define nssv_CPLUSPLUS __cplusplus -# endif -#endif - -#define nssv_CPP98_OR_GREATER ( nssv_CPLUSPLUS >= 199711L ) -#define nssv_CPP11_OR_GREATER ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP11_OR_GREATER_ ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP14_OR_GREATER ( nssv_CPLUSPLUS >= 201402L ) -#define nssv_CPP17_OR_GREATER ( nssv_CPLUSPLUS >= 201703L ) -#define nssv_CPP20_OR_GREATER ( nssv_CPLUSPLUS >= 202000L ) - -// use C++17 std::string_view if available and requested: - -#if nssv_CPP17_OR_GREATER && defined(__has_include ) -# if __has_include( ) -# define nssv_HAVE_STD_STRING_VIEW 1 -# else -# define nssv_HAVE_STD_STRING_VIEW 0 -# endif -#else -# define nssv_HAVE_STD_STRING_VIEW 0 -#endif - -#define nssv_USES_STD_STRING_VIEW ( (nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_STD) || ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_DEFAULT) && nssv_HAVE_STD_STRING_VIEW) ) - -#define nssv_HAVE_STARTS_WITH ( nssv_CPP20_OR_GREATER || !nssv_USES_STD_STRING_VIEW ) -#define nssv_HAVE_ENDS_WITH nssv_HAVE_STARTS_WITH - -// -// Use C++17 std::string_view: -// - -#if nssv_USES_STD_STRING_VIEW - -#include - -// Extensions for std::string: - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - - template< class CharT, class Traits, class Allocator = std::allocator > - std::basic_string - to_string(std::basic_string_view v, Allocator const& a = Allocator()) - { - return std::basic_string(v.begin(), v.end(), a); - } - - template< class CharT, class Traits, class Allocator > - std::basic_string_view - to_string_view(std::basic_string const& s) - { - return std::basic_string_view(s.data(), s.size()); - } - - // Literal operators sv and _sv: - -#if nssv_CONFIG_STD_SV_OPERATOR - - using namespace std::literals::string_view_literals; - -#endif - -#if nssv_CONFIG_USR_SV_OPERATOR - - inline namespace literals { - inline namespace string_view_literals { - - - constexpr std::string_view operator "" _sv(const char* str, size_t len) noexcept // (1) - { - return std::string_view{ str, len }; - } - - constexpr std::u16string_view operator "" _sv(const char16_t* str, size_t len) noexcept // (2) - { - return std::u16string_view{ str, len }; - } - - constexpr std::u32string_view operator "" _sv(const char32_t* str, size_t len) noexcept // (3) - { - return std::u32string_view{ str, len }; - } - - constexpr std::wstring_view operator "" _sv(const wchar_t* str, size_t len) noexcept // (4) - { - return std::wstring_view{ str, len }; - } - - } - } // namespace literals::string_view_literals - -#endif // nssv_CONFIG_USR_SV_OPERATOR - -} // namespace nonstd - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - - using std::string_view; - using std::wstring_view; - using std::u16string_view; - using std::u32string_view; - using std::basic_string_view; - - // literal "sv" and "_sv", see above - - using std::operator==; - using std::operator!=; - using std::operator<; - using std::operator<=; - using std::operator>; - using std::operator>=; - - using std::operator<<; - -} // namespace nonstd - -#else // nssv_HAVE_STD_STRING_VIEW - -// -// Before C++17: use string_view lite: -// - -// Compiler versions: -// -// MSVC++ 6.0 _MSC_VER == 1200 nssv_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0) -// MSVC++ 7.0 _MSC_VER == 1300 nssv_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002) -// MSVC++ 7.1 _MSC_VER == 1310 nssv_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003) -// MSVC++ 8.0 _MSC_VER == 1400 nssv_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005) -// MSVC++ 9.0 _MSC_VER == 1500 nssv_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008) -// MSVC++ 10.0 _MSC_VER == 1600 nssv_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010) -// MSVC++ 11.0 _MSC_VER == 1700 nssv_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012) -// MSVC++ 12.0 _MSC_VER == 1800 nssv_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013) -// MSVC++ 14.0 _MSC_VER == 1900 nssv_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015) -// MSVC++ 14.1 _MSC_VER >= 1910 nssv_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017) -// MSVC++ 14.2 _MSC_VER >= 1920 nssv_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019) - -#if defined(_MSC_VER ) && !defined(__clang__) -# define nssv_COMPILER_MSVC_VER (_MSC_VER ) -# define nssv_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) ) -#else -# define nssv_COMPILER_MSVC_VER 0 -# define nssv_COMPILER_MSVC_VERSION 0 -#endif - -#define nssv_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) ) - -#if defined( __apple_build_version__ ) -# define nssv_COMPILER_APPLECLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -# define nssv_COMPILER_CLANG_VERSION 0 -#elif defined( __clang__ ) -# define nssv_COMPILER_APPLECLANG_VERSION 0 -# define nssv_COMPILER_CLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -#else -# define nssv_COMPILER_APPLECLANG_VERSION 0 -# define nssv_COMPILER_CLANG_VERSION 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) -# define nssv_COMPILER_GNUC_VERSION nssv_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#else -# define nssv_COMPILER_GNUC_VERSION 0 -#endif - -// half-open range [lo..hi): -#define nssv_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) - -// Presence of language and library features: - -#ifdef _HAS_CPP0X -# define nssv_HAS_CPP0X _HAS_CPP0X -#else -# define nssv_HAS_CPP0X 0 -#endif - -// Unless defined otherwise below, consider VC14 as C++11 for variant-lite: - -#if nssv_COMPILER_MSVC_VER >= 1900 -# undef nssv_CPP11_OR_GREATER -# define nssv_CPP11_OR_GREATER 1 -#endif - -#define nssv_CPP11_90 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1500) -#define nssv_CPP11_100 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1600) -#define nssv_CPP11_110 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1700) -#define nssv_CPP11_120 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1800) -#define nssv_CPP11_140 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1900) -#define nssv_CPP11_141 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1910) - -#define nssv_CPP14_000 (nssv_CPP14_OR_GREATER) -#define nssv_CPP17_000 (nssv_CPP17_OR_GREATER) - -// Presence of C++11 language features: - -#define nssv_HAVE_CONSTEXPR_11 nssv_CPP11_140 -#define nssv_HAVE_EXPLICIT_CONVERSION nssv_CPP11_140 -#define nssv_HAVE_INLINE_NAMESPACE nssv_CPP11_140 -#define nssv_HAVE_IS_DEFAULT nssv_CPP11_140 -#define nssv_HAVE_IS_DELETE nssv_CPP11_140 -#define nssv_HAVE_NOEXCEPT nssv_CPP11_140 -#define nssv_HAVE_NULLPTR nssv_CPP11_100 -#define nssv_HAVE_REF_QUALIFIER nssv_CPP11_140 -#define nssv_HAVE_UNICODE_LITERALS nssv_CPP11_140 -#define nssv_HAVE_USER_DEFINED_LITERALS nssv_CPP11_140 -#define nssv_HAVE_WCHAR16_T nssv_CPP11_100 -#define nssv_HAVE_WCHAR32_T nssv_CPP11_100 - -#if ! ( ( nssv_CPP11_OR_GREATER && nssv_COMPILER_CLANG_VERSION ) || nssv_BETWEEN( nssv_COMPILER_CLANG_VERSION, 300, 400 ) ) -# define nssv_HAVE_STD_DEFINED_LITERALS nssv_CPP11_140 -#else -# define nssv_HAVE_STD_DEFINED_LITERALS 0 -#endif - -// Presence of C++14 language features: - -#define nssv_HAVE_CONSTEXPR_14 nssv_CPP14_000 - -// Presence of C++17 language features: - -#define nssv_HAVE_NODISCARD nssv_CPP17_000 - -// Presence of C++ library features: - -#define nssv_HAVE_STD_HASH nssv_CPP11_120 - -// Presence of compiler intrinsics: - -// Providing char-type specializations for compare() and length() that -// use compiler intrinsics can improve compile- and run-time performance. -// -// The challenge is in using the right combinations of builtin availability -// and its constexpr-ness. -// -// | compiler | __builtin_memcmp (constexpr) | memcmp (constexpr) | -// |----------|------------------------------|---------------------| -// | clang | 4.0 (>= 4.0 ) | any (? ) | -// | clang-a | 9.0 (>= 9.0 ) | any (? ) | -// | gcc | any (constexpr) | any (? ) | -// | msvc | >= 14.2 C++17 (>= 14.2 ) | any (? ) | - -#define nssv_HAVE_BUILTIN_VER ( (nssv_CPP17_000 && nssv_COMPILER_MSVC_VERSION >= 142) || nssv_COMPILER_GNUC_VERSION > 0 || nssv_COMPILER_CLANG_VERSION >= 400 || nssv_COMPILER_APPLECLANG_VERSION >= 900 ) -#define nssv_HAVE_BUILTIN_CE ( nssv_HAVE_BUILTIN_VER ) - -#define nssv_HAVE_BUILTIN_MEMCMP ( (nssv_HAVE_CONSTEXPR_14 && nssv_HAVE_BUILTIN_CE) || !nssv_HAVE_CONSTEXPR_14 ) -#define nssv_HAVE_BUILTIN_STRLEN ( (nssv_HAVE_CONSTEXPR_11 && nssv_HAVE_BUILTIN_CE) || !nssv_HAVE_CONSTEXPR_11 ) - -#ifdef __has_builtin -# define nssv_HAVE_BUILTIN( x ) __has_builtin( x ) -#else -# define nssv_HAVE_BUILTIN( x ) 0 -#endif - -#if nssv_HAVE_BUILTIN(__builtin_memcmp) || nssv_HAVE_BUILTIN_VER -# define nssv_BUILTIN_MEMCMP __builtin_memcmp -#else -# define nssv_BUILTIN_MEMCMP memcmp -#endif - -#if nssv_HAVE_BUILTIN(__builtin_strlen) || nssv_HAVE_BUILTIN_VER -# define nssv_BUILTIN_STRLEN __builtin_strlen -#else -# define nssv_BUILTIN_STRLEN strlen -#endif - -// C++ feature usage: - -#if nssv_HAVE_CONSTEXPR_11 -# define nssv_constexpr constexpr -#else -# define nssv_constexpr /*constexpr*/ -#endif - -#if nssv_HAVE_CONSTEXPR_14 -# define nssv_constexpr14 constexpr -#else -# define nssv_constexpr14 /*constexpr*/ -#endif - -#if nssv_HAVE_EXPLICIT_CONVERSION -# define nssv_explicit explicit -#else -# define nssv_explicit /*explicit*/ -#endif - -#if nssv_HAVE_INLINE_NAMESPACE -# define nssv_inline_ns inline -#else -# define nssv_inline_ns /*inline*/ -#endif - -#if nssv_HAVE_NOEXCEPT -# define nssv_noexcept noexcept -#else -# define nssv_noexcept /*noexcept*/ -#endif - -//#if nssv_HAVE_REF_QUALIFIER -//# define nssv_ref_qual & -//# define nssv_refref_qual && -//#else -//# define nssv_ref_qual /*&*/ -//# define nssv_refref_qual /*&&*/ -//#endif - -#if nssv_HAVE_NULLPTR -# define nssv_nullptr nullptr -#else -# define nssv_nullptr NULL -#endif - -#if nssv_HAVE_NODISCARD -# define nssv_nodiscard [[nodiscard]] -#else -# define nssv_nodiscard /*[[nodiscard]]*/ -#endif - -// Additional includes: - -#include -#include -#include -#include -#include // std::char_traits<> - -#if ! nssv_CONFIG_NO_STREAM_INSERTION -# include -#endif - -#if ! nssv_CONFIG_NO_EXCEPTIONS -# include -#endif - -#if nssv_CPP11_OR_GREATER -# include -#endif - -// Clang, GNUC, MSVC warning suppression macros: - -#if defined(__clang__) -# pragma clang diagnostic ignored "-Wreserved-user-defined-literal" -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wuser-defined-literals" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wliteral-suffix" -#endif // __clang__ - -#if nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_SUPPRESS_MSGSL_WARNING(expr) [[gsl::suppress(expr)]] -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) __pragma(warning(suppress: code) ) -# define nssv_DISABLE_MSVC_WARNINGS(codes) __pragma(warning(push)) __pragma(warning(disable: codes)) -#else -# define nssv_SUPPRESS_MSGSL_WARNING(expr) -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) -# define nssv_DISABLE_MSVC_WARNINGS(codes) -#endif - -#if defined(__clang__) -# define nssv_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") -#elif defined(__GNUC__) -# define nssv_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") -#elif nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_RESTORE_WARNINGS() __pragma(warning(pop )) -#else -# define nssv_RESTORE_WARNINGS() -#endif - -// Suppress the following MSVC (GSL) warnings: -// - C4455, non-gsl : 'operator ""sv': literal suffix identifiers that do not -// start with an underscore are reserved -// - C26472, gsl::t.1 : don't use a static_cast for arithmetic conversions; -// use brace initialization, gsl::narrow_cast or gsl::narow -// - C26481: gsl::b.1 : don't use pointer arithmetic. Use span instead - -nssv_DISABLE_MSVC_WARNINGS(4455 26481 26472) -//nssv_DISABLE_CLANG_WARNINGS( "-Wuser-defined-literals" ) -//nssv_DISABLE_GNUC_WARNINGS( -Wliteral-suffix ) - -namespace nonstd { - namespace sv_lite { - - // - // basic_string_view declaration: - // - - template - < - class CharT, - class Traits = std::char_traits - > - class basic_string_view; - - namespace detail { - - // support constexpr comparison in C++14; - // for C++17 and later, use provided traits: - - template< typename CharT > - inline nssv_constexpr14 int compare(CharT const* s1, CharT const* s2, std::size_t count) - { - while (count-- != 0) - { - if (*s1 < *s2) return -1; - if (*s1 > *s2) return +1; - ++s1; ++s2; - } - return 0; - } - -#if nssv_HAVE_BUILTIN_MEMCMP - - // specialization of compare() for char, see also generic compare() above: - - inline nssv_constexpr14 int compare(char const* s1, char const* s2, std::size_t count) - { - return nssv_BUILTIN_MEMCMP(s1, s2, count); - } - -#endif - -#if nssv_HAVE_BUILTIN_STRLEN - - // specialization of length() for char, see also generic length() further below: - - inline nssv_constexpr std::size_t length(char const* s) - { - return nssv_BUILTIN_STRLEN(s); - } - -#endif - -#if defined(__OPTIMIZE__) - - // gcc, clang provide __OPTIMIZE__ - // Expect tail call optimization to make length() non-recursive: - - template< typename CharT > - inline nssv_constexpr std::size_t length(CharT* s, std::size_t result = 0) - { - return *s == '\0' ? result : length(s + 1, result + 1); - } - -#else // OPTIMIZE - - // non-recursive: - - template< typename CharT > - inline nssv_constexpr14 std::size_t length(CharT* s) - { - std::size_t result = 0; - while (*s++ != '\0') - { - ++result; - } - return result; - } - -#endif // OPTIMIZE - -#if nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER -#if defined(__OPTIMIZE__) - - // gcc, clang provide __OPTIMIZE__ - // Expect tail call optimization to make search() non-recursive: - - template< class CharT, class Traits = std::char_traits > - constexpr const CharT* search(basic_string_view haystack, basic_string_view needle) - { - return haystack.starts_with(needle) ? haystack.begin() : - haystack.empty() ? haystack.end() : search(haystack.substr(1), needle); - } - -#else // OPTIMIZE - - // non-recursive: - - template< class CharT, class Traits = std::char_traits > - constexpr const CharT* search(basic_string_view haystack, basic_string_view needle) - { - return std::search(haystack.begin(), haystack.end(), needle.begin(), needle.end()); - } - -#endif // OPTIMIZE -#endif // nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER - - } // namespace detail - - // - // basic_string_view: - // - - template - < - class CharT, - class Traits /* = std::char_traits */ - > - class basic_string_view - { - public: - // Member types: - - typedef Traits traits_type; - typedef CharT value_type; - - typedef CharT* pointer; - typedef CharT const* const_pointer; - typedef CharT& reference; - typedef CharT const& const_reference; - - typedef const_pointer iterator; - typedef const_pointer const_iterator; - typedef std::reverse_iterator< const_iterator > reverse_iterator; - typedef std::reverse_iterator< const_iterator > const_reverse_iterator; - - typedef std::size_t size_type; - typedef std::ptrdiff_t difference_type; - - // 24.4.2.1 Construction and assignment: - - nssv_constexpr basic_string_view() nssv_noexcept - : data_(nssv_nullptr) - , size_(0) - {} - -#if nssv_CPP11_OR_GREATER - nssv_constexpr basic_string_view(basic_string_view const& other) nssv_noexcept = default; -#else - nssv_constexpr basic_string_view(basic_string_view const& other) nssv_noexcept - : data_(other.data_) - , size_(other.size_) - {} -#endif - - nssv_constexpr basic_string_view(CharT const* s, size_type count) nssv_noexcept // non-standard noexcept - : data_(s) - , size_(count) - {} - - nssv_constexpr basic_string_view(CharT const* s) nssv_noexcept // non-standard noexcept - : data_(s) -#if nssv_CPP17_OR_GREATER - , size_(Traits::length(s)) -#elif nssv_CPP11_OR_GREATER - , size_(detail::length(s)) -#else - , size_(Traits::length(s)) -#endif - {} - -#if nssv_HAVE_NULLPTR -# if nssv_HAVE_IS_DELETE - nssv_constexpr basic_string_view(std::nullptr_t) nssv_noexcept = delete; -# else - private: nssv_constexpr basic_string_view(std::nullptr_t) nssv_noexcept; public: -# endif -#endif - - // Assignment: - -#if nssv_CPP11_OR_GREATER - nssv_constexpr14 basic_string_view& operator=(basic_string_view const& other) nssv_noexcept = default; -#else - nssv_constexpr14 basic_string_view& operator=(basic_string_view const& other) nssv_noexcept - { - data_ = other.data_; - size_ = other.size_; - return *this; - } -#endif - - // 24.4.2.2 Iterator support: - - nssv_constexpr const_iterator begin() const nssv_noexcept { return data_; } - nssv_constexpr const_iterator end() const nssv_noexcept { return data_ + size_; } - - nssv_constexpr const_iterator cbegin() const nssv_noexcept { return begin(); } - nssv_constexpr const_iterator cend() const nssv_noexcept { return end(); } - - nssv_constexpr const_reverse_iterator rbegin() const nssv_noexcept { return const_reverse_iterator(end()); } - nssv_constexpr const_reverse_iterator rend() const nssv_noexcept { return const_reverse_iterator(begin()); } - - nssv_constexpr const_reverse_iterator crbegin() const nssv_noexcept { return rbegin(); } - nssv_constexpr const_reverse_iterator crend() const nssv_noexcept { return rend(); } - - // 24.4.2.3 Capacity: - - nssv_constexpr size_type size() const nssv_noexcept { return size_; } - nssv_constexpr size_type length() const nssv_noexcept { return size_; } - nssv_constexpr size_type max_size() const nssv_noexcept { return (std::numeric_limits< size_type >::max)(); } - - // since C++20 - nssv_nodiscard nssv_constexpr bool empty() const nssv_noexcept - { - return 0 == size_; - } - - // 24.4.2.4 Element access: - - nssv_constexpr const_reference operator[](size_type pos) const - { - return data_at(pos); - } - - nssv_constexpr14 const_reference at(size_type pos) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos < size()); -#else - if (pos >= size()) - { - throw std::out_of_range("nonstd::string_view::at()"); - } -#endif - return data_at(pos); - } - - nssv_constexpr const_reference front() const { return data_at(0); } - nssv_constexpr const_reference back() const { return data_at(size() - 1); } - - nssv_constexpr const_pointer data() const nssv_noexcept { return data_; } - - // 24.4.2.5 Modifiers: - - nssv_constexpr14 void remove_prefix(size_type n) - { - assert(n <= size()); - data_ += n; - size_ -= n; - } - - nssv_constexpr14 void remove_suffix(size_type n) - { - assert(n <= size()); - size_ -= n; - } - - nssv_constexpr14 void swap(basic_string_view& other) nssv_noexcept - { - const basic_string_view tmp(other); - other = *this; - *this = tmp; - } - - // 24.4.2.6 String operations: - - size_type copy(CharT* dest, size_type n, size_type pos = 0) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) - { - throw std::out_of_range("nonstd::string_view::copy()"); - } -#endif - const size_type rlen = (std::min)(n, size() - pos); - - (void)Traits::copy(dest, data() + pos, rlen); - - return rlen; - } - - nssv_constexpr14 basic_string_view substr(size_type pos = 0, size_type n = npos) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert(pos <= size()); -#else - if (pos > size()) - { - throw std::out_of_range("nonstd::string_view::substr()"); - } -#endif - return basic_string_view(data() + pos, (std::min)(n, size() - pos)); - } - - // compare(), 6x: - - nssv_constexpr14 int compare(basic_string_view other) const nssv_noexcept // (1) - { -#if nssv_CPP17_OR_GREATER - if (const int result = Traits::compare(data(), other.data(), (std::min)(size(), other.size()))) -#else - if (const int result = detail::compare(data(), other.data(), (std::min)(size(), other.size()))) -#endif - { - return result; - } - - return size() == other.size() ? 0 : size() < other.size() ? -1 : 1; - } - - nssv_constexpr int compare(size_type pos1, size_type n1, basic_string_view other) const // (2) - { - return substr(pos1, n1).compare(other); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, basic_string_view other, size_type pos2, size_type n2) const // (3) - { - return substr(pos1, n1).compare(other.substr(pos2, n2)); - } - - nssv_constexpr int compare(CharT const* s) const // (4) - { - return compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, CharT const* s) const // (5) - { - return substr(pos1, n1).compare(basic_string_view(s)); - } - - nssv_constexpr int compare(size_type pos1, size_type n1, CharT const* s, size_type n2) const // (6) - { - return substr(pos1, n1).compare(basic_string_view(s, n2)); - } - - // 24.4.2.7 Searching: - - // starts_with(), 3x, since C++20: - - nssv_constexpr bool starts_with(basic_string_view v) const nssv_noexcept // (1) - { - return size() >= v.size() && compare(0, v.size(), v) == 0; - } - - nssv_constexpr bool starts_with(CharT c) const nssv_noexcept // (2) - { - return starts_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool starts_with(CharT const* s) const // (3) - { - return starts_with(basic_string_view(s)); - } - - // ends_with(), 3x, since C++20: - - nssv_constexpr bool ends_with(basic_string_view v) const nssv_noexcept // (1) - { - return size() >= v.size() && compare(size() - v.size(), npos, v) == 0; - } - - nssv_constexpr bool ends_with(CharT c) const nssv_noexcept // (2) - { - return ends_with(basic_string_view(&c, 1)); - } - - nssv_constexpr bool ends_with(CharT const* s) const // (3) - { - return ends_with(basic_string_view(s)); - } - - // find(), 4x: - - nssv_constexpr size_type find(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return assert(v.size() == 0 || v.data() != nssv_nullptr) - , pos >= size() - ? npos : to_pos( -#if nssv_CPP11_OR_GREATER && ! nssv_CPP17_OR_GREATER - detail::search(substr(pos), v) -#else - std::search(cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq) -#endif - ); - } - - nssv_constexpr size_type find(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find(CharT const* s, size_type pos, size_type n) const // (3) - { - return find(basic_string_view(s, n), pos); - } - - nssv_constexpr size_type find(CharT const* s, size_type pos = 0) const // (4) - { - return find(basic_string_view(s), pos); - } - - // rfind(), 4x: - - nssv_constexpr14 size_type rfind(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - if (size() < v.size()) - { - return npos; - } - - if (v.empty()) - { - return (std::min)(size(), pos); - } - - const_iterator last = cbegin() + (std::min)(size() - v.size(), pos) + v.size(); - const_iterator result = std::find_end(cbegin(), last, v.cbegin(), v.cend(), Traits::eq); - - return result != last ? size_type(result - cbegin()) : npos; - } - - nssv_constexpr14 size_type rfind(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return rfind(basic_string_view(&c, 1), pos); - } - - nssv_constexpr14 size_type rfind(CharT const* s, size_type pos, size_type n) const // (3) - { - return rfind(basic_string_view(s, n), pos); - } - - nssv_constexpr14 size_type rfind(CharT const* s, size_type pos = npos) const // (4) - { - return rfind(basic_string_view(s), pos); - } - - // find_first_of(), 4x: - - nssv_constexpr size_type find_first_of(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos(std::find_first_of(cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq)); - } - - nssv_constexpr size_type find_first_of(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_of(CharT const* s, size_type pos, size_type n) const // (3) - { - return find_first_of(basic_string_view(s, n), pos); - } - - nssv_constexpr size_type find_first_of(CharT const* s, size_type pos = 0) const // (4) - { - return find_first_of(basic_string_view(s), pos); - } - - // find_last_of(), 4x: - - nssv_constexpr size_type find_last_of(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_of(v, size() - 1) - : to_pos(std::find_first_of(const_reverse_iterator(cbegin() + pos + 1), crend(), v.cbegin(), v.cend(), Traits::eq)); - } - - nssv_constexpr size_type find_last_of(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_last_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_last_of(CharT const* s, size_type pos = npos) const // (4) - { - return find_last_of(basic_string_view(s), pos); - } - - // find_first_not_of(), 4x: - - nssv_constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos(std::find_if(cbegin() + pos, cend(), not_in_view(v))); - } - - nssv_constexpr size_type find_first_not_of(CharT c, size_type pos = 0) const nssv_noexcept // (2) - { - return find_first_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_first_not_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_first_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_first_not_of(CharT const* s, size_type pos = 0) const // (4) - { - return find_first_not_of(basic_string_view(s), pos); - } - - // find_last_not_of(), 4x: - - nssv_constexpr size_type find_last_not_of(basic_string_view v, size_type pos = npos) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_not_of(v, size() - 1) - : to_pos(std::find_if(const_reverse_iterator(cbegin() + pos + 1), crend(), not_in_view(v))); - } - - nssv_constexpr size_type find_last_not_of(CharT c, size_type pos = npos) const nssv_noexcept // (2) - { - return find_last_not_of(basic_string_view(&c, 1), pos); - } - - nssv_constexpr size_type find_last_not_of(CharT const* s, size_type pos, size_type count) const // (3) - { - return find_last_not_of(basic_string_view(s, count), pos); - } - - nssv_constexpr size_type find_last_not_of(CharT const* s, size_type pos = npos) const // (4) - { - return find_last_not_of(basic_string_view(s), pos); - } - - // Constants: - -#if nssv_CPP17_OR_GREATER - static nssv_constexpr size_type npos = size_type(-1); -#elif nssv_CPP11_OR_GREATER - enum : size_type { npos = size_type(-1) }; -#else - enum { npos = size_type(-1) }; -#endif - - private: - struct not_in_view - { - const basic_string_view v; - - nssv_constexpr explicit not_in_view(basic_string_view v_) : v(v_) {} - - nssv_constexpr bool operator()(CharT c) const - { - return npos == v.find_first_of(c); - } - }; - - nssv_constexpr size_type to_pos(const_iterator it) const - { - return it == cend() ? npos : size_type(it - cbegin()); - } - - nssv_constexpr size_type to_pos(const_reverse_iterator it) const - { - return it == crend() ? npos : size_type(crend() - it - 1); - } - - nssv_constexpr const_reference data_at(size_type pos) const - { -#if nssv_BETWEEN( nssv_COMPILER_GNUC_VERSION, 1, 500 ) - return data_[pos]; -#else - return assert(pos < size()), data_[pos]; -#endif - } - - private: - const_pointer data_; - size_type size_; - - public: -#if nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - - template< class Allocator > - basic_string_view(std::basic_string const& s) nssv_noexcept - : data_(s.data()) - , size_(s.size()) - {} - -#if nssv_HAVE_EXPLICIT_CONVERSION - - template< class Allocator > - explicit operator std::basic_string() const - { - return to_string(Allocator()); - } - -#endif // nssv_HAVE_EXPLICIT_CONVERSION - -#if nssv_CPP11_OR_GREATER - - template< class Allocator = std::allocator > - std::basic_string - to_string(Allocator const& a = Allocator()) const - { - return std::basic_string(begin(), end(), a); - } - -#else - - std::basic_string - to_string() const - { - return std::basic_string(begin(), end()); - } - - template< class Allocator > - std::basic_string - to_string(Allocator const& a) const - { - return std::basic_string(begin(), end(), a); - } - -#endif // nssv_CPP11_OR_GREATER - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - }; - - // - // Non-member functions: - // - - // 24.4.3 Non-member comparison functions: - // lexicographically compare two string views (function template): - - template< class CharT, class Traits > - nssv_constexpr bool operator== ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator!= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits > - nssv_constexpr bool operator< ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator<= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator> ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits > - nssv_constexpr bool operator>= ( - basic_string_view lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - // Let S be basic_string_view, and sv be an instance of S. - // Implementations shall provide sufficient additional overloads marked - // constexpr and noexcept so that an object t with an implicit conversion - // to S can be compared according to Table 67. - -#if ! nssv_CPP11_OR_GREATER || nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 100, 141 ) - -// accommodate for older compilers: - -// == - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.size() == detail::length(rhs) && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return detail::length(lhs) == rhs.size() && rhs.compare(lhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator==( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits> - nssv_constexpr bool operator!=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return !(lhs == rhs); - } - - // < - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) > 0; - } - - // <= - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator<=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) >= 0; - } - - // > - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) < 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) < 0; - } - - // >= - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - basic_string_view lhs, - CharT const* rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - CharT const* lhs, - basic_string_view rhs) nssv_noexcept - { - return rhs.compare(lhs) <= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - basic_string_view lhs, - std::basic_string rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits> - nssv_constexpr bool operator>=( - std::basic_string rhs, - basic_string_view lhs) nssv_noexcept - { - return rhs.compare(lhs) <= 0; - } - -#else // newer compilers: - -#define nssv_BASIC_STRING_VIEW_I(T,U) typename std::decay< basic_string_view >::type - -#if defined(_MSC_VER) // issue 40 -# define nssv_MSVC_ORDER(x) , int=x -#else -# define nssv_MSVC_ORDER(x) /*, int=x*/ -#endif - -// == - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator==( - basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator==( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs) nssv_noexcept - { - return lhs.size() == rhs.size() && lhs.compare(rhs) == 0; - } - - // != - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator!= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator!= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return !(lhs == rhs); - } - - // < - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator< ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator< ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) < 0; - } - - // <= - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator<= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator<= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) <= 0; - } - - // > - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator> ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator> ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) > 0; - } - - // >= - - template< class CharT, class Traits nssv_MSVC_ORDER(1) > - nssv_constexpr bool operator>= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - - template< class CharT, class Traits nssv_MSVC_ORDER(2) > - nssv_constexpr bool operator>= ( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view < CharT, Traits > rhs) nssv_noexcept - { - return lhs.compare(rhs) >= 0; - } - -#undef nssv_MSVC_ORDER -#undef nssv_BASIC_STRING_VIEW_I - -#endif // compiler-dependent approach to comparisons - - // 24.4.4 Inserters and extractors: - -#if ! nssv_CONFIG_NO_STREAM_INSERTION - - namespace detail { - - template< class Stream > - void write_padding(Stream& os, std::streamsize n) - { - for (std::streamsize i = 0; i < n; ++i) - os.rdbuf()->sputc(os.fill()); - } - - template< class Stream, class View > - Stream& write_to_stream(Stream& os, View const& sv) - { - typename Stream::sentry sentry(os); - - if (!sentry) - return os; - - const std::streamsize length = static_cast(sv.length()); - - // Whether, and how, to pad: - const bool pad = (length < os.width()); - const bool left_pad = pad && (os.flags() & std::ios_base::adjustfield) == std::ios_base::right; - - if (left_pad) - write_padding(os, os.width() - length); - - // Write span characters: - os.rdbuf()->sputn(sv.begin(), length); - - if (pad && !left_pad) - write_padding(os, os.width() - length); - - // Reset output stream width: - os.width(0); - - return os; - } - - } // namespace detail - - template< class CharT, class Traits > - std::basic_ostream& - operator<<( - std::basic_ostream& os, - basic_string_view sv) - { - return detail::write_to_stream(os, sv); - } - -#endif // nssv_CONFIG_NO_STREAM_INSERTION - - // Several typedefs for common character types are provided: - - typedef basic_string_view string_view; - typedef basic_string_view wstring_view; -#if nssv_HAVE_WCHAR16_T - typedef basic_string_view u16string_view; - typedef basic_string_view u32string_view; -#endif - - } -} // namespace nonstd::sv_lite - -// -// 24.4.6 Suffix for basic_string_view literals: -// - -#if nssv_HAVE_USER_DEFINED_LITERALS - -namespace nonstd { - nssv_inline_ns namespace literals { - nssv_inline_ns namespace string_view_literals { - -#if nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - - nssv_constexpr nonstd::sv_lite::string_view operator "" sv(const char* str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator "" sv(const char16_t* str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator "" sv(const char32_t* str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator "" sv(const wchar_t* str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{ str, len }; - } - -#endif // nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - -#if nssv_CONFIG_USR_SV_OPERATOR - - nssv_constexpr nonstd::sv_lite::string_view operator "" _sv(const char* str, size_t len) nssv_noexcept // (1) - { - return nonstd::sv_lite::string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u16string_view operator "" _sv(const char16_t* str, size_t len) nssv_noexcept // (2) - { - return nonstd::sv_lite::u16string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::u32string_view operator "" _sv(const char32_t* str, size_t len) nssv_noexcept // (3) - { - return nonstd::sv_lite::u32string_view{ str, len }; - } - - nssv_constexpr nonstd::sv_lite::wstring_view operator "" _sv(const wchar_t* str, size_t len) nssv_noexcept // (4) - { - return nonstd::sv_lite::wstring_view{ str, len }; - } - -#endif // nssv_CONFIG_USR_SV_OPERATOR - - } - } -} // namespace nonstd::literals::string_view_literals - -#endif - -// -// Extensions for std::string: -// - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - namespace sv_lite { - - // Exclude MSVC 14 (19.00): it yields ambiguous to_string(): - -#if nssv_CPP11_OR_GREATER && nssv_COMPILER_MSVC_VERSION != 140 - - template< class CharT, class Traits, class Allocator = std::allocator > - std::basic_string - to_string(basic_string_view v, Allocator const& a = Allocator()) - { - return std::basic_string(v.begin(), v.end(), a); - } - -#else - - template< class CharT, class Traits > - std::basic_string - to_string(basic_string_view v) - { - return std::basic_string(v.begin(), v.end()); - } - - template< class CharT, class Traits, class Allocator > - std::basic_string - to_string(basic_string_view v, Allocator const& a) - { - return std::basic_string(v.begin(), v.end(), a); - } - -#endif // nssv_CPP11_OR_GREATER - - template< class CharT, class Traits, class Allocator > - basic_string_view - to_string_view(std::basic_string const& s) - { - return basic_string_view(s.data(), s.size()); - } - - } -} // namespace nonstd::sv_lite - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -// -// make types and algorithms available in namespace nonstd: -// - -namespace nonstd { - - using sv_lite::basic_string_view; - using sv_lite::string_view; - using sv_lite::wstring_view; - -#if nssv_HAVE_WCHAR16_T - using sv_lite::u16string_view; -#endif -#if nssv_HAVE_WCHAR32_T - using sv_lite::u32string_view; -#endif - - // literal "sv" - - using sv_lite::operator==; - using sv_lite::operator!=; - using sv_lite::operator<; - using sv_lite::operator<=; - using sv_lite::operator>; - using sv_lite::operator>=; - -#if ! nssv_CONFIG_NO_STREAM_INSERTION - using sv_lite::operator<<; -#endif - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - using sv_lite::to_string; - using sv_lite::to_string_view; -#endif - -} // namespace nonstd - -// 24.4.5 Hash support (C++11): - -// Note: The hash value of a string view object is equal to the hash value of -// the corresponding string object. - -#if nssv_HAVE_STD_HASH - -#include - -namespace std { - - template<> - struct hash< nonstd::string_view > - { - public: - std::size_t operator()(nonstd::string_view v) const nssv_noexcept - { - return std::hash()(std::string(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::wstring_view > - { - public: - std::size_t operator()(nonstd::wstring_view v) const nssv_noexcept - { - return std::hash()(std::wstring(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::u16string_view > - { - public: - std::size_t operator()(nonstd::u16string_view v) const nssv_noexcept - { - return std::hash()(std::u16string(v.data(), v.size())); - } - }; - - template<> - struct hash< nonstd::u32string_view > - { - public: - std::size_t operator()(nonstd::u32string_view v) const nssv_noexcept - { - return std::hash()(std::u32string(v.data(), v.size())); - } - }; - -} // namespace std - -#endif // nssv_HAVE_STD_HASH - -nssv_RESTORE_WARNINGS() - -#endif // nssv_HAVE_STD_STRING_VIEW -#endif // NONSTD_SV_LITE_H_INCLUDED \ No newline at end of file diff --git a/test/test_combiner.cpp b/test/test_combiner.cpp index c21763b9..9c4c88b5 100644 --- a/test/test_combiner.cpp +++ b/test/test_combiner.cpp @@ -133,22 +133,22 @@ TEST(KiwiCppCombiner, Joiner) TEST(KiwiCppCombiner, Allomorph) { - using Tuple = std::tuple; + using Tuple = std::tuple; auto& rule = getCompiledRule(); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"를"}, CondVowel::vowel, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"을"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"를"}, CondVowel::vowel, (uint8_t)0}, + Tuple{ std::u16string_view{u"을"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jko); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"가"}, CondVowel::vowel, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"이"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"가"}, CondVowel::vowel, (uint8_t)0}, + Tuple{ std::u16string_view{u"이"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jks); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"로"}, CondVowel::vocalic, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"으로"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"로"}, CondVowel::vocalic, (uint8_t)0}, + Tuple{ std::u16string_view{u"으로"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::jkb); auto joiner = rule.newJoiner(); @@ -182,8 +182,8 @@ TEST(KiwiCppCombiner, Allomorph) EXPECT_EQ(joiner.getU16(), u"북으로"); rule.addAllomorph({ - Tuple{ nonstd::u16string_view{u"면"}, CondVowel::vocalic, (uint8_t)0}, - Tuple{ nonstd::u16string_view{u"으면"}, CondVowel::non_vowel, (uint8_t)0} + Tuple{ std::u16string_view{u"면"}, CondVowel::vocalic, (uint8_t)0}, + Tuple{ std::u16string_view{u"으면"}, CondVowel::non_vowel, (uint8_t)0} }, POSTag::ec); joiner = rule.newJoiner(); diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index e8030a49..481ef8ba 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -50,7 +50,7 @@ constexpr std::vector> toPair(const ATy(&init)[n]) Kiwi& reuseKiwiInstance() { - static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, }.build(); + static Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm }.build(); return kiwi; } @@ -147,7 +147,7 @@ TEST(KiwiCpp, SingleConsonantMorpheme) TEST(KiwiCpp, SpecialTokenErrorOnContinualTypo) { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm }; Kiwi typoKiwi = builder.build(DefaultTypoSet::continualTypoSet); auto res = typoKiwi.analyze(u"감사합니다 -친구들과", Match::allWithNormalizing).first; @@ -382,7 +382,7 @@ TEST(KiwiCpp, TagRoundTrip) TEST(KiwiCpp, UserTag) { - KiwiBuilder kw{ MODEL_PATH, 0, BuildOption::default_, }; + KiwiBuilder kw{ MODEL_PATH, 0, BuildOption::default_, ModelType::knlm, }; EXPECT_TRUE(kw.addWord(u"사용자태그", POSTag::user0).second); EXPECT_TRUE(kw.addWord(u"이것도유저", POSTag::user1).second); EXPECT_TRUE(kw.addWord(u"특수한표지", POSTag::user2).second); @@ -432,7 +432,7 @@ TEST(KiwiCpp, HSDataset) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; dataset.reset(); - while (s = dataset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = dataset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s; @@ -450,13 +450,13 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., tokenFilter, {}, 0.1, false, {}, 0, &devset); + trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., 0., 0., false, tokenFilter, {}, 0.1, false, {}, 0, {}, &devset); for (size_t i = 0; i < 2; ++i) { { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; trainset.reset(); - while (s = trainset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = trainset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s; @@ -468,7 +468,7 @@ TEST(KiwiCpp, HSDataset) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s; devset.reset(); - while (s = devset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt)) + while ((s = devset.next(in.data(), out.data(), lmLProbs.data(), outNgramBase.data(), restLm, restLmCnt))) { EXPECT_LE(s, batchSize); totalTokenCnt += s; @@ -519,6 +519,7 @@ TEST(KiwiCpp, SentenceBoundaryErrors) EXPECT_EQ(sentRanges.size(), 1); if (sentRanges.size() > 1) { + kiwi.splitIntoSents(str, Match::allWithNormalizing, &res); for (auto& r : sentRanges) { std::cerr << std::string{ &str[r.first], r.second - r.first } << std::endl; @@ -626,12 +627,14 @@ TEST(KiwiCpp, FalsePositiveSB) u"도서전에서 관람객의 관심을 받을 것으로 예상되는 프로그램으로는 '인문학 아카데미'가 있어요. 이 프로그램에서는 유시민 전 의원, 광고인 박웅현 씨 등이 문화 역사 미학 등 다양한 분야에 대해 강의할 예정이다. 또한, '북 멘토 프로그램'도 이어져요. 이 프로그램에서는 각 분야 전문가들이 경험과 노하우를 전수해 주는 프로그램으로, 시 창작(이정록 시인), 번역(강주헌 번역가), 북 디자인(오진경 북디자이너) 등의 분야에서 멘토링이 이뤄져요.", }) { - auto tokens = kiwi.analyze(str, 10, Match::allWithNormalizing)[0].first; + auto res = kiwi.analyze(str, 1, Match::allWithNormalizing); + auto tokens = res[0].first; auto sbCount = std::count_if(tokens.begin(), tokens.end(), [](const TokenInfo& t) { return t.tag == POSTag::sb; }); EXPECT_EQ(sbCount, 0); + kiwi.analyze(str, 10, Match::allWithNormalizing); } } @@ -925,12 +928,14 @@ TEST(KiwiCpp, AnalyzeWithLoadDefaultDict) TEST(KiwiCpp, AnalyzeSBG) { - Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, true }.build(); - kiwi.analyze(TEST_SENT, Match::all); - - auto tokens = kiwi.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", kiwi::Match::allWithNormalizing).first; - EXPECT_EQ(tokens.size(), 11); - EXPECT_EQ(tokens[8].str, u"걸"); + Kiwi kiwi = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, ModelType::knlm }.build(); + Kiwi kiwiSbg = KiwiBuilder{ MODEL_PATH, 0, BuildOption::none, ModelType::sbg }.build(); + kiwiSbg.analyze(TEST_SENT, Match::all); + + auto res = kiwi.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", 3, kiwi::Match::allWithNormalizing); + auto resSbg = kiwiSbg.analyze(u"이 번호로 전화를 이따가 꼭 반드시 걸어.", 3, kiwi::Match::allWithNormalizing); + EXPECT_EQ(resSbg[0].first.size(), 11); + EXPECT_EQ(resSbg[0].first[8].str, u"걸"); } TEST(KiwiCpp, AnalyzeMultithread) @@ -1205,7 +1210,7 @@ TEST(KiwiCpp, IssueP111_SentenceSplitError) auto res = kiwi.splitIntoSents(text); EXPECT_GT(res.size(), 1); - KiwiBuilder builder{ MODEL_PATH, 1 }; + KiwiBuilder builder{ MODEL_PATH, 1, BuildOption::default_, ModelType::knlm }; EXPECT_TRUE(builder.addWord(u"모", POSTag::nng).second); Kiwi kiwi2 = builder.build(); auto res2 = kiwi2.splitIntoSents(text); @@ -1255,7 +1260,7 @@ TEST(KiwiCpp, AddRule) auto ores = okiwi.analyze(u"했어요! 하잖아요! 할까요? 좋아요!", Match::allWithNormalizing); { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict, ModelType::knlm }; auto inserted = builder.addRule(POSTag::ef, [](std::u16string input) { if (input.back() == u'요') @@ -1272,7 +1277,7 @@ TEST(KiwiCpp, AddRule) } { - KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict }; + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_ & ~BuildOption::loadTypoDict, ModelType::knlm }; auto inserted = builder.addRule(POSTag::ef, [](std::u16string input) { if (input.back() == u'요') @@ -1674,6 +1679,12 @@ TEST(KiwiCpp, IssueP189) TEST(KiwiCpp, Issue205) { + if (sizeof(void*) != 8) + { + std::cerr << "This test is only available in 64-bit mode" << std::endl; + return; + } + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, }; builder.addWord(u"함박 스테이크"); auto kiwi1 = builder.build(); diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 866ae6e5..05dd959f 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 866ae6e5ffe93a1f63be738078da94cf3005cce2 +Subproject commit 05dd959fa26c7e68fa229495a35f55e06a3b9655 diff --git a/third_party/streamvbyte b/third_party/streamvbyte new file mode 160000 index 00000000..f27641e3 --- /dev/null +++ b/third_party/streamvbyte @@ -0,0 +1 @@ +Subproject commit f27641e3194d14d667e30928a418685d943ab62c diff --git a/third_party/variant b/third_party/variant deleted file mode 160000 index f87fcbda..00000000 --- a/third_party/variant +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f87fcbda9daf13fba47a6a889696b0ad23fc098d diff --git a/tools/Evaluator.cpp b/tools/Evaluator.cpp index 1bef793c..21412c51 100644 --- a/tools/Evaluator.cpp +++ b/tools/Evaluator.cpp @@ -4,16 +4,47 @@ #include #include "../src/StrUtils.h" #include "Evaluator.h" +#include "toolUtils.h" #include "LCS.hpp" using namespace std; using namespace kiwi; -TokenInfo parseWordPOS(const u16string& str) +unique_ptr Evaluator::create(const std::string& evalType) +{ + if (evalType == "morph") return std::make_unique(); + if (evalType == "disamb") return std::make_unique(); + throw runtime_error{ "Unknown Evaluator Type" }; +} + +const char* modelTypeToStr(ModelType type) +{ + switch (type) + { + case ModelType::knlm: return "knlm"; + case ModelType::knlmTransposed: return "knlm-transposed"; + case ModelType::sbg: return "sbg"; + case ModelType::cong: return "cong"; + case ModelType::congGlobal: return "cong-global"; + case ModelType::congFp32: return "cong-fp32"; + case ModelType::congGlobalFp32: return "cong-global-fp32"; + } + return "unknown"; +} + +inline ostream& operator<<(ostream& o, const kiwi::TokenInfo& t) +{ + o << utf16To8(t.str); + if (t.senseId) o << "__" << (int)t.senseId; + o << "/" << kiwi::tagToString(t.tag); + return o; +} + +inline TokenInfo parseWordPOS(const u16string& str) { auto p = str.rfind('/'); if (p == str.npos) return {}; - u16string form = replace(nonstd::u16string_view(str.data(), p), u"_", u" "); + u16string form = replace(std::u16string_view(str.data(), p), u"_", u" "); if (str[p + 1] == 'E') { if (form[0] == u'아' || form[0] == u'여') form[0] = u'어'; @@ -36,13 +67,98 @@ TokenInfo parseWordPOS(const u16string& str) tagStr.erase(tagStr.begin() + tagStr.find('-'), tagStr.end()); } POSTag tag = toPOSTag(tagStr); - if (tag >= POSTag::max) throw runtime_error{ "Wrong Input '" + utf16To8(str.substr(p + 1)) + "'" }; + if (clearIrregular(tag) >= POSTag::max) throw runtime_error{ "Wrong Input '" + utf16To8(str.substr(p + 1)) + "'" }; return { form, tag, 0, 0 }; } -Evaluator::Evaluator(const std::string& testSetFile, const Kiwi* _kw, Match _matchOption, size_t _topN) - : kw{ _kw }, matchOption{ _matchOption }, topN{ _topN } +int Evaluator::operator()(const string& modelPath, + const string& output, + const vector& input, + bool normCoda, bool zCoda, bool multiDict, ModelType modelType, + float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, + int repeat) +{ + try + { + if (typoCostWeight > 0 && !bTypo && !cTypo && !lTypo) + { + bTypo = true; + } + else if (typoCostWeight == 0) + { + bTypo = false; + cTypo = false; + lTypo = false; + } + + tutils::Timer timer; + auto option = (BuildOption::default_ & ~BuildOption::loadMultiDict) | (multiDict ? BuildOption::loadMultiDict : BuildOption::none); + auto typo = getDefaultTypoSet(DefaultTypoSet::withoutTypo); + + if (bTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::basicTypoSet); + } + + if (cTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::continualTypoSet); + } + + if (lTypo) + { + typo |= getDefaultTypoSet(DefaultTypoSet::lengtheningTypoSet); + } + + Kiwi kw = KiwiBuilder{ modelPath, 1, option, modelType }.build( + typo + ); + if (typoCostWeight > 0) kw.setTypoCostWeight(typoCostWeight); + + cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; + cout << "ArchType : " << archToStr(kw.archType()) << endl; + cout << "Model Type : " << modelTypeToStr(kw.modelType()) << endl; + if (kw.getLangModel()) + { + cout << "LM Size : " << (kw.getLangModel()->getMemorySize() / 1024. / 1024.) << " MB" << endl; + } + cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB\n" << endl; + + double avgMicro = 0, avgMacro = 0; + double cnt = 0; + for (auto& tf : input) + { + cout << "Test file: " << tf << endl; + try + { + auto result = eval(output, tf, kw, normCoda, zCoda, repeat); + avgMicro += result.first; + avgMacro += result.second; + ++cnt; + cout << "================" << endl; + } + catch (const std::exception& e) + { + cerr << e.what() << endl; + } + } + + cout << endl << "================" << endl; + cout << "Avg Score" << endl; + cout << avgMicro / cnt << ", " << avgMacro / cnt << endl; + cout << "================" << endl; + return 0; + } + catch (const exception& e) + { + cerr << e.what() << endl; + return -1; + } +} + +auto MorphEvaluator::loadTestset(const string& testSetFile) const -> vector { + vector ret; ifstream f{ testSetFile }; if (!f) throw std::ios_base::failure{ "Cannot open '" + testSetFile + "'" }; string line; @@ -55,32 +171,24 @@ Evaluator::Evaluator(const std::string& testSetFile, const Kiwi* _kw, Match _mat vector tokens; for (size_t i = 1; i < fd.size(); ++i) { - for (auto s : split(fd[i], u' ')) tokens.emplace_back(s.to_string()); + for (auto s : split(fd[i], u' ')) tokens.emplace_back(s); } TestResult tr; - tr.q = fd[0].to_string(); + tr.q = u16string{ fd[0] }; for (auto& t : tokens) tr.a.emplace_back(parseWordPOS(t)); - testsets.emplace_back(std::move(tr)); - } -} - -void Evaluator::run() -{ - for (auto& tr : testsets) - { - auto cands = kw->analyze(tr.q, topN, matchOption); - tr.r = cands[0].first; + ret.emplace_back(std::move(tr)); } + return ret; } -Evaluator::Score Evaluator::evaluate() +auto MorphEvaluator::computeScore(vector& preds, vector& errors) const -> Score { errors.clear(); size_t totalCount = 0, microCorrect = 0, microCount = 0; double totalScore = 0; - for (auto& tr : testsets) + for (auto& tr : preds) { if (tr.a != tr.r) { @@ -128,15 +236,31 @@ Evaluator::Score Evaluator::evaluate() return ret; } -ostream& operator<<(ostream& o, const kiwi::TokenInfo& t) +auto DisambEvaluator::computeScore(vector& preds, vector& errors) const -> Score { - o << utf16To8(t.str); - if (t.senseId) o << "__" << (int)t.senseId; - o << "/" << kiwi::tagToString(t.tag); - return o; + errors.clear(); + Score score; + for (auto& tr : preds) + { + bool correct = false; + for (auto& token : tr.result.first) + { + if (token.str == tr.target.str && + clearIrregular(token.tag) == clearIrregular(tr.target.tag)) + { + correct = true; + break; + } + } + if (correct) score.acc += 1; + else errors.emplace_back(tr); + score.totalCount++; + } + score.acc /= score.totalCount; + return score; } -void Evaluator::TestResult::writeResult(ostream& out) const +void MorphEvaluator::TestResult::writeResult(ostream& out) const { out << utf16To8(q) << '\t' << score << endl; for (auto& _r : da) @@ -151,3 +275,114 @@ void Evaluator::TestResult::writeResult(ostream& out) const out << endl; out << endl; } + +pair MorphEvaluator::eval(const string& output, const string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) +{ + const size_t topN = 1; + const Match matchOption = (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda); + vector testsets = loadTestset(file), errors; + tutils::Timer total; + for (int i = 0; i < repeat; ++i) + { + for (auto& tr : testsets) + { + auto cands = kiwi.analyze(tr.q, topN, matchOption); + tr.r = cands[0].first; + } + } + double tm = total.getElapsed() / repeat; + auto score = computeScore(testsets, errors); + + cout << score.micro << ", " << score.macro << endl; + cout << "Total (" << score.totalCount << " lines) Time : " << tm << " ms" << endl; + cout << "Time per Line : " << tm / score.totalCount << " ms" << endl; + + if (!output.empty()) + { + const size_t last_slash_idx = file.find_last_of("\\/"); + string name; + if (last_slash_idx != file.npos) name = file.substr(last_slash_idx + 1); + else name = file; + + ofstream out{ output + "/" + name }; + out << score.micro << ", " << score.macro << endl; + out << "Total (" << score.totalCount << ") Time : " << tm << " ms" << endl; + out << "Time per Unit : " << tm / score.totalCount << " ms" << endl; + for (auto t : errors) + { + t.writeResult(out); + } + } + return make_pair(score.micro, score.macro); +} + +auto DisambEvaluator::loadTestset(const string& testSetFile) const -> vector +{ + vector ret; + ifstream f{ testSetFile }; + if (!f) throw std::ios_base::failure{ "Cannot open '" + testSetFile + "'" }; + string line; + while (getline(f, line)) + { + while (line.back() == '\n' || line.back() == '\r') line.pop_back(); + auto wstr = utf8To16(line); + auto fd = split(wstr, u'\t'); + if (fd.size() < 2) continue; + TestResult tr; + tr.target = parseWordPOS(u16string{ fd[0] }); + tr.text = u16string{ fd[1] }; + ret.emplace_back(move(tr)); + } + return ret; +} + +void DisambEvaluator::TestResult::writeResult(ostream& out) const +{ + out << target << '\t' << utf16To8(text) << '\t' << score << endl; + for (auto& _r : result.first) + { + out << _r << '\t'; + } + out << endl; + out << endl; +} + +pair DisambEvaluator::eval(const string& output, const string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) +{ + const size_t topN = 1; + const Match matchOption = (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda); + vector testsets = loadTestset(file), errors; + tutils::Timer total; + for (int i = 0; i < repeat; ++i) + { + for (auto& tr : testsets) + { + auto cands = kiwi.analyze(tr.text, topN, matchOption); + tr.result = cands[0]; + } + } + double tm = total.getElapsed() / repeat; + auto score = computeScore(testsets, errors); + + cout << score.acc << endl; + cout << "Total (" << score.totalCount << " lines) Time : " << tm << " ms" << endl; + cout << "Time per Line : " << tm / score.totalCount << " ms" << endl; + + if (!output.empty()) + { + const size_t last_slash_idx = file.find_last_of("\\/"); + string name; + if (last_slash_idx != file.npos) name = file.substr(last_slash_idx + 1); + else name = file; + + ofstream out{ output + "/" + name }; + out << score.acc << endl; + out << "Total (" << score.totalCount << ") Time : " << tm << " ms" << endl; + out << "Time per Unit : " << tm / score.totalCount << " ms" << endl; + for (auto t : errors) + { + t.writeResult(out); + } + } + return make_pair(score.acc, score.acc); +} diff --git a/tools/Evaluator.h b/tools/Evaluator.h index a91161d3..77221306 100644 --- a/tools/Evaluator.h +++ b/tools/Evaluator.h @@ -3,7 +3,23 @@ class Evaluator { + virtual std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) = 0; public: + + virtual ~Evaluator() = default; + + static std::unique_ptr create(const std::string& evalType); + + int operator()(const std::string& modelPath, + const std::string& output, + const std::vector& input, + bool normCoda, bool zCoda, bool multiDict, kiwi::ModelType modelType, + float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, + int repeat); +}; + +class MorphEvaluator : public Evaluator +{ using AnswerType = std::vector; struct TestResult { @@ -15,7 +31,7 @@ class Evaluator float score; void writeResult(std::ostream& out) const; }; - + struct Score { double micro = 0; @@ -23,15 +39,31 @@ class Evaluator size_t totalCount = 0; }; -private: - std::vector testsets, errors; - const kiwi::Kiwi* kw = nullptr; - kiwi::Match matchOption; - size_t topN = 1; -public: - Evaluator(const std::string& testSetFile, const kiwi::Kiwi* _kw, kiwi::Match _matchOption = kiwi::Match::all, size_t topN = 1); - void run(); - Score evaluate(); - const std::vector& getErrors() const { return errors; } + std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) override; + + std::vector loadTestset(const std::string& file) const; + Score computeScore(std::vector& preds, std::vector& errors) const; }; +class DisambEvaluator : public Evaluator +{ + struct TestResult + { + std::u16string text; + kiwi::TokenInfo target; + kiwi::TokenResult result; + float score = 0; + void writeResult(std::ostream& out) const; + }; + + struct Score + { + double acc = 0; + size_t totalCount = 0; + }; + + std::pair eval(const std::string& output, const std::string& file, kiwi::Kiwi& kiwi, bool normCoda, bool zCoda, int repeat) override; + + std::vector loadTestset(const std::string& file) const; + Score computeScore(std::vector& preds, std::vector& errors) const; +}; diff --git a/tools/cong_builder.cpp b/tools/cong_builder.cpp new file mode 100644 index 00000000..6dc9b38d --- /dev/null +++ b/tools/cong_builder.cpp @@ -0,0 +1,69 @@ +#include +#include + +#include +#include +#include +#include "toolUtils.h" + +using namespace std; +using namespace kiwi; + +int run(const std::string& morphemeDef, const std::string& contextDef, const std::string& embedding, + size_t minCnt, size_t maxLength, const std::string& output, bool useVLE, bool reorderContextIdx = true) +{ + try + { + tutils::Timer timer; + KiwiBuilder::buildMorphData(morphemeDef, output, minCnt); + auto ret = lm::CoNgramModelBase::build(contextDef, embedding, maxLength, useVLE, reorderContextIdx); + ret.writeToFile(output + "/cong.mdl"); + double tm = timer.getElapsed(); + cout << "Total: " << tm << " ms " << endl; + return 0; + } + catch (const exception& e) + { + cerr << e.what() << endl; + return -1; + } +} + +using namespace TCLAP; + +int main(int argc, const char* argv[]) +{ + tutils::setUTF8Output(); + + CmdLine cmd{ "Kiwi PCLanguageModel Builder", ' ', "0.21.0" }; + + ValueArg mdef{ "m", "morpheme-def", "morpheme definition", true, "", "string" }; + ValueArg cdef{ "c", "context-def", "context definition", true, "", "string" }; + ValueArg emb{ "e", "emb", "embedding file", true, "", "string" }; + ValueArg minCnt{ "n", "min-cnt", "min count of morpheme", false, 10, "int" }; + ValueArg maxLength{ "l", "max-length", "max length of n-grams", false, (size_t)-1, "int"}; + ValueArg output{ "o", "output", "", true, "", "string" }; + SwitchArg useVLE{ "", "use-vle", "use VLE", false }; + SwitchArg preserveContextIdx{ "p", "preserve-context-idx", "preserve context index", false }; + + cmd.add(mdef); + cmd.add(cdef); + cmd.add(emb); + cmd.add(minCnt); + cmd.add(maxLength); + cmd.add(output); + cmd.add(useVLE); + cmd.add(preserveContextIdx); + + try + { + cmd.parse(argc, argv); + } + catch (const ArgException& e) + { + cerr << "error: " << e.error() << " for arg " << e.argId() << endl; + return -1; + } + + return run(mdef, cdef, emb, minCnt, maxLength, output, useVLE, !preserveContextIdx); +} diff --git a/tools/diff_tokens.cpp b/tools/diff_tokens.cpp new file mode 100644 index 00000000..41d94aff --- /dev/null +++ b/tools/diff_tokens.cpp @@ -0,0 +1,245 @@ +#include +#include + +#include +#include + +#include "toolUtils.h" + +using namespace std; +using namespace kiwi; +using namespace TCLAP; + +Kiwi loadKiwiFromArg(const string& model, const string& modelType, size_t numThreads = 2) +{ + ModelType kiwiModelType = tutils::parseModelType(modelType); + BuildOption opt = BuildOption::default_; + opt &= ~BuildOption::loadMultiDict; + KiwiBuilder builder{ model, numThreads < 2 ? 2 : numThreads, opt, kiwiModelType }; + return builder.build(); +} + +inline bool isEqual(const TokenInfo* a, size_t aSize, const TokenInfo* b, size_t bSize, bool ignoreTag = false) +{ + if (aSize != bSize) return false; + for (size_t i = 0; i < aSize; ++i) + { + if (a[i].str != b[i].str) return false; + if (!ignoreTag && a[i].tag != b[i].tag) return false; + } + return true; +} + +inline ostream& operator<<(ostream& ostr, const TokenInfo& token) +{ + return ostr << utf16To8(token.str) << '/' << tagToString(token.tag); +} + +bool printDiffTokens(ostream& ostr, const string& raw, const TokenInfo* a, size_t aSize, const TokenInfo* b, size_t bSize, bool ignoreTag = false, bool showSame = false) +{ + if (isEqual(a, aSize, b, bSize, ignoreTag) != showSame) return false; + ostr << raw << '\t'; + for (size_t i = 0; i < aSize; ++i) + { + if (i) ostr << ' '; + ostr << a[i]; + } + if (!showSame || ignoreTag) + { + ostr << '\t'; + for (size_t i = 0; i < bSize; ++i) + { + if (i) ostr << ' '; + ostr << b[i]; + } + } + ostr << endl; + return true; +} + +pair diffTokens(ostream& ostr, const string& raw, const TokenResult& a, const TokenResult& b, bool sentenceLevel, bool ignoreTag = false, bool showSame = false) +{ + size_t diff = 0, total = 0; + if (sentenceLevel) + { + thread_local vector> aBounds, bBounds, sentBounds; + aBounds.clear(); + bBounds.clear(); + sentBounds.clear(); + auto& aTokens = a.first; + auto& bTokens = b.first; + for (size_t i = 1; i < aTokens.size(); ++i) + { + if (aTokens[i - 1].sentPosition != aTokens[i].sentPosition) + { + aBounds.emplace_back(aTokens[i - 1].endPos(), i); + } + } + + for (size_t i = 1; i < bTokens.size(); ++i) + { + if (bTokens[i - 1].sentPosition != bTokens[i].sentPosition) + { + bBounds.emplace_back(bTokens[i - 1].endPos(), i); + } + } + + // find intersection between aBounds and bBounds and store in sentBounds + sentBounds.emplace_back(0, 0); + auto aIt = aBounds.begin(); + auto bIt = bBounds.begin(); + while (aIt != aBounds.end() && bIt != bBounds.end()) + { + if (aIt->first < bIt->first) + { + ++aIt; + } + else if (aIt->first > bIt->first) + { + ++bIt; + } + else + { + sentBounds.emplace_back(aIt->second, bIt->second); + ++aIt; + ++bIt; + } + } + sentBounds.emplace_back(aTokens.size(), bTokens.size()); + + const u16string u16raw = utf8To16(raw); + + for (size_t i = 1; i < sentBounds.size(); ++i) + { + const auto aStart = sentBounds[i - 1].first; + const auto aEnd = sentBounds[i].first; + const auto bStart = sentBounds[i - 1].second; + const auto bEnd = sentBounds[i].second; + const auto rawSent = u16raw.substr(aTokens[aStart].position, aTokens[aEnd - 1].endPos() - aTokens[aStart].position); + const bool isDiff = printDiffTokens(ostr, utf16To8(rawSent), aTokens.data() + aStart, aEnd - aStart, bTokens.data() + bStart, bEnd - bStart, ignoreTag, showSame); + if (isDiff) ++diff; + ++total; + } + } + else + { + const bool isDiff = printDiffTokens(ostr, raw, a.first.data(), a.first.size(), b.first.data(), b.first.size(), ignoreTag, showSame); + if (isDiff) ++diff; + ++total; + } + return { diff, total }; +} + +pair diffInputs(Kiwi& kiwiA, Kiwi& kiwiB, const string& inputs, ostream& ostr, bool sentenceLevel, bool ignoreTag = false, bool showSame = false) +{ + ifstream ifs{ inputs }; + if (!ifs) + { + cerr << "Cannot open " << inputs << endl; + return { 0, 0 }; + } + string line; + deque, future>> futures; + auto* poolA = kiwiA.getThreadPool(); + auto* poolB = kiwiB.getThreadPool(); + size_t diff = 0, total = 0; + + while (getline(ifs, line)) + { + while (futures.size() > kiwiA.getNumThreads() * 2) + { + auto rawInput = move(get<0>(futures.front())); + auto resultA = get<1>(futures.front()).get(); + auto resultB = get<2>(futures.front()).get(); + futures.pop_front(); + + auto p = diffTokens(ostr, rawInput, resultA, resultB, sentenceLevel, ignoreTag, showSame); + diff += p.first; + total += p.second; + } + + futures.emplace_back( + line, + poolA->enqueue([&, line](size_t tid) { return kiwiA.analyze(line, Match::allWithNormalizing);}), + poolB->enqueue([&, line](size_t tid) { return kiwiB.analyze(line, Match::allWithNormalizing);}) + ); + } + + while (!futures.empty()) + { + auto rawInput = move(get<0>(futures.front())); + auto resultA = get<1>(futures.front()).get(); + auto resultB = get<2>(futures.front()).get(); + futures.pop_front(); + + auto p = diffTokens(ostr, rawInput, resultA, resultB, sentenceLevel, ignoreTag, showSame); + diff += p.first; + total += p.second; + } + return { diff, total }; +} + +int main(int argc, const char* argv[]) +{ + tutils::setUTF8Output(); + + CmdLine cmd{ "Kiwi Diff Tokenizations" }; + + ValueArg modelA{ "", "model-a", "Model A path", true, "", "string" }; + ValueArg modelAType{ "", "model-a-type", "Model A Type", false, "none", "string" }; + ValueArg modelB{ "", "model-b", "Model B path", true, "", "string" }; + ValueArg modelBType{ "", "model-b-type", "Model B Type", false, "none", "string" }; + ValueArg output{ "o", "output", "output path", false, "", "string" }; + ValueArg numThreads{ "t", "threads", "number of threads", false, 2, "int" }; + SwitchArg sentence{ "", "sentence", "diff in sentence level", false }; + SwitchArg ignoreTag{ "i", "ignore-tag", "ignore tag", false }; + SwitchArg showSame{ "s", "show-same", "show the same result only", false }; + SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; + SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; + SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; + ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float" }; + SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; + SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; + SwitchArg lTypo{ "", "ltypo", "make lengthening-typo-tolerant model", false }; + UnlabeledMultiArg inputs{ "inputs", "targets", false, "string" }; + + cmd.add(modelA); + cmd.add(modelAType); + cmd.add(modelB); + cmd.add(modelBType); + cmd.add(output); + cmd.add(inputs); + cmd.add(numThreads); + cmd.add(sentence); + cmd.add(ignoreTag); + cmd.add(showSame); + + try + { + cmd.parse(argc, argv); + } + catch (const ArgException& e) + { + cerr << "error: " << e.error() << " for arg " << e.argId() << endl; + return -1; + } + + Kiwi kiwiA = loadKiwiFromArg(modelA, modelAType, numThreads); + Kiwi kiwiB = loadKiwiFromArg(modelB, modelBType, numThreads); + + unique_ptr ofstr; + ostream* ostr = &cout; + if (!output.getValue().empty()) + { + ofstr = std::make_unique(output); + ostr = ofstr.get(); + } + + for (auto& input : inputs) + { + cout << "input: " << input << " "; + cout.flush(); + auto p = diffInputs(kiwiA, kiwiB, input, *ostr, sentence, ignoreTag, showSame); + cout << "(diff: " << p.first << " / " << p.second << ")" << endl; + } +} diff --git a/tools/evaluator_main.cpp b/tools/evaluator_main.cpp index 4913add8..1e967e97 100644 --- a/tools/evaluator_main.cpp +++ b/tools/evaluator_main.cpp @@ -10,119 +10,12 @@ using namespace std; using namespace kiwi; - -int doEvaluate(const string& modelPath, const string& output, const vector& input, - bool normCoda, bool zCoda, bool multiDict, bool useSBG, - float typoCostWeight, bool bTypo, bool cTypo, bool lTypo, - int repeat) -{ - try - { - if (typoCostWeight > 0 && !bTypo && !cTypo && !lTypo) - { - bTypo = true; - } - else if (typoCostWeight == 0) - { - bTypo = false; - cTypo = false; - lTypo = false; - } - - tutils::Timer timer; - auto option = (BuildOption::default_ & ~BuildOption::loadMultiDict) | (multiDict ? BuildOption::loadMultiDict : BuildOption::none); - auto typo = getDefaultTypoSet(DefaultTypoSet::withoutTypo); - - if (bTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::basicTypoSet); - } - - if (cTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::continualTypoSet); - } - - if (lTypo) - { - typo |= getDefaultTypoSet(DefaultTypoSet::lengtheningTypoSet); - } - - Kiwi kw = KiwiBuilder{ modelPath, 1, option, useSBG }.build( - typo - ); - if (typoCostWeight > 0) kw.setTypoCostWeight(typoCostWeight); - - cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; - cout << "ArchType : " << archToStr(kw.archType()) << endl; - cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; - cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB\n" << endl; - - double avgMicro = 0, avgMacro = 0; - double cnt = 0; - for (auto& tf : input) - { - cout << "Test file: " << tf << endl; - try - { - Evaluator test{ tf, &kw, (normCoda ? Match::allWithNormalizing : Match::all) & ~(zCoda ? Match::none : Match::zCoda) }; - tutils::Timer total; - for (int i = 0; i < repeat; ++i) - { - test.run(); - } - double tm = total.getElapsed() / repeat; - auto result = test.evaluate(); - - cout << result.micro << ", " << result.macro << endl; - cout << "Total (" << result.totalCount << " lines) Time : " << tm << " ms" << endl; - cout << "Time per Line : " << tm / result.totalCount << " ms" << endl; - - avgMicro += result.micro; - avgMacro += result.macro; - cnt++; - - if (!output.empty()) - { - const size_t last_slash_idx = tf.find_last_of("\\/"); - string name; - if (last_slash_idx != tf.npos) name = tf.substr(last_slash_idx + 1); - else name = tf; - - ofstream out{ output + "/" + name }; - out << result.micro << ", " << result.macro << endl; - out << "Total (" << result.totalCount << ") Time : " << tm << " ms" << endl; - out << "Time per Unit : " << tm / result.totalCount << " ms" << endl; - for (auto t : test.getErrors()) - { - t.writeResult(out); - } - } - cout << "================" << endl; - } - catch (const std::exception& e) - { - cerr << e.what() << endl; - } - } - - cout << endl << "================" << endl; - cout << "Avg Score" << endl; - cout << avgMicro / cnt << ", " << avgMacro / cnt << endl; - cout << "================" << endl; - return 0; - } - catch (const exception& e) - { - cerr << e.what() << endl; - return -1; - } -} - using namespace TCLAP; int main(int argc, const char* argv[]) { + tutils::setUTF8Output(); + CmdLine cmd{ "Kiwi evaluator" }; ValueArg model{ "m", "model", "Kiwi model path", false, "models/base", "string" }; @@ -130,26 +23,26 @@ int main(int argc, const char* argv[]) SwitchArg noNormCoda{ "", "no-normcoda", "without normalizing coda", false }; SwitchArg noZCoda{ "", "no-zcoda", "without z-coda", false }; SwitchArg noMulti{ "", "no-multi", "turn off multi dict", false }; - SwitchArg useSBG{ "", "sbg", "use SkipBigram", false }; + ValueArg modelType{ "t", "type", "model type", false, "none", "string" }; ValueArg typoWeight{ "", "typo", "typo weight", false, 0.f, "float"}; SwitchArg bTypo{ "", "btypo", "make basic-typo-tolerant model", false }; SwitchArg cTypo{ "", "ctypo", "make continual-typo-tolerant model", false }; SwitchArg lTypo{ "", "ltypo", "make lengthening-typo-tolerant model", false }; ValueArg repeat{ "", "repeat", "repeat evaluation for benchmark", false, 1, "int" }; - UnlabeledMultiArg files{ "files", "evaluation set files", true, "string" }; + UnlabeledMultiArg inputs{ "inputs", "evaluation set (--morph, --disamb)", false, "string" }; cmd.add(model); cmd.add(output); - cmd.add(files); cmd.add(noNormCoda); cmd.add(noZCoda); cmd.add(noMulti); - cmd.add(useSBG); + cmd.add(modelType); cmd.add(typoWeight); cmd.add(bTypo); cmd.add(cTypo); cmd.add(lTypo); cmd.add(repeat); + cmd.add(inputs); try { @@ -160,7 +53,64 @@ int main(int argc, const char* argv[]) cerr << "error: " << e.error() << " for arg " << e.argId() << endl; return -1; } - return doEvaluate(model, output, files.getValue(), - !noNormCoda, !noZCoda, !noMulti, useSBG, typoWeight, bTypo, cTypo, lTypo, repeat); + ModelType kiwiModelType = ModelType::none; + try + { + kiwiModelType = tutils::parseModelType(modelType); + } + catch (const exception& e) + { + cerr << e.what() << endl; + return -1; + } + + vector morphInputs, disambInputs; + + string currentType = ""; + for (auto& input : inputs.getValue()) + { + if (input.size() > 2 && input[0] == '-' && input[1] == '-') + { + currentType = input; + } + else + { + if (currentType == "--morph") + { + morphInputs.emplace_back(input); + } + else if (currentType == "--disamb") + { + disambInputs.emplace_back(input); + } + else + { + cerr << "Unknown argument: " << input << endl; + return -1; + } + } + } + + if (morphInputs.size()) + { + auto evaluator = Evaluator::create("morph"); + (*evaluator)(model, output, morphInputs, + !noNormCoda, !noZCoda, !noMulti, + kiwiModelType, + typoWeight, bTypo, cTypo, lTypo, + repeat); + cout << endl; + } + + if (disambInputs.size()) + { + auto evaluator = Evaluator::create("disamb"); + (*evaluator)(model, output, disambInputs, + !noNormCoda, !noZCoda, !noMulti, + kiwiModelType, + typoWeight, bTypo, cTypo, lTypo, + repeat); + cout << endl; + } } diff --git a/tools/runner.cpp b/tools/runner.cpp index 78954de2..96e1f549 100644 --- a/tools/runner.cpp +++ b/tools/runner.cpp @@ -28,7 +28,7 @@ int run(const string& modelPath, bool benchmark, const string& output, const str { tutils::Timer timer; size_t lines = 0, bytes = 0; - Kiwi kw = KiwiBuilder{ modelPath, 1, BuildOption::default_, sbg }.build(typos > 0 ? DefaultTypoSet::basicTypoSet : DefaultTypoSet::withoutTypo); + Kiwi kw = KiwiBuilder{ modelPath, 1, BuildOption::default_, sbg ? ModelType::sbg : ModelType::knlm }.build(typos > 0 ? DefaultTypoSet::basicTypoSet : DefaultTypoSet::withoutTypo); cout << "Kiwi v" << KIWI_VERSION_STRING << endl; if (tolerance) @@ -46,7 +46,7 @@ int run(const string& modelPath, bool benchmark, const string& output, const str { cout << "Loading Time : " << timer.getElapsed() << " ms" << endl; cout << "ArchType : " << archToStr(kw.archType()) << endl; - cout << "LM Size : " << (kw.getKnLM()->getMemory().size() / 1024. / 1024.) << " MB" << endl; + cout << "LM Size : " << (kw.getLangModel()->getMemorySize() / 1024. / 1024.) << " MB" << endl; cout << "Mem Usage : " << (tutils::getCurrentPhysicalMemoryUsage() / 1024.) << " MB" << endl; cout << "ModelType : " << (sbg ? "sbg" : "knlm") << endl; } diff --git a/tools/toolUtils.h b/tools/toolUtils.h index 04b002ae..327a9c5c 100644 --- a/tools/toolUtils.h +++ b/tools/toolUtils.h @@ -102,4 +102,44 @@ namespace tutils { } #endif -} \ No newline at end of file + + inline kiwi::ModelType parseModelType(const std::string& v) + { + if (v == "none") + { + return kiwi::ModelType::none; + } + else if (v == "knlm") + { + return kiwi::ModelType::knlm; + } + else if (v == "sbg") + { + return kiwi::ModelType::sbg; + } + else if (v == "knlm-transposed") + { + return kiwi::ModelType::knlmTransposed; + } + else if (v == "cong") + { + return kiwi::ModelType::cong; + } + else if (v == "cong-global") + { + return kiwi::ModelType::congGlobal; + } + else if (v == "cong-fp32") + { + return kiwi::ModelType::congFp32; + } + else if (v == "cong-global-fp32") + { + return kiwi::ModelType::congGlobalFp32; + } + else + { + throw std::invalid_argument{ "Invalid model type" }; + } + } +} diff --git a/vsproj/build_cong.vcxproj b/vsproj/build_cong.vcxproj new file mode 100644 index 00000000..b2012b69 --- /dev/null +++ b/vsproj/build_cong.vcxproj @@ -0,0 +1,230 @@ + + + + + Debug + ARM64 + + + Debug + Win32 + + + Release + ARM64 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {C63940BA-24B0-452C-A618-E435888BB45C} + Win32Proj + KiwiRun + 10.0 + build_cong + + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + true + v143 + Unicode + + + Application + true + v143 + Unicode + + + Application + false + v143 + true + Unicode + + + Application + false + v143 + true + Unicode + + + + + + + + + + + + + + + + + + + + + + + + + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + true + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + false + $(SolutionDir)third_party/mimalloc/include;$(SolutionDir)third_party/tclap/include;$(SolutionDir)include\;$(VC_IncludePath);$(WindowsSDK_IncludePath); + + + + Level3 + NotUsing + MaxSpeed + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + stdcpp17 + + + Console + true + true + + + + + Level3 + NotUsing + MaxSpeed + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + + + Console + true + true + + + + + NotUsing + Level3 + Disabled + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + stdcpp17 + + + Console + + + + + NotUsing + Level3 + Disabled + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + stdcpp17 + + + Console + + + + + NotUsing + Level3 + Disabled + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreadedDebug + /utf-8 %(AdditionalOptions) + + + Console + + + + + Level3 + NotUsing + MaxSpeed + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + MultiThreaded + /utf-8 %(AdditionalOptions) + stdcpp17 + + + Console + true + true + + + + + {f790bc37-2732-4ed1-9ca5-7248bed3588e} + + + + + + + + + \ No newline at end of file diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index aca1d9e3..b2f49933 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -36,10 +36,12 @@ + + @@ -54,21 +56,31 @@ + + + + + - + + + + + + @@ -96,6 +108,14 @@ AdvancedVectorExtensions512 AdvancedVectorExtensions512 + + AdvancedVectorExtensions512 + AdvancedVectorExtensions512 + + + AdvancedVectorExtensions2 + AdvancedVectorExtensions2 + NotSet true @@ -120,7 +140,9 @@ + + @@ -136,6 +158,13 @@ + + + + + + + {F790BC37-2732-4ED1-9CA5-7248BED3588E} @@ -213,7 +242,7 @@ true - $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\include\;$(IncludePath) + $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\third_party/streamvbyte/include;$(ProjectDir)..\include\;$(IncludePath) true @@ -225,7 +254,7 @@ false - $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\include\;$(IncludePath) + $(ProjectDir)..\third_party/json/include;$(ProjectDir)..\third_party/eigen;$(ProjectDir)..\third_party/variant/include;$(ProjectDir)..\third_party/cpuinfo/src;$(ProjectDir)..\third_party/cpuinfo/include;$(ProjectDir)..\third_party/mimalloc/include;$(ProjectDir)..\third_party/cpp-btree;$(ProjectDir)..\third_party/streamvbyte/include;$(ProjectDir)..\include\;$(IncludePath) false @@ -251,7 +280,7 @@ NotUsing Level3 Disabled - KIWI_ARCH_X86_64=1;KIWI_USE_MIMALLOC;_DEBUG;_CONSOLE;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_MIMALLOC;_DEBUG;_CONSOLE;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) true MultiThreadedDebug /Qvec-report:1 /utf-8 /D _CRT_SECURE_NO_WARNINGS=1 /bigobj %(AdditionalOptions) @@ -314,7 +343,7 @@ MaxSpeed true true - KIWI_ARCH_X86_64=1;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + KIWI_ARCH_X86_64=1;__restrict__=__restrict;KIWI_USE_BTREE;KIWI_USE_MIMALLOC;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) true AdvancedVectorExtensions2 /Qvec-report:1 /utf-8 /D _CRT_SECURE_NO_WARNINGS=1 /bigobj %(AdditionalOptions)