Skip to content

Commit 7101a61

Browse files
authored
Add b64decode (#5)
1 parent 2d2cf83 commit 7101a61

9 files changed

Lines changed: 174 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ b'data to be encoded'
4040
| b64encode (100 KB data) | 0.307 | 0.325 | 0.318 | 0.047 (6.6x) | 0.061 (5.3x) | 0.050 (6.4x) |
4141
| b64encode (1 MB data) | 3.383 | 3.456 | 3.411 | 0.447 (7.6x) | 0.487 (7.1x) | 0.467 (7.3x) |
4242
| b64encode (altchars + 100 KB data) | 0.472 | 0.490 | 0.483 | 0.303 (1.6x) | 0.320 (1.5x) | 0.313 (1.5x) |
43+
| b64decode (100 KB data) | 0.512 | 0.569 | 0.538 | 0.110 (4.7x) | 0.125 (4.5x) | 0.117 (4.6x) |
4344

4445
## How to develop locally
4546

benchmarks/bench_encode.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,59 @@
44

55

66
ITERATIONS = 1_000
7+
78
SMALL_DATA = b"t" * 1_000 # 1 KB
89
MEDIUM_DATA = b"t" * 100_000 # 100 KB
910
LARGE_DATA = b"t" * 1_000_000 # 1 MB
1011

12+
SMALL_DATA_ENCODED = base64.b64encode(SMALL_DATA)
13+
MEDIUM_DATA_ENCODED = base64.b64encode(MEDIUM_DATA)
14+
LARGE_DATA_ENCODED = base64.b64encode(LARGE_DATA)
15+
1116

12-
def stdlib_base64encode(data, altchars=None) -> None:
17+
def stdlib_b64encode(data, altchars=None) -> None:
1318
for _ in range(ITERATIONS):
1419
base64.b64encode(data, altchars=altchars)
1520

1621

17-
def base64_utils_base64encode(data, altchars=None) -> None:
22+
def base64_utils_b64encode(data, altchars=None) -> None:
1823
for _ in range(ITERATIONS):
1924
base64_utils.b64encode(data, altchars=altchars)
2025

26+
def stdlib_b64decode(data, altchars=None, validate=False) -> None:
27+
for _ in range(ITERATIONS):
28+
base64.b64decode(data, altchars=altchars, validate=validate)
29+
30+
31+
def base64_utils_b64decode(data, altchars=None, validate=False) -> None:
32+
for _ in range(ITERATIONS):
33+
base64_utils.b64decode(data, altchars=altchars, validate=validate)
34+
2135

2236
__benchmarks__ = [
2337
(
24-
lambda: stdlib_base64encode(SMALL_DATA),
25-
lambda: base64_utils_base64encode(SMALL_DATA),
38+
lambda: stdlib_b64encode(SMALL_DATA),
39+
lambda: base64_utils_b64encode(SMALL_DATA),
2640
"b64encode (1 KB data)",
2741
),
2842
(
29-
lambda: stdlib_base64encode(MEDIUM_DATA),
30-
lambda: base64_utils_base64encode(MEDIUM_DATA),
43+
lambda: stdlib_b64encode(MEDIUM_DATA),
44+
lambda: base64_utils_b64encode(MEDIUM_DATA),
3145
"b64encode (100 KB data)",
3246
),
3347
(
34-
lambda: stdlib_base64encode(LARGE_DATA),
35-
lambda: base64_utils_base64encode(LARGE_DATA),
48+
lambda: stdlib_b64encode(LARGE_DATA),
49+
lambda: base64_utils_b64encode(LARGE_DATA),
3650
"b64encode (1 MB data)",
3751
),
3852
(
39-
lambda: stdlib_base64encode(MEDIUM_DATA, altchars=b"-_"),
40-
lambda: base64_utils_base64encode(MEDIUM_DATA, altchars=b"-_"),
53+
lambda: stdlib_b64encode(MEDIUM_DATA, altchars=b"-_"),
54+
lambda: base64_utils_b64encode(MEDIUM_DATA, altchars=b"-_"),
4155
"b64encode (altchars + 100 KB data)",
4256
),
57+
(
58+
lambda: stdlib_b64decode(MEDIUM_DATA_ENCODED),
59+
lambda: base64_utils_b64decode(MEDIUM_DATA_ENCODED),
60+
"b64decode (100 KB data)",
61+
)
4362
]

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ b'data to be encoded'
4747
| b64encode (100 KB data) | 0.307 | 0.325 | 0.318 | 0.047 (6.6x) | 0.061 (5.3x) | 0.050 (6.4x) |
4848
| b64encode (1 MB data) | 3.383 | 3.456 | 3.411 | 0.447 (7.6x) | 0.487 (7.1x) | 0.467 (7.3x) |
4949
| b64encode (altchars + 100 KB data) | 0.472 | 0.490 | 0.483 | 0.303 (1.6x) | 0.320 (1.5x) | 0.313 (1.5x) |
50+
| b64decode (100 KB data) | 0.512 | 0.569 | 0.538 | 0.110 (4.7x) | 0.125 (4.5x) | 0.117 (4.6x) |
5051

5152
## How to develop locally
5253

python/base64_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from ._base64_utils import (
2+
b64decode,
23
b64encode,
34
standard_b64encode,
45
urlsafe_b64encode,
56
)
67

78
__all__ = [
89
"b64encode",
10+
"b64decode",
911
"standard_b64encode",
1012
"urlsafe_b64encode",
1113
]

python/base64_utils/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ __version__: str
44

55
__all__ = [
66
"b64encode",
7+
"b64decode",
78
"standard_b64encode",
89
"urlsafe_b64encode",
910
]
1011

1112
def b64encode(s: ReadableBuffer, altchars: ReadableBuffer | None = None) -> bytes: ...
13+
def b64decode(
14+
s: str | ReadableBuffer,
15+
altchars: str | ReadableBuffer | None = None,
16+
validate: bool = False,
17+
) -> bytes: ...
1218
def standard_b64encode(s: ReadableBuffer) -> bytes: ...
1319
def urlsafe_b64encode(s: ReadableBuffer) -> bytes: ...

src/lib.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1-
use base64_simd::{Out, STANDARD, URL_SAFE};
1+
use base64_simd::{Out, STANDARD, URL_SAFE, forgiving_decode_inplace};
22
use pyo3::exceptions::PyValueError;
33
use pyo3::prelude::*;
44
use pyo3::types::PyBytes;
55

6+
#[derive(FromPyObject)]
7+
pub enum StringOrBytes {
8+
#[pyo3(transparent, annotation = "str")]
9+
String(String),
10+
#[pyo3(transparent, annotation = "bytes")]
11+
Bytes(Vec<u8>),
12+
}
13+
14+
impl StringOrBytes {
15+
fn into_bytes(self) -> Vec<u8> {
16+
match self {
17+
StringOrBytes::String(s) => s.into_bytes(),
18+
StringOrBytes::Bytes(b) => b,
19+
}
20+
}
21+
}
22+
623
#[pyfunction]
724
#[pyo3(signature = (s, altchars=None))]
825
pub fn b64encode(py: Python<'_>, s: &[u8], altchars: Option<&[u8]>) -> PyResult<Py<PyBytes>> {
@@ -37,6 +54,56 @@ pub fn b64encode(py: Python<'_>, s: &[u8], altchars: Option<&[u8]>) -> PyResult<
3754
Ok(output.into())
3855
}
3956

57+
#[pyfunction]
58+
#[pyo3(signature = (s, altchars=None, validate=false))]
59+
pub fn b64decode(
60+
py: Python<'_>,
61+
s: StringOrBytes,
62+
altchars: Option<StringOrBytes>,
63+
validate: bool,
64+
) -> PyResult<Py<PyBytes>> {
65+
let mut input: Vec<u8> = s.into_bytes();
66+
67+
if let Some(alt) = altchars {
68+
let bytes = alt.into_bytes();
69+
if bytes.len() != 2 {
70+
return Err(PyValueError::new_err(
71+
"altchars must be a bytes-like object of length 2",
72+
));
73+
}
74+
75+
for byte in input.iter_mut() {
76+
if *byte == bytes[0] {
77+
*byte = b'+';
78+
} else if *byte == bytes[1] {
79+
*byte = b'/';
80+
}
81+
}
82+
}
83+
84+
if validate {
85+
STANDARD
86+
.check(&input)
87+
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
88+
89+
let output_len = STANDARD
90+
.decoded_length(&input)
91+
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
92+
93+
let output: Bound<'_, PyBytes> = PyBytes::new_with(py, output_len, |buf| {
94+
STANDARD
95+
.decode(&input, Out::from_slice(buf))
96+
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
97+
Ok(())
98+
})?;
99+
Ok(output.into())
100+
} else {
101+
let output = forgiving_decode_inplace(&mut input)
102+
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
103+
Ok(PyBytes::new(py, output).into())
104+
}
105+
}
106+
40107
#[pyfunction]
41108
pub fn standard_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
42109
let output_len = STANDARD.encoded_length(s.len());
@@ -61,6 +128,7 @@ pub fn urlsafe_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
61128
fn _base64_utils(m: &Bound<'_, PyModule>) -> PyResult<()> {
62129
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
63130
m.add_function(wrap_pyfunction!(b64encode, m)?)?;
131+
m.add_function(wrap_pyfunction!(b64decode, m)?)?;
64132
m.add_function(wrap_pyfunction!(standard_b64encode, m)?)?;
65133
m.add_function(wrap_pyfunction!(urlsafe_b64encode, m)?)?;
66134
Ok(())

tests/test_base64_decode.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import base64
2+
3+
import base64_utils
4+
import pytest
5+
6+
7+
def test_b64decode() -> None:
8+
data = b"dGVzdA=="
9+
decoded = base64_utils.b64decode(data)
10+
expected = base64.b64decode(data)
11+
12+
assert isinstance(decoded, bytes)
13+
assert expected == decoded
14+
15+
16+
def test_b64decode_str() -> None:
17+
data = "dGVzdA=="
18+
decoded = base64_utils.b64decode(data)
19+
expected = base64.b64decode(data)
20+
21+
assert isinstance(decoded, bytes)
22+
assert expected == decoded
23+
24+
25+
def test_b64decode_altchars() -> None:
26+
data = b"dGVzdA+/"
27+
altchars = b"-_"
28+
decoded = base64_utils.b64decode(data, altchars=altchars)
29+
expected = base64.b64decode(data, altchars=altchars)
30+
31+
assert isinstance(decoded, bytes)
32+
assert expected == decoded
33+
34+
35+
def test_b64decode_altchars_invalid() -> None:
36+
with pytest.raises(ValueError):
37+
base64_utils.b64decode(b"dGVzdA+/", altchars=b"-")
38+
39+
40+
def test_b64decode_validate() -> None:
41+
data_with_spaces = b"dGVz dA==" # "test" with a space in the middle
42+
decoded = base64_utils.b64decode(data_with_spaces, validate=False)
43+
expected = base64.b64decode(data_with_spaces, validate=False)
44+
assert decoded == expected
45+
assert decoded == b"test"
46+
47+
with pytest.raises(ValueError):
48+
base64_utils.b64decode(data_with_spaces, validate=True)
49+
with pytest.raises(ValueError):
50+
base64.b64decode(data_with_spaces, validate=True)
51+
52+
53+
def test_b64decode_invalid_data() -> None:
54+
data = b"invalid_base64!!"
55+
with pytest.raises(ValueError):
56+
base64_utils.b64decode(data)

tests/test_benchmarks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,13 @@ def test_standard_b64encode() -> None:
2222
@pytest.mark.benchmark
2323
def test_urlsafe_b64encode() -> None:
2424
base64_utils.urlsafe_b64encode(b"test data")
25+
26+
27+
@pytest.mark.benchmark
28+
def test_b64decode() -> None:
29+
base64_utils.b64decode(b"dGVzdA==")
30+
31+
32+
@pytest.mark.benchmark
33+
def test_b64decode_str() -> None:
34+
base64_utils.b64decode("dGVzdA==")

0 commit comments

Comments
 (0)