diff --git a/base58/__init__.py b/base58/__init__.py index 929014f..2dd4940 100644 --- a/base58/__init__.py +++ b/base58/__init__.py @@ -1,8 +1,8 @@ -'''Base58 encoding +"""Base58 encoding Implementations of Base58 and Base58Check encodings that are compatible with the bitcoin network. -''' +""" # This module is based upon base58 snippets found scattered over many bitcoin # tools written in python. From what I gather the original source is from a @@ -11,15 +11,25 @@ from functools import lru_cache from hashlib import sha256 -from typing import Mapping, Union +from typing import Dict, Tuple, Union +from math import log -__version__ = '2.1.1' +try: + from gmpy2 import mpz +except ImportError: + mpz = None + +__version__ = "2.1.1" # 58 character alphabet used -BITCOIN_ALPHABET = \ - b'123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' -RIPPLE_ALPHABET = b'rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz' +BITCOIN_ALPHABET = b"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" +RIPPLE_ALPHABET = b"rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz" XRP_ALPHABET = RIPPLE_ALPHABET +_MPZ_ALPHABET = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +POWERS = { + 45: {2**i: 45 ** (2**i) for i in range(4, 20)}, + 58: {2**i: 58 ** (2**i) for i in range(4, 20)}, +} # type: Dict[int, Dict[int, int]] # Retro compatibility alphabet = BITCOIN_ALPHABET @@ -27,89 +37,161 @@ def scrub_input(v: Union[str, bytes]) -> bytes: if isinstance(v, str): - v = v.encode('ascii') + v = v.encode("ascii") return v +def _encode_int(i: int, base: int = 58, alphabet: bytes = BITCOIN_ALPHABET) -> bytes: + """ + Encode integer to bytes with base 58 alphabet by powers of 58 + """ + min_val = POWERS[base][2**8] + if i <= min_val: + string = bytearray() + while i: + i, idx = divmod(i, base) + string.append(idx) + return bytes(string[::-1]) + else: + origlen0 = int(log(i, 58)) // 2 + try: + split_num = POWERS[base][2**origlen0] + except KeyError: + POWERS[base][2**origlen0] = split_num = base**origlen0 + i1, i0 = divmod(i, split_num) + + v1 = _encode_int(i1, base, alphabet) + v0 = _encode_int(i0, base, alphabet) + newlen0 = len(v0) + if newlen0 < origlen0: + v0 = b"\0" * (origlen0 - newlen0) + v0 + return v1 + v0 + + +def _mpz_encode(i: int, alphabet: bytes) -> bytes: + """ + Encode an integer to arbitrary base using gmpy2 mpz + """ + base = len(alphabet) + + raw: bytes = mpz(i).digits(base).encode() + tr_bytes = bytes.maketrans(_MPZ_ALPHABET[:base], alphabet) + encoded: bytes = raw.translate(tr_bytes) + + return encoded + + def b58encode_int( i: int, default_one: bool = True, alphabet: bytes = BITCOIN_ALPHABET ) -> bytes: """ Encode an integer using Base58 """ - if not i and default_one: - return alphabet[0:1] - string = b"" + if not i: + if default_one: + return alphabet[0:1] + return b"" + if mpz: + return _mpz_encode(i, alphabet) + base = len(alphabet) - while i: - i, idx = divmod(i, base) - string = alphabet[idx:idx+1] + string + raw_string = _encode_int(i, base, alphabet) + string = raw_string.translate( + bytes.maketrans(bytearray(range(len(alphabet))), alphabet) + ) + return string -def b58encode( - v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET -) -> bytes: +def b58encode(v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET) -> bytes: """ Encode a string using Base58 """ v = scrub_input(v) origlen = len(v) - v = v.lstrip(b'\0') + v = v.lstrip(b"\0") newlen = len(v) - acc = int.from_bytes(v, byteorder='big') # first byte is most significant + acc = int.from_bytes(v, byteorder="big") # first byte is most significant result = b58encode_int(acc, default_one=False, alphabet=alphabet) return alphabet[0:1] * (origlen - newlen) + result @lru_cache() -def _get_base58_decode_map(alphabet: bytes, - autofix: bool) -> Mapping[int, int]: +def _get_base58_decode_map(alphabet: bytes, autofix: bool) -> Tuple[bytes, bytes]: invmap = {char: index for index, char in enumerate(alphabet)} - + base = len(alphabet) if autofix: - groups = [b'0Oo', b'Il1'] + groups = [b"0Oo", b"Il1"] for group in groups: pivots = [c for c in group if c in invmap] if len(pivots) == 1: for alternative in group: invmap[alternative] = invmap[pivots[0]] - return invmap + del_chars = bytes(bytearray(x for x in range(256) if x not in invmap)) + + if mpz is not None: + mpz_alphabet = "".join([mpz(x).digits(base) for x in invmap.values()]).encode() + tr_bytes = bytes.maketrans(bytearray(invmap.keys()), mpz_alphabet) + return tr_bytes, del_chars + + tr_bytes = bytes.maketrans(bytearray(invmap.keys()), bytearray(invmap.values())) + return tr_bytes, del_chars + + +def _decode(data: bytes, min_split: int = 256, base: int = 58) -> int: + """ + Decode larger data blocks recursively + """ + if len(data) <= min_split: + ret_int = 0 + for val in data: + ret_int = base * ret_int + val + return ret_int + else: + split_len = 2 ** (len(data).bit_length() - 2) + try: + base_pow = POWERS[base][split_len] + except KeyError: + POWERS[base] = base_pow = base**split_len + return (base_pow * _decode(data[:-split_len])) + _decode(data[-split_len:]) def b58decode_int( - v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, - autofix: bool = False + v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, autofix: bool = False ) -> int: """ Decode a Base58 encoded string as an integer """ - if b' ' not in alphabet: + if b" " not in alphabet: v = v.rstrip() v = scrub_input(v) - map = _get_base58_decode_map(alphabet, autofix=autofix) - - decimal = 0 base = len(alphabet) - try: - for char in v: - decimal = decimal * base + map[char] - except KeyError as e: - raise ValueError( - "Invalid character {!r}".format(chr(e.args[0])) - ) from None - return decimal + tr_bytes, del_chars = _get_base58_decode_map(alphabet, autofix=autofix) + cv = v.translate(tr_bytes, delete=del_chars) + if len(v) != len(cv): + err_char = chr(next(c for c in v if c in del_chars)) + raise ValueError("Invalid character {!r}".format(err_char)) + + if cv == b"": + return 0 + + if mpz: + try: + return int(mpz(cv, base=base)) + except ValueError: + raise ValueError(cv, base) + + return _decode(cv, base=base) def b58decode( - v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, - autofix: bool = False + v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, autofix: bool = False ) -> bytes: """ Decode a Base58 encoded string @@ -123,17 +205,10 @@ def b58decode( acc = b58decode_int(v, alphabet=alphabet, autofix=autofix) - result = [] - while acc > 0: - acc, mod = divmod(acc, 256) - result.append(mod) - - return b'\0' * (origlen - newlen) + bytes(reversed(result)) + return acc.to_bytes(origlen - newlen + (acc.bit_length() + 7) // 8, "big") -def b58encode_check( - v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET -) -> bytes: +def b58encode_check(v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET) -> bytes: """ Encode a string using Base58 with a 4 character checksum """ @@ -144,10 +219,9 @@ def b58encode_check( def b58decode_check( - v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, - autofix: bool = False + v: Union[str, bytes], alphabet: bytes = BITCOIN_ALPHABET, *, autofix: bool = False ) -> bytes: - '''Decode and verify the checksum of a Base58 encoded string''' + """Decode and verify the checksum of a Base58 encoded string""" result = b58decode(v, alphabet=alphabet, autofix=autofix) result, check = result[:-4], result[-4:] diff --git a/base58/__main__.py b/base58/__main__.py index e18a1bd..0c86c1b 100644 --- a/base58/__main__.py +++ b/base58/__main__.py @@ -20,24 +20,22 @@ def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - 'file', - metavar='FILE', - nargs='?', - type=argparse.FileType('r'), - help=( - "File to encode or decode. If no file is provided standard " - "input is used instead"), - default='-') + "file", + metavar="FILE", + nargs="?", + type=argparse.FileType("r"), + help="File to encode or decode. If no file is provided standard input is used instead", + default="-", + ) parser.add_argument( - '-d', '--decode', - action='store_true', - help="decode data instead of encoding") + "-d", "--decode", action="store_true", help="decode data instead of encoding" + ) parser.add_argument( - '-c', '--check', - action='store_true', - help=( - "calculate a checksum and append to encoded data or verify " - "existing checksum when decoding")) + "-c", + "--check", + action="store_true", + help="calculate a checksum and append to encoded data or verify existing checksum when decoding", + ) args = parser.parse_args() fun = _fmap[(args.decode, args.check)] @@ -52,5 +50,5 @@ def main() -> None: stdout.write(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/setup.cfg b/setup.cfg index a3f4a82..26f4166 100644 --- a/setup.cfg +++ b/setup.cfg @@ -57,3 +57,5 @@ ignore_missing_imports = True [mypy-pytest.*] ignore_missing_imports = True +[mypy-gmpy2.*] +ignore_missing_imports = True diff --git a/test_base45.py b/test_base45.py index 9b72901..b68badc 100644 --- a/test_base45.py +++ b/test_base45.py @@ -1,90 +1,89 @@ from hamcrest import assert_that, equal_to, calling, raises -from base58 import (b58encode, b58decode, b58encode_check, b58decode_check) +from base58 import b58encode, b58decode, b58encode_check, b58decode_check BASE45_ALPHABET = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ $%*+-./:" def test_simple_encode(): - data = b58encode(b'hello world', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'K3*J+EGLBVAYYB36')) + data = b58encode(b"hello world", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"K3*J+EGLBVAYYB36")) def test_leadingz_encode(): - data = b58encode(b'\0\0hello world', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'00K3*J+EGLBVAYYB36')) + data = b58encode(b"\0\0hello world", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"00K3*J+EGLBVAYYB36")) def test_encode_empty(): - data = b58encode(b'', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'')) + data = b58encode(b"", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"")) def test_simple_decode(): - data = b58decode('K3*J+EGLBVAYYB36', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'hello world')) + data = b58decode("K3*J+EGLBVAYYB36", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"hello world")) def test_simple_decode_bytes(): - data = b58decode(b'K3*J+EGLBVAYYB36', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'hello world')) + data = b58decode(b"K3*J+EGLBVAYYB36", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"hello world")) def test_autofix_decode_bytes(): - data = b58decode( - b'K3*J+EGLBVAYYB36', alphabet=BASE45_ALPHABET, autofix=True) - assert_that(data, equal_to(b'hello world')) + data = b58decode(b"K3*J+EGLBVAYYB36", alphabet=BASE45_ALPHABET, autofix=True) + assert_that(data, equal_to(b"hello world")) def test_leadingz_decode(): - data = b58decode('00K3*J+EGLBVAYYB36', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'\0\0hello world')) + data = b58decode("00K3*J+EGLBVAYYB36", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"\0\0hello world")) def test_leadingz_decode_bytes(): - data = b58decode(b'00K3*J+EGLBVAYYB36', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'\0\0hello world')) + data = b58decode(b"00K3*J+EGLBVAYYB36", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"\0\0hello world")) def test_empty_decode(): - data = b58decode('1', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'\x01')) + data = b58decode("1", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"\x01")) def test_empty_decode_bytes(): - data = b58decode(b'1', alphabet=BASE45_ALPHABET) - assert_that(data, equal_to(b'\x01')) + data = b58decode(b"1", alphabet=BASE45_ALPHABET) + assert_that(data, equal_to(b"\x01")) def test_check_str(): - data = 'hello world' + data = "hello world" out = b58encode_check(data, alphabet=BASE45_ALPHABET) - assert_that(out, equal_to(b'AHN49RN6G8B%AWUALA8K2D')) + assert_that(out, equal_to(b"AHN49RN6G8B%AWUALA8K2D")) back = b58decode_check(out, alphabet=BASE45_ALPHABET) - assert_that(back, equal_to(b'hello world')) + assert_that(back, equal_to(b"hello world")) def test_autofix_check_str(): - data = 'AHN49RN6G8B%AWUALA8K2D' + data = "AHN49RN6G8B%AWUALA8K2D" back = b58decode_check(data, alphabet=BASE45_ALPHABET, autofix=True) - assert_that(back, equal_to(b'hello world')) + assert_that(back, equal_to(b"hello world")) def test_autofix_not_applicable_check_str(): - charset = BASE45_ALPHABET.replace(b'x', b'l') - msg = b'hello world' + charset = BASE45_ALPHABET.replace(b"x", b"l") + msg = b"hello world" enc = b58encode_check(msg, alphabet=BASE45_ALPHABET) - modified = enc.replace(b'x', b'l').replace(b'o', b'0') + modified = enc.replace(b"x", b"l").replace(b"o", b"0") back = b58decode_check(modified, alphabet=charset, autofix=True) assert_that(back, equal_to(msg)) def test_check_failure(): - data = '3vQB7B6MrGQZaxCuFg4oH' + data = "3vQB7B6MrGQZaxCuFg4oH" assert_that(calling(b58decode_check).with_args(data), raises(ValueError)) def test_invalid_input(): - data = 'xyz0' # 0 is not part of the bitcoin base58 alphabet + data = "xyz0" # 0 is not part of the bitcoin base58 alphabet assert_that( - calling(b58decode).with_args(data), - raises(ValueError, "Invalid character '0'")) + calling(b58decode).with_args(data), raises(ValueError, "Invalid character '0'") + ) diff --git a/test_base58.py b/test_base58.py index d458c9a..65ea93f 100644 --- a/test_base58.py +++ b/test_base58.py @@ -3,10 +3,15 @@ from random import getrandbits from hamcrest import assert_that, equal_to, calling, raises from base58 import ( - b58encode, b58decode, b58encode_check, b58decode_check, b58encode_int, + b58encode, + b58decode, + b58encode_check, + b58decode_check, + b58encode_int, b58decode_int, BITCOIN_ALPHABET, - XRP_ALPHABET) + XRP_ALPHABET, +) BASE45_ALPHABET = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ $%*+-./:" @@ -18,99 +23,96 @@ def alphabet(request) -> str: def test_simple_encode(): - data = b58encode(b'hello world') - assert_that(data, equal_to(b'StV1DL6CwTryKyV')) + data = b58encode(b"hello world") + assert_that(data, equal_to(b"StV1DL6CwTryKyV")) def test_leadingz_encode(): - data = b58encode(b'\0\0hello world') - assert_that(data, equal_to(b'11StV1DL6CwTryKyV')) + data = b58encode(b"\0\0hello world") + assert_that(data, equal_to(b"11StV1DL6CwTryKyV")) def test_encode_empty(): - data = b58encode(b'') - assert_that(data, equal_to(b'')) + data = b58encode(b"") + assert_that(data, equal_to(b"")) def test_simple_decode(): - data = b58decode('StV1DL6CwTryKyV') - assert_that(data, equal_to(b'hello world')) + data = b58decode("StV1DL6CwTryKyV") + assert_that(data, equal_to(b"hello world")) def test_simple_decode_bytes(): - data = b58decode(b'StV1DL6CwTryKyV') - assert_that(data, equal_to(b'hello world')) + data = b58decode(b"StV1DL6CwTryKyV") + assert_that(data, equal_to(b"hello world")) def test_autofix_decode_bytes(): - data = b58decode(b'StVlDL6CwTryKyV', autofix=True) - assert_that(data, equal_to(b'hello world')) + data = b58decode(b"StVlDL6CwTryKyV", autofix=True) + assert_that(data, equal_to(b"hello world")) def test_leadingz_decode(): - data = b58decode('11StV1DL6CwTryKyV') - assert_that(data, equal_to(b'\0\0hello world')) + data = b58decode("11StV1DL6CwTryKyV") + assert_that(data, equal_to(b"\0\0hello world")) def test_leadingz_decode_bytes(): - data = b58decode(b'11StV1DL6CwTryKyV') - assert_that(data, equal_to(b'\0\0hello world')) + data = b58decode(b"11StV1DL6CwTryKyV") + assert_that(data, equal_to(b"\0\0hello world")) def test_empty_decode(): - data = b58decode('1') - assert_that(data, equal_to(b'\0')) + data = b58decode("1") + assert_that(data, equal_to(b"\0")) def test_empty_decode_bytes(): - data = b58decode(b'1') - assert_that(data, equal_to(b'\0')) + data = b58decode(b"1") + assert_that(data, equal_to(b"\0")) def test_check_str(): - data = 'hello world' + data = "hello world" out = b58encode_check(data) - assert_that(out, equal_to(b'3vQB7B6MrGQZaxCuFg4oh')) + assert_that(out, equal_to(b"3vQB7B6MrGQZaxCuFg4oh")) back = b58decode_check(out) - assert_that(back, equal_to(b'hello world')) + assert_that(back, equal_to(b"hello world")) def test_autofix_check_str(): - data = '3vQB7B6MrGQZaxCuFg4Oh' + data = "3vQB7B6MrGQZaxCuFg4Oh" back = b58decode_check(data, autofix=True) - assert_that(back, equal_to(b'hello world')) + assert_that(back, equal_to(b"hello world")) def test_autofix_not_applicable_check_str(): - charset = BITCOIN_ALPHABET.replace(b'x', b'l') - msg = b'hello world' - enc = b58encode_check(msg).replace(b'x', b'l').replace(b'o', b'0') + charset = BITCOIN_ALPHABET.replace(b"x", b"l") + msg = b"hello world" + enc = b58encode_check(msg).replace(b"x", b"l").replace(b"o", b"0") back = b58decode_check(enc, alphabet=charset, autofix=True) assert_that(back, equal_to(msg)) def test_check_failure(): - data = '3vQB7B6MrGQZaxCuFg4oH' + data = "3vQB7B6MrGQZaxCuFg4oH" assert_that(calling(b58decode_check).with_args(data), raises(ValueError)) def test_check_identity(alphabet): - data = b'hello world' - out = b58decode_check( - b58encode_check(data, alphabet=alphabet), - alphabet=alphabet - ) + data = b"hello world" + out = b58decode_check(b58encode_check(data, alphabet=alphabet), alphabet=alphabet) assert_that(out, equal_to(data)) def test_round_trips(alphabet): - possible_bytes = [b'\x00', b'\x01', b'\x10', b'\xff'] + possible_bytes = [b"\x00", b"\x01", b"\x10", b"\xff"] for length in range(0, 5): for bytes_to_test in product(possible_bytes, repeat=length): - bytes_in = b''.join(bytes_to_test) + bytes_in = b"".join(bytes_to_test) bytes_out = b58decode( - b58encode(bytes_in, alphabet=alphabet), - alphabet=alphabet) + b58encode(bytes_in, alphabet=alphabet), alphabet=alphabet + ) assert_that(bytes_in, equal_to(bytes_out)) @@ -122,28 +124,29 @@ def test_simple_integers(alphabet): def test_large_integer(): - number = 0x111d38e5fc9071ffcd20b4a763cc9ae4f252bb4e48fd66a835e252ada93ff480d6dd43dc62a641155a5 # noqa + number = 0x111D38E5FC9071FFCD20B4A763CC9AE4F252BB4E48FD66A835E252ADA93FF480D6DD43DC62A641155A5 # noqa assert_that(b58decode_int(BITCOIN_ALPHABET), equal_to(number)) assert_that(b58encode_int(number), equal_to(BITCOIN_ALPHABET[1:])) def test_invalid_input(): - data = 'xyz\b' # backspace is not part of the bitcoin base58 alphabet + data = "xyz\b" # backspace is not part of the bitcoin base58 alphabet assert_that( calling(b58decode).with_args(data), - raises(ValueError, "Invalid character '\\\\x08'")) + raises(ValueError, "Invalid character '\\\\x08'"), + ) -@pytest.mark.parametrize('length', [8, 32, 256, 1024]) +@pytest.mark.parametrize("length", [8, 32, 256, 1024, 8192]) def test_encode_random(benchmark, length) -> None: - data = getrandbits(length * 8).to_bytes(length, byteorder='big') + data = getrandbits(length * 8).to_bytes(length, byteorder="big") encoded = benchmark(lambda: b58encode(data)) assert_that(b58decode(encoded), equal_to(data)) -@pytest.mark.parametrize('length', [8, 32, 256, 1024]) +@pytest.mark.parametrize("length", [8, 32, 256, 1024, 8192]) def test_decode_random(benchmark, length) -> None: - origdata = getrandbits(length * 8).to_bytes(length, byteorder='big') + origdata = getrandbits(length * 8).to_bytes(length, byteorder="big") encoded = b58encode(origdata) data = benchmark(lambda: b58decode(encoded)) assert_that(data, equal_to(origdata))