Skip to content

Commit d231cea

Browse files
committed
Another fix
1 parent 247d52d commit d231cea

File tree

2 files changed

+32
-20
lines changed

2 files changed

+32
-20
lines changed

operators/tokenizer/case_encoder.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ class CaseEncoder {
3333
virtual ~CaseEncoder() {}
3434
void SetNormalizer(Normalizer normalizer) { normalizer_ = normalizer; }
3535

36+
// Reset all state for a new tokenization call
37+
void Reset() {
38+
buffer_.clear();
39+
buffer_queue_.clear();
40+
signature_.clear();
41+
offset_ = 0;
42+
dump_buffer_from_ = -1;
43+
state_ = 0;
44+
spans_ = 0;
45+
seen_three_spans_ = false;
46+
}
47+
3648
public:
3749
CaseEncoder(bool remove_extra_white_space) : remove_extra_white_space_(remove_extra_white_space) {}
3850

@@ -98,10 +110,6 @@ class CaseEncoder {
98110
buffer_.clear();
99111
buffer_queue_.clear();
100112
offset_ = 0;
101-
// Reset all state for new tokenization
102-
signature_.clear();
103-
spans_ = 0;
104-
seen_three_spans_ = false;
105113
}
106114

107115
if (isUpper(sp)) {
@@ -174,9 +182,10 @@ class CaseEncoder {
174182
void PostProcess(std::string* normalized, std::vector<size_t>* norm_to_orig) {
175183
if (!seen_three_spans_) return;
176184

177-
// Safety check: ensure norm_to_orig has enough elements
178-
if (norm_to_orig->size() < normalized->size()) {
179-
return; // Cannot safely process - sizes don't match
185+
// Ensure norm_to_orig has at least as many elements as normalized
186+
// Pad with zeros if needed to prevent out-of-bounds access
187+
while (norm_to_orig->size() < normalized->size()) {
188+
norm_to_orig->push_back(0);
180189
}
181190

182191
std::string normalized_temp;

operators/tokenizer/ugm_kernels.hpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,12 @@ struct SpmUgmTokenizer {
535535
}
536536

537537
std::string NmtNormalize(const std::string& input) const {
538+
// Reset the case encoder state before starting new normalization
539+
case_encoder_->Reset();
540+
538541
std::string normalized;
539542
normalized.reserve(input.size() * 3);
540-
// Use a vector that tracks original positions - reserve capacity but start empty
543+
// Track norm_to_orig to match normalized string exactly
541544
std::vector<size_t> norm_to_orig;
542545
norm_to_orig.reserve(input.size() * 3);
543546

@@ -550,10 +553,8 @@ struct SpmUgmTokenizer {
550553
bool is_space_prepended = false;
551554
bool processing_non_ws = false;
552555

553-
size_t input_len = input.size();
554-
555556
std::string_view input_view(input);
556-
size_t orig_offset = 0;
557+
size_t orig_pos = 0; // Current position in original input
557558

558559
while (!input_view.empty()) {
559560
auto p = case_encoder_->NormalizePrefix(input_view);
@@ -562,7 +563,7 @@ struct SpmUgmTokenizer {
562563
if (p.second == 0 && p.first.empty()) {
563564
// Advance by one UTF-8 character to prevent infinite loop
564565
size_t skip = std::min(ustring::UTF8Len(input_view[0]), input_view.size());
565-
orig_offset += skip;
566+
orig_pos += skip;
566567
input_view.remove_prefix(skip);
567568
continue;
568569
}
@@ -574,39 +575,41 @@ struct SpmUgmTokenizer {
574575
processing_non_ws = true;
575576
if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
576577
normalized.append(space);
577-
// Track original position for space characters
578578
for (size_t j = 0; j < space.size(); j++) {
579-
norm_to_orig.push_back(orig_offset);
579+
norm_to_orig.push_back(orig_pos);
580580
}
581581
is_space_prepended = true;
582582
}
583583
}
584584
normalized.push_back(c);
585-
norm_to_orig.push_back(orig_offset + i);
585+
norm_to_orig.push_back(orig_pos);
586586
} else {
587587
if (processing_non_ws) {
588588
processing_non_ws = false;
589589
}
590590
if (!shall_merge_spaces) {
591591
normalized.append(space);
592592
for (size_t j = 0; j < space.size(); j++) {
593-
norm_to_orig.push_back(orig_offset + i);
593+
norm_to_orig.push_back(orig_pos);
594594
}
595595
}
596596
}
597597
}
598598

599-
orig_offset += static_cast<size_t>(p.second);
599+
orig_pos += static_cast<size_t>(p.second);
600600
input_view.remove_prefix(static_cast<size_t>(p.second));
601601
}
602602

603+
// Ensure norm_to_orig matches normalized size before PostProcess
604+
// This is critical for PostProcess to work correctly
605+
while (norm_to_orig.size() < normalized.size()) {
606+
norm_to_orig.push_back(orig_pos > 0 ? orig_pos - 1 : 0);
607+
}
608+
603609
case_encoder_->PostProcess(&normalized, &norm_to_orig);
604610

605611
if (shall_append_space) {
606612
normalized.append(space);
607-
for (size_t j = 0; j < space.size(); j++) {
608-
norm_to_orig.push_back(input.size());
609-
}
610613
}
611614

612615
return normalized;

0 commit comments

Comments
 (0)