Skip to content

Commit d242407

Browse files
committed
Fix character capitalization issue in marian tokenizer
1 parent 7387a4e commit d242407

File tree

1 file changed

+39
-2
lines changed

1 file changed

+39
-2
lines changed

operators/tokenizer/ugm_kernels.hpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,43 @@ class SpmUgmDecoder {
879879
token = prefix + suffix;
880880
}
881881

882+
// UTF-8 aware uppercase for the whole token
883+
void UppercaseUTF8(std::string& text) const {
884+
if (text.empty()) return;
885+
886+
std::string result;
887+
result.reserve(text.size());
888+
889+
size_t i = 0;
890+
while (i < text.size()) {
891+
// Decode next codepoint from current position
892+
wchar_t codepoint;
893+
size_t char_len = 0;
894+
895+
// Create a view starting at i; if decoding fails, copy raw byte
896+
std::string remaining = text.substr(i);
897+
if (!DecodeFirstUTF8Codepoint(remaining, codepoint, char_len) || char_len == 0) {
898+
result.push_back(text[i]);
899+
++i;
900+
continue;
901+
}
902+
903+
// Cyrillic special cases
904+
if (codepoint >= L'а' && codepoint <= L'я') {
905+
codepoint = codepoint - (L'а' - L'А');
906+
} else if (codepoint == L'ё') {
907+
codepoint = L'Ё';
908+
} else {
909+
codepoint = std::towupper(codepoint);
910+
}
911+
912+
result += EncodeUTF8(codepoint);
913+
i += char_len;
914+
}
915+
916+
text.swap(result);
917+
}
918+
882919
OrtxStatus Id2Token(extTokenId_t id, std::string& token, TokenizerDecodingState** state, bool skip_special_tokens /* only used by BPE; placeholder for UGM */ = true) const {
883920
std::unique_ptr<TokenizerDecodingState> decoding_state;
884921
if (*state == nullptr) {
@@ -932,7 +969,7 @@ class SpmUgmDecoder {
932969
switch (signature) {
933970
case normalizer::cUppercase:
934971
case normalizer::cAllUppercase:
935-
std::transform(token.begin(), token.end(), token.begin(), ::toupper);
972+
UppercaseUTF8(token);
936973
break;
937974
case normalizer::cTitlecase:
938975
TitlecaseFirstCharacter(token);
@@ -953,7 +990,7 @@ class SpmUgmDecoder {
953990
switch (first_char) {
954991
case normalizer::cUppercase:
955992
case normalizer::cAllUppercase:
956-
std::transform(token.begin(), token.end(), token.begin(), ::toupper);
993+
UppercaseUTF8(token);
957994
break;
958995
case normalizer::cTitlecase:
959996
TitlecaseFirstCharacter(token);

0 commit comments

Comments
 (0)