Skip to content

Commit aaa8d6a

Browse files
authored
Enable pickling Bytes (#295)
* Support pickling bytes * Add test * Enable hashable * revert hash
1 parent 2bd5686 commit aaa8d6a

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

obstore/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ fn _obstore(py: Python, m: &Bound<PyModule>) -> PyResult<()> {
5555
pyo3_object_store::register_exceptions_module(py, m, "obstore")?;
5656

5757
m.add_class::<pyo3_bytes::PyBytes>()?;
58+
// Set the value of `__module__` correctly on PyBytes
59+
m.getattr("Bytes")?.setattr("__module__", "obstore")?;
60+
5861
m.add_wrapped(wrap_pyfunction!(buffered::open_reader))?;
5962
m.add_wrapped(wrap_pyfunction!(buffered::open_reader_async))?;
6063
m.add_wrapped(wrap_pyfunction!(buffered::open_writer))?;

pyo3-bytes/src/bytes.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use bytes::{Bytes, BytesMut};
88
use pyo3::buffer::PyBuffer;
99
use pyo3::exceptions::{PyIndexError, PyValueError};
1010
use pyo3::prelude::*;
11-
use pyo3::types::PySlice;
11+
use pyo3::types::{PyDict, PySlice, PyTuple};
1212
use pyo3::{ffi, IntoPyObjectExt};
1313

1414
/// A wrapper around a [`bytes::Bytes`][].
@@ -161,6 +161,13 @@ impl PyBytes {
161161
buf
162162
}
163163

164+
fn __getnewargs_ex__(&self, py: Python) -> PyResult<PyObject> {
165+
let py_bytes = self.to_bytes(py);
166+
let args = PyTuple::new(py, vec![py_bytes])?.into_py_any(py)?;
167+
let kwargs = PyDict::new(py);
168+
PyTuple::new(py, [args, kwargs.into_py_any(py)?])?.into_py_any(py)
169+
}
170+
164171
/// The number of bytes in this Bytes
165172
fn __len__(&self) -> usize {
166173
self.0.len()

tests/test_bytes.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import pickle
34
from typing import TYPE_CHECKING
45

56
import pytest
@@ -93,3 +94,8 @@ def _bytes_slices(
9394
for stop in indices_range
9495
for step in steps
9596
)
97+
98+
99+
def test_pickle():
100+
b = Bytes(b"hello_world")
101+
assert b == pickle.loads(pickle.dumps(b))

0 commit comments

Comments
 (0)