|
59 | 59 | import elasticapm |
60 | 60 | import regex |
61 | 61 | from blake3 import blake3 |
62 | | -from editdistancek_rs import distance |
63 | 62 | from elastic_transport import ApiError, TransportError |
64 | 63 | from elasticsearch import AsyncElasticsearch |
65 | 64 | from geoip import geolite2 # type: ignore[import-untyped] |
66 | 65 | from openmoji_dist import VERSION as OPENMOJI_VERSION |
| 66 | +from rapidfuzz.distance.Levenshtein import distance |
67 | 67 | from redis.asyncio import Redis |
68 | 68 | from tornado.web import HTTPError, RequestHandler |
69 | 69 | from typed_stream import Stream |
@@ -300,6 +300,16 @@ def bool_to_str(val: bool) -> str: |
300 | 300 | return "sure" if val else "nope" |
301 | 301 |
|
302 | 302 |
|
| 303 | +def bounded_edit_distance(s1: str, s2: str, /, k: int) -> int: |
| 304 | + """Return a bounded edit distance between two strings. |
| 305 | +
|
| 306 | + k is the maximum number returned |
| 307 | + """ |
| 308 | + if (dist := distance(s1, s2, score_cutoff=k)) == k + 1: |
| 309 | + return k |
| 310 | + return dist |
| 311 | + |
| 312 | + |
303 | 313 | def country_code_to_flag(code: str) -> str: |
304 | 314 | """Convert a two-letter ISO country code to a flag emoji.""" |
305 | 315 | return "".join(chr(ord(char) + 23 * 29 * 191) for char in code.upper()) |
@@ -526,7 +536,9 @@ def get_close_matches( # based on difflib.get_close_matches |
526 | 536 | result: list[tuple[float, str]] = [] |
527 | 537 | for possibility in possibilities: |
528 | 538 | if max_dist := max(word_len, len(possibility)): |
529 | | - dist = distance(possibility, word, 1 + int(cutoff * max_dist)) |
| 539 | + dist = bounded_edit_distance( |
| 540 | + possibility, word, 1 + int(cutoff * max_dist) |
| 541 | + ) |
530 | 542 | if (ratio := dist / max_dist) <= cutoff: |
531 | 543 | bisect.insort(result, (ratio, possibility)) |
532 | 544 | if len(result) > count: |
|
0 commit comments