@@ -135,7 +135,7 @@ namespace kiwi
135
135
{
136
136
nextLmStates.resize (prevLmStates.size () * nextWids.size ());
137
137
scores.resize (prevLmStates.size () * nextWids.size ());
138
- langMdl->template progressMatrix <windowSize> (prevLmStates.data (), nextWids.data (), prevLmStates.size (), nextWids.size (), nextDistantWids.size (), nextLmStates.data (), scores.data ());
138
+ langMdl->progressMatrix (prevLmStates.data (), nextWids.data (), prevLmStates.size (), nextWids.size (), nextDistantWids.size (), nextLmStates.data (), scores.data ());
139
139
}
140
140
}
141
141
@@ -1211,11 +1211,12 @@ namespace kiwi
1211
1211
}
1212
1212
1213
1213
template <ArchType arch, class KeyType , class VlKeyType , size_t windowSize, bool quantized>
1214
- template <size_t _windowSize>
1215
1214
void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrix(
1216
- const typename std::enable_if<(_windowSize > 0 ), LmStateType>::type * prevStates, const KeyType* nextIds,
1215
+ const LmStateType* prevStates, const KeyType* nextIds,
1217
1216
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
1218
1217
LmStateType* outStates, float * outScores) const
1218
+ {
1219
+ if constexpr (windowSize > 0 )
1219
1220
{
1220
1221
thread_local TLSForProgressMatrix tls;
1221
1222
if (prevStateSize <= (quantized ? 16 : 8 ) && nextIdSize <= 16 )
@@ -1225,13 +1226,17 @@ namespace kiwi
1225
1226
else
1226
1227
{
1227
1228
return progressMatrixWSort (tls, prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores);
1229
+ }
1230
+ }
1231
+ else
1232
+ {
1233
+ return progressMatrixNoWindow (prevStates, nextIds, prevStateSize, nextIdSize, numValidDistantTokens, outStates, outScores);
1228
1234
}
1229
1235
}
1230
1236
1231
1237
template <ArchType arch, class KeyType , class VlKeyType , size_t windowSize, bool quantized>
1232
- template <size_t _windowSize>
1233
- void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrix(
1234
- const typename std::enable_if<_windowSize == 0 , LmStateType>::type* prevStates, const KeyType* nextIds,
1238
+ void CoNgramModel<arch, KeyType, VlKeyType, windowSize, quantized>::progressMatrixNoWindow(
1239
+ const LmStateType* prevStates, const KeyType* nextIds,
1235
1240
size_t prevStateSize, size_t nextIdSize, size_t numValidDistantTokens,
1236
1241
LmStateType* outStates, float * outScores) const
1237
1242
{
0 commit comments