Skip to content

Commit 005210a

Browse files
committed
replace median_improve
1 parent afffe65 commit 005210a

File tree

4 files changed

+133
-236
lines changed

4 files changed

+133
-236
lines changed

src/Levenshtein/Levenshtein-c/_levenshtein.hpp

Lines changed: 96 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#define LEVENSHTEIN_H
44

55
#include "Python.h"
6+
#include <cstdint>
67
#include <numeric>
78
#include <memory>
89
#include <vector>
@@ -401,73 +402,72 @@ static inline std::basic_string<uint32_t> lev_greedy_median(const std::vector<RF
401402
*
402403
* string1, len1 are already shortened.
403404
*/
404-
template <typename CharT>
405-
double finish_distance_computations(size_t len1, CharT* string1,
406-
size_t n, const size_t* lengths,
407-
const CharT** strings,
408-
const double *weights, std::vector<std::unique_ptr<size_t[]>>& rows,
405+
static inline double finish_distance_computations(size_t len1, uint32_t* string1,
406+
const std::vector<RF_String>& strings,
407+
const std::vector<double>& weights, std::vector<std::unique_ptr<size_t[]>>& rows,
409408
std::unique_ptr<size_t[]>& row)
410409
{
411410
size_t *end;
412411
size_t i, j;
413412
size_t offset; /* row[0]; offset + len1 give together real len of string1 */
414413
double distsum = 0.0; /* sum of distances */
415414

416-
/* catch trivia case */
415+
/* catch trivial case */
417416
if (len1 == 0) {
418-
for (j = 0; j < n; j++)
419-
distsum += (double)rows[j][lengths[j]]*weights[j];
417+
for (j = 0; j < strings.size(); j++)
418+
distsum += (double)rows[j][strings[j].length]*weights[j];
420419
return distsum;
421420
}
422421

423422
/* iterate through the strings and sum the distances */
424-
for (j = 0; j < n; j++) {
425-
size_t* rowi = rows[j].get(); /* current row */
426-
size_t leni = lengths[j]; /* current length */
427-
size_t len = len1; /* temporary len1 for suffix stripping */
428-
const CharT* stringi = strings[j]; /* current string */
429-
430-
/* strip common suffix (prefix CAN'T be stripped) */
431-
while (len && leni && stringi[leni-1] == string1[len-1]) {
432-
len--;
433-
leni--;
434-
}
423+
for (j = 0; j < strings.size(); j++) {
424+
visit(strings[j], [&](auto first1, auto last1){
425+
size_t* rowi = rows[j].get(); /* current row */
426+
size_t leni = (size_t)std::distance(first1, last1); /* current length */
427+
size_t len = len1; /* temporary len1 for suffix stripping */
428+
429+
/* strip common suffix (prefix CAN'T be stripped) */
430+
while (len && leni && first1[leni-1] == string1[len-1]) {
431+
len--;
432+
leni--;
433+
}
435434

436-
/* catch trivial cases */
437-
if (len == 0) {
438-
distsum += (double)rowi[leni]*weights[j];
439-
continue;
440-
}
441-
offset = rowi[0];
442-
if (leni == 0) {
443-
distsum += (double)(offset + len)*weights[j];
444-
continue;
445-
}
435+
/* catch trivial cases */
436+
if (len == 0) {
437+
distsum += (double)rowi[leni]*weights[j];
438+
return;
439+
}
440+
offset = rowi[0];
441+
if (leni == 0) {
442+
distsum += (double)(offset + len)*weights[j];
443+
return;
444+
}
446445

447-
/* complete the matrix */
448-
memcpy(row.get(), rowi, (leni + 1)*sizeof(size_t));
449-
end = row.get() + leni;
450-
451-
for (i = 1; i <= len; i++) {
452-
size_t* p = row.get() + 1;
453-
const CharT char1 = string1[i - 1];
454-
const CharT* char2p = stringi;
455-
size_t D, x;
456-
457-
D = x = i + offset;
458-
while (p <= end) {
459-
size_t c3 = --D + (char1 != *(char2p++));
460-
x++;
461-
if (x > c3)
462-
x = c3;
463-
D = *p;
464-
D++;
465-
if (x > D)
466-
x = D;
467-
*(p++) = x;
446+
/* complete the matrix */
447+
memcpy(row.get(), rowi, (leni + 1)*sizeof(size_t));
448+
end = row.get() + leni;
449+
450+
for (i = 1; i <= len; i++) {
451+
size_t* p = row.get() + 1;
452+
const uint32_t char1 = string1[i - 1];
453+
auto char2p = first1;
454+
size_t D, x;
455+
456+
D = x = i + offset;
457+
while (p <= end) {
458+
size_t c3 = --D + (char1 != *(char2p++));
459+
x++;
460+
if (x > c3)
461+
x = c3;
462+
D = *p;
463+
D++;
464+
if (x > D)
465+
x = D;
466+
*(p++) = x;
467+
}
468468
}
469-
}
470-
distsum += weights[j]*(double)(*end);
469+
distsum += weights[j]*(double)(*end);
470+
});
471471
}
472472

473473
return distsum;
@@ -491,40 +491,45 @@ double finish_distance_computations(size_t len1, CharT* string1,
491491
*
492492
* Returns: The improved generalized median
493493
**/
494-
template <typename CharT>
495-
std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n, const size_t* lengths,
496-
const CharT** strings, const double *weights)
494+
static inline std::basic_string<uint32_t> lev_median_improve2(const RF_String& string,
495+
const std::vector<RF_String>& strings, const std::vector<double>& weights)
497496
{
498497
/* find all symbols */
499-
std::vector<CharT> symlist = make_symlist(n, lengths, strings);
498+
std::vector<uint32_t> symlist = make_symlist(strings);
500499
if (symlist.empty()) {
501-
return std::basic_string<CharT>();
500+
return std::basic_string<uint32_t>();
502501
}
503502

504503
/* allocate and initialize per-string matrix rows and a common work buffer */
505-
std::vector<std::unique_ptr<size_t[]>> rows(n);
506-
size_t maxlen = *std::max_element(lengths, lengths + n);
504+
std::vector<std::unique_ptr<size_t[]>> rows(strings.size());
505+
size_t maxlen = 0;
506+
for (const auto& str : strings) {
507+
maxlen = std::max(maxlen, (size_t)str.length);
508+
}
507509

508-
for (size_t i = 0; i < n; i++) {
509-
size_t leni = lengths[i];
510+
for (size_t i = 0; i < strings.size(); i++) {
511+
size_t leni = (size_t)strings[i].length;
510512
rows[i] = std::make_unique<size_t[]>(leni + 1);
511513
std::iota(rows[i].get(), rows[i].get() + leni + 1, 0);
512514
}
515+
513516
size_t stoplen = 2*maxlen + 1;
514517
auto row = std::make_unique<size_t[]>(stoplen + 1);
515518

516519
/* initialize median to given string */
517-
auto _median = std::make_unique<CharT[]>(stoplen + 1);
518-
CharT* median = _median.get() + 1; /* we need -1st element for insertions a pos 0 */
519-
size_t medlen = len;
520-
memcpy(median, s, (medlen)*sizeof(CharT));
521-
double minminsum = finish_distance_computations(medlen, median,
522-
n, lengths, strings,
523-
weights, rows, row);
520+
auto _median = std::make_unique<uint32_t[]>(stoplen + 1);
521+
uint32_t* median = _median.get() + 1; /* we need -1st element for insertions a pos 0 */
522+
size_t medlen = (size_t)string.length;
523+
524+
visit(string, [&](auto first1, auto last1){
525+
std::copy(first1, last1, median);
526+
});
527+
528+
double minminsum = finish_distance_computations(medlen, median, strings, weights, rows, row);
524529

525530
/* sequentially try perturbations on all positions */
526531
for (size_t pos = 0; pos <= medlen; ) {
527-
CharT orig_symbol, symbol;
532+
uint32_t orig_symbol, symbol;
528533
LevEditType operation;
529534
double sum;
530535

@@ -538,9 +543,7 @@ std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n
538543
if (symlist[j] == orig_symbol)
539544
continue;
540545
median[pos] = symlist[j];
541-
sum = finish_distance_computations(medlen - pos, median + pos,
542-
n, lengths, strings,
543-
weights, rows, row);
546+
sum = finish_distance_computations(medlen - pos, median + pos, strings, weights, rows, row);
544547
if (sum < minminsum) {
545548
minminsum = sum;
546549
symbol = symlist[j];
@@ -555,9 +558,7 @@ std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n
555558
orig_symbol = *(median + pos - 1);
556559
for (size_t j = 0; j < symlist.size(); j++) {
557560
*(median + pos - 1) = symlist[j];
558-
sum = finish_distance_computations(medlen - pos + 1, median + pos - 1,
559-
n, lengths, strings,
560-
weights, rows, row);
561+
sum = finish_distance_computations(medlen - pos + 1, median + pos - 1, strings, weights, rows, row);
561562
if (sum < minminsum) {
562563
minminsum = sum;
563564
symbol = symlist[j];
@@ -568,9 +569,7 @@ std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n
568569
/* IF pos < medlen: try to delete the symbol at pos, if it lowers
569570
* the total distance remember it (decrease medlen) */
570571
if (pos < medlen) {
571-
sum = finish_distance_computations(medlen - pos - 1, median + pos + 1,
572-
n, lengths, strings,
573-
weights, rows, row);
572+
sum = finish_distance_computations(medlen - pos - 1, median + pos + 1, strings, weights, rows, row);
574573
if (sum < minminsum) {
575574
minminsum = sum;
576575
operation = LEV_EDIT_DELETE;
@@ -584,14 +583,14 @@ std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n
584583

585584
case LEV_EDIT_INSERT:
586585
memmove(median+pos+1, median+pos,
587-
(medlen - pos)*sizeof(CharT));
586+
(medlen - pos)*sizeof(uint32_t));
588587
median[pos] = symbol;
589588
medlen++;
590589
break;
591590

592591
case LEV_EDIT_DELETE:
593592
memmove(median+pos, median + pos+1,
594-
(medlen - pos-1)*sizeof(CharT));
593+
(medlen - pos-1)*sizeof(uint32_t));
595594
medlen--;
596595
break;
597596

@@ -603,26 +602,28 @@ std::basic_string<CharT> lev_median_improve(size_t len, const CharT* s, size_t n
603602
if (operation != LEV_EDIT_DELETE) {
604603
symbol = median[pos];
605604
row[0] = pos + 1;
606-
for (size_t i = 0; i < n; i++) {
607-
const CharT* stri = strings[i];
608-
size_t* oldrow = rows[i].get();
609-
size_t leni = lengths[i];
610-
/* compute a row of Levenshtein matrix */
611-
for (size_t k = 1; k <= leni; k++) {
612-
size_t c1 = oldrow[k] + 1;
613-
size_t c2 = row[k - 1] + 1;
614-
size_t c3 = oldrow[k - 1] + (symbol != stri[k - 1]);
615-
row[k] = c2 > c3 ? c3 : c2;
616-
if (row[k] > c1)
617-
row[k] = c1;
618-
}
619-
memcpy(oldrow, row.get(), (leni + 1)*sizeof(size_t));
605+
606+
for (size_t i = 0; i < strings.size(); i++) {
607+
visit(strings[i], [&](auto first1, auto last1){
608+
size_t* oldrow = rows[i].get();
609+
size_t leni = (size_t)std::distance(first1, last1);
610+
/* compute a row of Levenshtein matrix */
611+
for (size_t k = 1; k <= leni; k++) {
612+
size_t c1 = oldrow[k] + 1;
613+
size_t c2 = row[k - 1] + 1;
614+
size_t c3 = oldrow[k - 1] + (symbol != first1[k - 1]);
615+
row[k] = c2 > c3 ? c3 : c2;
616+
if (row[k] > c1)
617+
row[k] = c1;
618+
}
619+
memcpy(oldrow, row.get(), (leni + 1)*sizeof(size_t));
620+
});
620621
}
621622
pos++;
622623
}
623624
}
624625

625-
return std::basic_string<CharT>(median, medlen);
626+
return std::basic_string<uint32_t>(median, medlen);
626627
}
627628

628629
std::basic_string<lev_byte>

src/Levenshtein/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from rapidfuzz.distance.JaroWinkler import similarity as jaro_winkler
2626

2727
from Levenshtein._levenshtein import (
28-
median_improve,
2928
quickmedian
3029
)
3130

@@ -37,6 +36,7 @@
3736
subtract_edit,
3837
apply_edit,
3938
median,
39+
median_improve,
4040
setmedian,
4141
setratio,
4242
seqratio

0 commit comments

Comments
 (0)