Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/base64_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from ._base64_utils import (
b64decode,
b64encode,
standard_b64decode,
standard_b64encode,
urlsafe_b64decode,
urlsafe_b64encode,
)

__all__ = [
"b64encode",
"b64decode",
"b64encode",
"standard_b64decode",
"standard_b64encode",
"urlsafe_b64decode",
"urlsafe_b64encode",
]
8 changes: 6 additions & 2 deletions python/base64_utils/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@ from _typeshed import ReadableBuffer
__version__: str

__all__ = [
"b64encode",
"b64decode",
"b64encode",
"standard_b64decode",
"standard_b64encode",
"urlsafe_b64decode",
"urlsafe_b64encode",
]

def b64encode(s: ReadableBuffer, altchars: ReadableBuffer | None = None) -> bytes: ...
def b64decode(
s: str | ReadableBuffer,
altchars: str | ReadableBuffer | None = None,
validate: bool = False,
) -> bytes: ...
def b64encode(s: ReadableBuffer, altchars: ReadableBuffer | None = None) -> bytes: ...
def standard_b64decode(s: str | ReadableBuffer) -> bytes: ...
def standard_b64encode(s: ReadableBuffer) -> bytes: ...
def urlsafe_b64decode(s: str | ReadableBuffer) -> bytes: ...
def urlsafe_b64encode(s: ReadableBuffer) -> bytes: ...
97 changes: 97 additions & 0 deletions src/decoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use base64_simd::{Out, STANDARD, URL_SAFE, forgiving_decode_inplace};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;

#[derive(FromPyObject)]
pub enum StringOrBytes {
#[pyo3(transparent, annotation = "str")]
String(String),
#[pyo3(transparent, annotation = "bytes")]
Bytes(Vec<u8>),
}

impl StringOrBytes {
fn into_bytes(self) -> Vec<u8> {
match self {
StringOrBytes::String(s) => s.into_bytes(),
StringOrBytes::Bytes(b) => b,
}
}
}

#[pyfunction]
#[pyo3(signature = (s, altchars=None, validate=false))]
pub fn b64decode(
py: Python<'_>,
s: StringOrBytes,
altchars: Option<StringOrBytes>,
validate: bool,
) -> PyResult<Py<PyBytes>> {
let mut input: Vec<u8> = s.into_bytes();

if let Some(alt) = altchars {
let bytes = alt.into_bytes();
if bytes.len() != 2 {
return Err(PyValueError::new_err(
"altchars must be a bytes-like object of length 2",
));
}

for byte in input.iter_mut() {
if *byte == bytes[0] {
*byte = b'+';
} else if *byte == bytes[1] {
*byte = b'/';
}
}
}

if validate {
STANDARD
.check(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;

let output_len = STANDARD
.decoded_length(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;

let output: Bound<'_, PyBytes> = PyBytes::new_with(py, output_len, |buf| {
STANDARD
.decode(&input, Out::from_slice(buf))
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
Ok(())
})?;
Ok(output.into())
} else {
let output = forgiving_decode_inplace(&mut input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
Ok(PyBytes::new(py, output).into())
}
}

#[pyfunction]
pub fn standard_b64decode(py: Python<'_>, s: StringOrBytes) -> PyResult<Py<PyBytes>> {
let input: Vec<u8> = s.into_bytes();
let output_len = STANDARD
.decoded_length(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.decode(&input, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}

#[pyfunction]
pub fn urlsafe_b64decode(py: Python<'_>, s: StringOrBytes) -> PyResult<Py<PyBytes>> {
let input: Vec<u8> = s.into_bytes();
let output_len = URL_SAFE
.decoded_length(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = URL_SAFE.decode(&input, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}
58 changes: 58 additions & 0 deletions src/encoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use base64_simd::{Out, STANDARD, URL_SAFE};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;

#[pyfunction]
#[pyo3(signature = (s, altchars=None))]
pub fn b64encode(py: Python<'_>, s: &[u8], altchars: Option<&[u8]>) -> PyResult<Py<PyBytes>> {
let output_len = STANDARD.encoded_length(s.len());

if let Some(alt) = altchars {
if alt.len() != 2 {
return Err(PyValueError::new_err(
"altchars must be a bytes-like object of length 2",
));
}

let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));

for byte in buf.iter_mut() {
*byte = match *byte {
b'+' => alt[0],
b'/' => alt[1],
b => b,
};
}
Ok(())
})?;
return Ok(output.into());
}

let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}

#[pyfunction]
pub fn standard_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
let output_len = STANDARD.encoded_length(s.len());
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}

#[pyfunction]
pub fn urlsafe_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
let output_len = URL_SAFE.encoded_length(s.len());
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = URL_SAFE.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}
135 changes: 8 additions & 127 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,135 +1,16 @@
use base64_simd::{Out, STANDARD, URL_SAFE, forgiving_decode_inplace};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;

#[derive(FromPyObject)]
pub enum StringOrBytes {
#[pyo3(transparent, annotation = "str")]
String(String),
#[pyo3(transparent, annotation = "bytes")]
Bytes(Vec<u8>),
}

impl StringOrBytes {
fn into_bytes(self) -> Vec<u8> {
match self {
StringOrBytes::String(s) => s.into_bytes(),
StringOrBytes::Bytes(b) => b,
}
}
}

#[pyfunction]
#[pyo3(signature = (s, altchars=None))]
pub fn b64encode(py: Python<'_>, s: &[u8], altchars: Option<&[u8]>) -> PyResult<Py<PyBytes>> {
let output_len = STANDARD.encoded_length(s.len());

if let Some(alt) = altchars {
if alt.len() != 2 {
return Err(PyValueError::new_err(
"altchars must be a bytes-like object of length 2",
));
}

let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));

for byte in buf.iter_mut() {
*byte = match *byte {
b'+' => alt[0],
b'/' => alt[1],
b => b,
};
}
Ok(())
})?;
return Ok(output.into());
}

let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}

#[pyfunction]
#[pyo3(signature = (s, altchars=None, validate=false))]
pub fn b64decode(
py: Python<'_>,
s: StringOrBytes,
altchars: Option<StringOrBytes>,
validate: bool,
) -> PyResult<Py<PyBytes>> {
let mut input: Vec<u8> = s.into_bytes();

if let Some(alt) = altchars {
let bytes = alt.into_bytes();
if bytes.len() != 2 {
return Err(PyValueError::new_err(
"altchars must be a bytes-like object of length 2",
));
}

for byte in input.iter_mut() {
if *byte == bytes[0] {
*byte = b'+';
} else if *byte == bytes[1] {
*byte = b'/';
}
}
}

if validate {
STANDARD
.check(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;

let output_len = STANDARD
.decoded_length(&input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;

let output: Bound<'_, PyBytes> = PyBytes::new_with(py, output_len, |buf| {
STANDARD
.decode(&input, Out::from_slice(buf))
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
Ok(())
})?;
Ok(output.into())
} else {
let output = forgiving_decode_inplace(&mut input)
.map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?;
Ok(PyBytes::new(py, output).into())
}
}

#[pyfunction]
pub fn standard_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
let output_len = STANDARD.encoded_length(s.len());
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = STANDARD.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}

#[pyfunction]
pub fn urlsafe_b64encode(py: Python<'_>, s: &[u8]) -> PyResult<Py<PyBytes>> {
let output_len = URL_SAFE.encoded_length(s.len());
let output = PyBytes::new_with(py, output_len, |buf| {
let _ = URL_SAFE.encode(s, Out::from_slice(buf));
Ok(())
})?;
Ok(output.into())
}
mod decoder;
mod encoder;

#[pymodule]
fn _base64_utils(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(b64encode, m)?)?;
m.add_function(wrap_pyfunction!(b64decode, m)?)?;
m.add_function(wrap_pyfunction!(standard_b64encode, m)?)?;
m.add_function(wrap_pyfunction!(urlsafe_b64encode, m)?)?;
m.add_function(wrap_pyfunction!(decoder::b64decode, m)?)?;
m.add_function(wrap_pyfunction!(decoder::standard_b64decode, m)?)?;
m.add_function(wrap_pyfunction!(decoder::urlsafe_b64decode, m)?)?;
m.add_function(wrap_pyfunction!(encoder::b64encode, m)?)?;
m.add_function(wrap_pyfunction!(encoder::standard_b64encode, m)?)?;
m.add_function(wrap_pyfunction!(encoder::urlsafe_b64encode, m)?)?;
Ok(())
}
20 changes: 20 additions & 0 deletions tests/test_base64_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,23 @@ def test_b64decode_invalid_data() -> None:
data = b"invalid_base64!!"
with pytest.raises(ValueError):
base64_utils.b64decode(data)


def test_standard_b64decode() -> None:
data = base64.standard_b64encode(b"some data")

decoded = base64_utils.standard_b64decode(data)
expected = base64.standard_b64decode(data)

assert isinstance(decoded, bytes)
assert expected == decoded


def test_urlsafe_b64decode() -> None:
data = base64.urlsafe_b64encode(b"some data")

decoded = base64_utils.urlsafe_b64decode(data)
expected = base64.urlsafe_b64decode(data)

assert isinstance(decoded, bytes)
assert expected == decoded
10 changes: 10 additions & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ def test_b64decode() -> None:
@pytest.mark.benchmark
def test_b64decode_str() -> None:
base64_utils.b64decode("dGVzdA==")


@pytest.mark.benchmark
def test_standard_b64decode() -> None:
base64_utils.standard_b64decode(b"dGVzdA==")


@pytest.mark.benchmark
def test_urlsafe_b64decode() -> None:
base64_utils.urlsafe_b64decode(b"dGVzdA==")