Skip to content

Commit dbd90bc

Browse files
committed
Fix compilation errors
1 parent dbb7582 commit dbd90bc

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

Diff for: src/CoNgramModel.cpp

+11-6
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ namespace kiwi
135135
{
136136
nextLmStates.resize(prevLmStates.size() * nextWids.size());
137137
scores.resize(prevLmStates.size() * nextWids.size());
138-
langMdl->template progressMatrix<windowSize>(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data());
138+
langMdl->progressMatrix(prevLmStates.data(), nextWids.data(), prevLmStates.size(), nextWids.size(), nextDistantWids.size(), nextLmStates.data(), scores.data());
139139
}
140140
}
141141

@@ -1211,11 +1211,12 @@ namespace kiwi
12111211
}
12121212

12131213
template<ArchType arch, class KeyType, class VlKeyType, size_t windowSize, bool quantized>
1214-
template<size_t _windowSize>
12151214
void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrix(
1216-
const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds,
1215+
const LmStateType* prevStates, const KeyType* nextIds,
12171216
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
12181217
LmStateType* outStates, float* outScores) const
1218+
{
1219+
if constexpr (windowSize > 0)
12191220
{
12201221
thread_local TLSForProgressMatrix tls;
12211222
if (prevStateSize <= (quantized ? 16 : 8) && nextIdSize <= 16)
@@ -1225,13 +1226,17 @@ namespace kiwi
12251226
else
12261227
{
12271228
return progressMatrixWSort(tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores);
1229+
}
1230+
}
1231+
else
1232+
{
1233+
return progressMatrixNoWindow(prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores);
12281234
}
12291235
}
12301236

12311237
template<ArchType arch, class KeyType, class VlKeyType, size_t windowSize, bool quantized>
1232-
template<size_t _windowSize>
1233-
void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrix(
1234-
const typename std::enable_if<_windowSize == 0, LmStateType>::type* prevStates, const KeyType* nextIds,
1238+
void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrixNoWindow(
1239+
const LmStateType* prevStates, const KeyType* nextIds,
12351240
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
12361241
LmStateType* outStates, float* outScores) const
12371242
{

Diff for: src/CoNgramModel.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ namespace kiwi
311311
* 새 상태값은 outStates에 저장되고, 각 상태별 확률값은 outScores에 저장된다.
312312
* nextIdSize개의 다음 토큰 중 마지막 numValidDistantTokens개의 토큰은 유효한 distant 토큰으로 처리된다.
313313
*/
314-
template<size_t _windowSize>
315-
void progressMatrix(const typename std::enable_if<(_windowSize > 0), LmStateType>::type* prevStates, const KeyType* nextIds,
314+
void progressMatrix(const LmStateType* prevStates, const KeyType* nextIds,
316315
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
317316
LmStateType* outStates, float* outScores) const;
318317

@@ -326,8 +325,7 @@ namespace kiwi
326325
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
327326
LmStateType* outStates, float* outScores) const;
328327

329-
template<size_t _windowSize>
330-
void progressMatrix(const typename std::enable_if<(_windowSize == 0), LmStateType>::type* prevStates, const KeyType* nextIds,
328+
void progressMatrixNoWindow(const LmStateType* prevStates, const KeyType* nextIds,
331329
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
332330
LmStateType* outStates, float* outScores) const;
333331
};

0 commit comments

Comments
 (0)