Skip to content

Commit 1c214fd

Browse files
committed
Add faster base58 encode/decode
Add optional gmpy2.mpz for even faster encode/decode Add longer random benchmark
1 parent 578b01f commit 1c214fd

File tree

2 files changed

+103
-24
lines changed

2 files changed

+103
-24
lines changed

base58/__init__.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from functools import lru_cache
1313
from hashlib import sha256
1414
from typing import Mapping, Union
15+
from math import log
16+
17+
try:
18+
from gmpy2 import mpz
19+
except ImportError:
20+
mpz = None
1521

1622
__version__ = '2.1.1'
1723

@@ -20,6 +26,10 @@
2026
b'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz'
2127
RIPPLE_ALPHABET = b'rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz'
2228
XRP_ALPHABET = RIPPLE_ALPHABET
29+
POWERS = {
30+
45: {2 ** i: 45 ** (2 ** i) for i in range(4, 20)},
31+
58: {2 ** i: 58 ** (2 ** i) for i in range(4, 20)}
32+
}
2333

2434
# Retro compatibility
2535
alphabet = BITCOIN_ALPHABET
@@ -32,19 +42,64 @@ def scrub_input(v: Union[str, bytes]) -> bytes:
3242
return v
3343

3444

45+
def _encode_int(i: int, base: int = 58, alphabet: bytes = BITCOIN_ALPHABET) -> bytes:
46+
"""
47+
Encode integer to bytes with base 58 alphabet by powers of 58
48+
"""
49+
min_val = POWERS[base][2**8]
50+
if i <= min_val:
51+
string = bytearray()
52+
while i:
53+
i, idx = divmod(i, base)
54+
string.append(idx)
55+
return string[::-1]
56+
else:
57+
origlen0 = int(log(i, 58))//2
58+
try:
59+
split_num = POWERS[base][2**origlen0]
60+
except KeyError:
61+
POWERS[base][2**origlen0] = split_num = base ** origlen0
62+
i1, i0 = divmod(i, split_num)
63+
64+
v1 = _encode_int(i1, base, alphabet)
65+
v0 = _encode_int(i0, base, alphabet)
66+
newlen0 = len(v0)
67+
if newlen0 < origlen0:
68+
v0[:0] = b'\0' * (origlen0 - newlen0)
69+
70+
return v1 + v0
71+
72+
73+
def _mpz_encode(i: int, alphabet: bytes) -> bytes:
74+
"""
75+
Encode an integer to arbitrary base using gmpy2 mpz
76+
"""
77+
base = len(alphabet)
78+
79+
raw: bytes = mpz(i).digits(base).encode()
80+
tr_bytes = bytes.maketrans(''.join([mpz(x).digits(base) for x in range(base)]).encode(), alphabet)
81+
encoded: bytes = raw.translate(tr_bytes)
82+
83+
return encoded
84+
85+
3586
def b58encode_int(
3687
i: int, default_one: bool = True, alphabet: bytes = BITCOIN_ALPHABET
3788
) -> bytes:
3889
"""
3990
Encode an integer using Base58
4091
"""
41-
if not i and default_one:
42-
return alphabet[0:1]
43-
string = b""
92+
if not i:
93+
if default_one:
94+
return alphabet[0:1]
95+
return b''
96+
if mpz:
97+
return _mpz_encode(i, alphabet)
98+
4499
base = len(alphabet)
45-
while i:
46-
i, idx = divmod(i, base)
47-
string = alphabet[idx:idx+1] + string
100+
raw_string = _encode_int(i, base, alphabet)
101+
string = raw_string.translate(bytes.maketrans(bytearray(range(len(alphabet))), alphabet))
102+
48103
return string
49104

50105

@@ -82,6 +137,24 @@ def _get_base58_decode_map(alphabet: bytes,
82137
return invmap
83138

84139

140+
def _decode(data: bytes, min_split: int = 256, base: int = 58) -> int:
141+
"""
142+
Decode larger data blocks recursively
143+
"""
144+
if len(data) <= min_split:
145+
ret_int = 0
146+
for val in data:
147+
ret_int = base * ret_int + val
148+
return ret_int
149+
else:
150+
split_len = 2**(len(data).bit_length()-2)
151+
try:
152+
base_pow = POWERS[base][split_len]
153+
except KeyError:
154+
POWERS[base] = base_pow = base ** split_len
155+
return (base_pow * _decode(data[:-split_len])) + _decode(data[-split_len:])
156+
157+
85158
def b58decode_int(
86159
v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *,
87160
autofix: bool = False
@@ -93,18 +166,29 @@ def b58decode_int(
93166
v = v.rstrip()
94167
v = scrub_input(v)
95168

169+
base = len(alphabet)
96170
map = _get_base58_decode_map(alphabet, autofix=autofix)
171+
if mpz:
172+
tr_bytes = bytes.maketrans(bytearray(map.keys()), ''.join([mpz(x).digits(base) for x in map.values()]).encode())
173+
else:
174+
tr_bytes = bytes.maketrans(bytearray(map.keys()), bytearray(map.values()))
175+
del_chars = bytes(bytearray(x for x in range(256) if x not in map))
97176

98-
decimal = 0
99-
base = len(alphabet)
100-
try:
101-
for char in v:
102-
decimal = decimal * base + map[char]
103-
except KeyError as e:
104-
raise ValueError(
105-
"Invalid character {!r}".format(chr(e.args[0]))
106-
) from None
107-
return decimal
177+
cv = v.translate(tr_bytes, delete=del_chars)
178+
if len(v) != len(cv):
179+
err_char = chr(next(c for c in v if c not in map))
180+
raise ValueError("Invalid character {!r}".format(err_char))
181+
182+
if cv == b'':
183+
return 0
184+
185+
if mpz:
186+
try:
187+
return int(mpz(cv, base=base))
188+
except ValueError:
189+
raise ValueError(cv, base)
190+
191+
return _decode(cv, base=base)
108192

109193

110194
def b58decode(
@@ -123,12 +207,7 @@ def b58decode(
123207

124208
acc = b58decode_int(v, alphabet=alphabet, autofix=autofix)
125209

126-
result = []
127-
while acc > 0:
128-
acc, mod = divmod(acc, 256)
129-
result.append(mod)
130-
131-
return b'\0' * (origlen - newlen) + bytes(reversed(result))
210+
return acc.to_bytes(origlen - newlen + (acc.bit_length() + 7) // 8, "big")
132211

133212

134213
def b58encode_check(

test_base58.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ def test_invalid_input():
134134
raises(ValueError, "Invalid character '\\\\x08'"))
135135

136136

137-
@pytest.mark.parametrize('length', [8, 32, 256, 1024])
137+
@pytest.mark.parametrize('length', [8, 32, 256, 1024, 8192])
138138
def test_encode_random(benchmark, length) -> None:
139139
data = getrandbits(length * 8).to_bytes(length, byteorder='big')
140140
encoded = benchmark(lambda: b58encode(data))
141141
assert_that(b58decode(encoded), equal_to(data))
142142

143143

144-
@pytest.mark.parametrize('length', [8, 32, 256, 1024])
144+
@pytest.mark.parametrize('length', [8, 32, 256, 1024, 8192])
145145
def test_decode_random(benchmark, length) -> None:
146146
origdata = getrandbits(length * 8).to_bytes(length, byteorder='big')
147147
encoded = b58encode(origdata)

0 commit comments

Comments
 (0)