Skip to content

Commit dcdb0db

Browse files
committed
Access violation fix
1 parent 247abb4 commit dcdb0db

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

operators/tokenizer/case_encoder.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,49 +170,71 @@ class CaseEncoder {
170170
void PostProcess(std::string* normalized, std::vector<size_t>* norm_to_orig) {
171171
if (!seen_three_spans_) return;
172172

173+
// Safety check: ensure norm_to_orig has enough elements
174+
if (norm_to_orig->size() < normalized->size()) {
175+
return; // Cannot safely process - sizes don't match
176+
}
177+
173178
std::string normalized_temp;
174179
normalized_temp.reserve(normalized->size());
175180

176181
std::vector<size_t> norm_to_orig_temp;
177182
norm_to_orig_temp.reserve(norm_to_orig->size());
178183

179184
const char* sig_it = signature_.data();
185+
const char* sig_end = signature_.data() + signature_.length();
180186

181187
auto nrm_it = normalized->cbegin();
188+
auto nrm_end = normalized->cend();
182189
auto n2o_it = norm_to_orig->cbegin();
190+
auto n2o_end = norm_to_orig->cend();
183191

184192
for (const auto& span : Search(signature_)) {
185193
size_t len = std::distance(sig_it, span.first);
186194

195+
// Bounds check before advancing iterators
196+
if (std::distance(nrm_it, nrm_end) < static_cast<ptrdiff_t>(len) ||
197+
std::distance(n2o_it, n2o_end) < static_cast<ptrdiff_t>(len)) {
198+
break; // Not enough elements remaining
199+
}
200+
187201
normalized_temp.insert(normalized_temp.end(), nrm_it, nrm_it + len);
188202
norm_to_orig_temp.insert(norm_to_orig_temp.end(), n2o_it, n2o_it + len);
189203

190204
sig_it += len;
191205
nrm_it += len;
192206
n2o_it += len;
207+
208+
// Bounds check before dereferencing
209+
if (n2o_it == n2o_end) break;
210+
193211
normalized_temp.push_back(cAllUppercase);
194212
norm_to_orig_temp.push_back(*n2o_it);
195213

196214
while (sig_it != span.second) {
215+
if (sig_it >= sig_end || nrm_it >= nrm_end || n2o_it >= n2o_end) break;
216+
197217
if (*sig_it == cUppercase) {
198218
sig_it++;
199219
nrm_it++;
200220
n2o_it++;
201221
}
222+
if (sig_it >= sig_end || nrm_it >= nrm_end || n2o_it >= n2o_end) break;
223+
202224
sig_it++;
203225
normalized_temp.push_back(*nrm_it++);
204226
norm_to_orig_temp.push_back(*n2o_it++);
205227
}
206-
if (sig_it != signature_.data() + signature_.length()) {
207-
if (*sig_it != cUppercase) {
228+
if (sig_it != sig_end) {
229+
if (*sig_it != cUppercase && n2o_it != n2o_end) {
208230
normalized_temp.push_back(cLowercase);
209231
norm_to_orig_temp.push_back(*n2o_it);
210232
}
211233
}
212234
}
213235

214-
if (nrm_it != normalized->cend()) normalized_temp.insert(normalized_temp.end(), nrm_it, normalized->cend());
215-
if (n2o_it != norm_to_orig->cend()) norm_to_orig_temp.insert(norm_to_orig_temp.end(), n2o_it, norm_to_orig->cend());
236+
if (nrm_it != nrm_end) normalized_temp.insert(normalized_temp.end(), nrm_it, nrm_end);
237+
if (n2o_it != n2o_end) norm_to_orig_temp.insert(norm_to_orig_temp.end(), n2o_it, n2o_end);
216238

217239
normalized->swap(normalized_temp);
218240
norm_to_orig->swap(norm_to_orig_temp);

operators/tokenizer/ugm_kernels.hpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,9 @@ struct SpmUgmTokenizer {
537537
std::string NmtNormalize(const std::string& input) const {
538538
std::string normalized;
539539
normalized.reserve(input.size() * 3);
540-
std::vector<size_t> norm_to_orig(input.size() * 3);
540+
// Use a vector that tracks original positions - reserve capacity but start empty
541+
std::vector<size_t> norm_to_orig;
542+
norm_to_orig.reserve(input.size() * 3);
541543

542544
const std::string space = tokenizer_escape_whitespaces_ ? std::string(spm_escaped_space) : " ";
543545

@@ -551,40 +553,60 @@ struct SpmUgmTokenizer {
551553
size_t input_len = input.size();
552554

553555
std::string_view input_view(input);
554-
int consumed = 0;
556+
size_t orig_offset = 0;
555557

556558
while (!input_view.empty()) {
557559
auto p = case_encoder_->NormalizePrefix(input_view);
558560

561+
// Safety check: if nothing was consumed and nothing was returned, skip one byte to avoid infinite loop
562+
if (p.second == 0 && p.first.empty()) {
563+
// Advance by one UTF-8 character to prevent infinite loop
564+
size_t skip = std::min(ustring::UTF8Len(input_view[0]), input_view.size());
565+
orig_offset += skip;
566+
input_view.remove_prefix(skip);
567+
continue;
568+
}
569+
559570
for (size_t i = 0; i < p.first.size(); i++) {
560571
char c = p.first[i];
561572
if (c != ' ') {
562573
if (!processing_non_ws) {
563574
processing_non_ws = true;
564575
if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) {
565576
normalized.append(space);
577+
// Track original position for space characters
578+
for (size_t j = 0; j < space.size(); j++) {
579+
norm_to_orig.push_back(orig_offset);
580+
}
566581
is_space_prepended = true;
567582
}
568583
}
569584
normalized.push_back(c);
585+
norm_to_orig.push_back(orig_offset + i);
570586
} else {
571587
if (processing_non_ws) {
572588
processing_non_ws = false;
573589
}
574590
if (!shall_merge_spaces) {
575591
normalized.append(space);
592+
for (size_t j = 0; j < space.size(); j++) {
593+
norm_to_orig.push_back(orig_offset + i);
594+
}
576595
}
577596
}
578597
}
579598

580-
consumed += p.second;
581-
input_view.remove_prefix(p.second);
599+
orig_offset += static_cast<size_t>(p.second);
600+
input_view.remove_prefix(static_cast<size_t>(p.second));
582601
}
583602

584603
case_encoder_->PostProcess(&normalized, &norm_to_orig);
585604

586605
if (shall_append_space) {
587606
normalized.append(space);
607+
for (size_t j = 0; j < space.size(); j++) {
608+
norm_to_orig.push_back(input.size());
609+
}
588610
}
589611

590612
return normalized;

0 commit comments

Comments
 (0)