Skip to content

Commit 31ed2e4

Browse files
machichimakylebarronmartindurant
authored
[FEAT] support df.to_parquet and df.read_parquet() (#165)
* feat: add write() for open() in fsspec * temp: upload with iterator * refactor: rename data_li to buffer * feat: buffered write in fsspec * fix: remove unused code * fix: assert mode is either rb or wb * fix: correctly detect file exist for read_parquet * run pre-commit * feat: split bucket name from path in fsspec _open * Update obstore/python/obstore/fsspec.py Co-authored-by: Martin Durant <[email protected]> * fix: move incorrect mode exception into else * fix: remove writer in init and add self.closed=True * fix: self._writer not exist error in close * fix: use info() in AbstractFileSystem * fix: typing and linting * feat: merge BufferedFileWrite/Read together * test: for write to parquet * docs: update docstring * Use underlying reader/writer methods where possible * Updated docs * Override `loc` property * loc setter to allow `__init__` --------- Co-authored-by: Kyle Barron <[email protected]> Co-authored-by: Martin Durant <[email protected]> Co-authored-by: Kyle Barron <[email protected]>
1 parent 34f6d30 commit 31ed2e4

File tree

2 files changed

+256
-26
lines changed

2 files changed

+256
-26
lines changed

obstore/python/obstore/fsspec.py

+232-24
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
integration.
3131
"""
3232

33-
# ruff: noqa: ANN401, PTH123, FBT001, FBT002
33+
# ruff: noqa: ANN401, EM102, PTH123, FBT001, FBT002, S101
3434

3535
from __future__ import annotations
3636

@@ -46,13 +46,13 @@
4646
import fsspec.spec
4747

4848
import obstore as obs
49-
from obstore import Bytes
49+
from obstore import open_reader, open_writer
5050
from obstore.store import from_url
5151

5252
if TYPE_CHECKING:
5353
from collections.abc import Coroutine, Iterable
5454

55-
from obstore import Bytes
55+
from obstore import Attributes, Bytes, ReadableFile, WritableFile
5656
from obstore.store import (
5757
AzureConfig,
5858
AzureConfigInput,
@@ -462,41 +462,246 @@ def _open(
462462
autocommit: Any = True, # noqa: ARG002
463463
cache_options: Any = None, # noqa: ARG002
464464
**kwargs: Any,
465-
) -> BufferedFileSimple:
465+
) -> BufferedFile:
466466
"""Return raw bytes-mode file-like from the file-system."""
467-
return BufferedFileSimple(self, path, mode, **kwargs)
467+
if mode not in ("wb", "rb"):
468+
err_msg = f"Only 'rb' and 'wb' modes supported, got: {mode}"
469+
raise ValueError(err_msg)
470+
471+
return BufferedFile(self, path, mode, **kwargs)
472+
468473

474+
class BufferedFile(fsspec.spec.AbstractBufferedFile):
475+
"""A buffered readable or writable file.
476+
477+
This is a wrapper around [`obstore.ReadableFile`][] and [`obstore.WritableFile`][].
478+
If you don't have a need to use the fsspec integration, you may be better served by
479+
using [`open_reader`][obstore.open_reader] or [`open_writer`][obstore.open_writer]
480+
directly.
481+
"""
469482

470-
class BufferedFileSimple(fsspec.spec.AbstractBufferedFile):
471-
"""Implementation of buffered file around `fsspec.spec.AbstractBufferedFile`."""
483+
mode: Literal["rb", "wb"]
484+
_reader: ReadableFile
485+
_writer: WritableFile
486+
_writer_loc: int
487+
"""Stream position.
472488
489+
Only defined for writers. We use the underlying rust stream position for reading.
490+
"""
491+
492+
@overload
473493
def __init__(
474494
self,
475495
fs: AsyncFsspecStore,
476496
path: str,
477-
mode: str = "rb",
497+
mode: Literal["rb"] = "rb",
498+
*,
499+
buffer_size: int = 1024 * 1024,
500+
**kwargs: Any,
501+
) -> None: ...
502+
@overload
503+
def __init__(
504+
self,
505+
fs: AsyncFsspecStore,
506+
path: str,
507+
mode: Literal["wb"],
508+
*,
509+
buffer_size: int = 10 * 1024 * 1024,
510+
attributes: Attributes | None = None,
511+
tags: dict[str, str] | None = None,
512+
**kwargs: Any,
513+
) -> None: ...
514+
def __init__( # noqa: PLR0913
515+
self,
516+
fs: AsyncFsspecStore,
517+
path: str,
518+
mode: Literal["rb", "wb"] = "rb",
519+
*,
520+
buffer_size: int | None = None,
521+
attributes: Attributes | None = None,
522+
tags: dict[str, str] | None = None,
478523
**kwargs: Any,
479524
) -> None:
480-
"""Create new buffered file."""
481-
if mode != "rb":
482-
raise ValueError("Only 'rb' mode is currently supported")
525+
"""Create new buffered file.
526+
527+
Args:
528+
fs: The underlying fsspec store to read from.
529+
path: The path within the store to use.
530+
mode: `"rb"` for a readable binary file or `"wb"` for a writable binary
531+
file. Defaults to "rb".
532+
533+
Keyword Args:
534+
attributes: Provide a set of `Attributes`. Only used when writing. Defaults
535+
to `None`.
536+
buffer_size: Up to `buffer_size` bytes will be buffered in memory. **When
537+
reading:** The minimum number of bytes to read in a single request.
538+
**When writing:** If `buffer_size` is exceeded, data will be uploaded
539+
as a multipart upload in chunks of `buffer_size`. Defaults to None.
540+
tags: Provide tags for this object. Only used when writing. Defaults to
541+
`None`.
542+
kwargs: Keyword arguments passed on to [`fsspec.spec.AbstractBufferedFile`][].
543+
544+
""" # noqa: E501
483545
super().__init__(fs, path, mode, **kwargs)
484546

485-
def read(self, length: int = -1) -> Any:
547+
bucket, path = fs._split_path(path) # noqa: SLF001
548+
store = fs._construct_store(bucket) # noqa: SLF001
549+
550+
self.mode = mode
551+
552+
if self.mode == "rb":
553+
buffer_size = 1024 * 1024 if buffer_size is None else buffer_size
554+
self._reader = open_reader(store, path, buffer_size=buffer_size)
555+
elif self.mode == "wb":
556+
buffer_size = 10 * 1024 * 1024 if buffer_size is None else buffer_size
557+
self._writer = open_writer(
558+
store,
559+
path,
560+
attributes=attributes,
561+
buffer_size=buffer_size,
562+
tags=tags,
563+
)
564+
565+
self._writer_loc = 0
566+
else:
567+
raise ValueError(f"Invalid mode: {mode}")
568+
569+
def read(self, length: int = -1) -> bytes:
486570
"""Return bytes from the remote file.
487571
488572
Args:
489573
length: if positive, returns up to this many bytes; if negative, return all
490574
remaining bytes.
491575
576+
Returns:
577+
Data in bytes
578+
492579
"""
580+
if self.mode != "rb":
581+
raise ValueError("File not in read mode")
493582
if length < 0:
494-
data = self.fs.cat_file(self.path, self.loc, self.size)
495-
self.loc = self.size
496-
else:
497-
data = self.fs.cat_file(self.path, self.loc, self.loc + length)
498-
self.loc += length
499-
return data
583+
length = self.size - self.tell()
584+
if self.closed:
585+
raise ValueError("I/O operation on closed file.")
586+
if length == 0:
587+
# don't even bother calling fetch
588+
return b""
589+
590+
out = self._reader.read(length)
591+
return out.to_bytes()
592+
593+
def readline(self) -> bytes:
594+
"""Read until first occurrence of newline character."""
595+
if self.mode != "rb":
596+
raise ValueError("File not in read mode")
597+
598+
out = self._reader.readline()
599+
return out.to_bytes()
600+
601+
def readlines(self) -> list[bytes]:
602+
"""Return all data, split by the newline character."""
603+
if self.mode != "rb":
604+
raise ValueError("File not in read mode")
605+
606+
out = self._reader.readlines()
607+
return [b.to_bytes() for b in out]
608+
609+
def tell(self) -> int:
610+
"""Get current file location."""
611+
if self.mode == "rb":
612+
return self._reader.tell()
613+
614+
if self.mode == "wb":
615+
# There's no way to get the stream position from the underlying writer
616+
# because it's async. Here we happen to be using the async writer in a
617+
# synchronous way, so we keep our own stream position.
618+
assert self._writer_loc is not None
619+
return self._writer_loc
620+
621+
raise ValueError(f"Unexpected mode {self.mode}")
622+
623+
def seek(self, loc: int, whence: int = 0) -> int:
624+
"""Set current file location.
625+
626+
Args:
627+
loc: byte location
628+
whence: Either
629+
- `0`: from start of file
630+
- `1`: current location
631+
- `2`: end of file
632+
633+
"""
634+
if self.mode != "rb":
635+
raise ValueError("Seek only available in read mode.")
636+
637+
return self._reader.seek(loc, whence)
638+
639+
def write(self, data: bytes) -> int:
640+
"""Write data to buffer.
641+
642+
Args:
643+
data: Set of bytes to be written.
644+
645+
"""
646+
if not self.writable():
647+
raise ValueError("File not in write mode")
648+
if self.closed:
649+
raise ValueError("I/O operation on closed file.")
650+
if self.forced:
651+
raise ValueError("This file has been force-flushed, can only close")
652+
653+
num_written = self._writer.write(data)
654+
self._writer_loc += num_written
655+
656+
return num_written
657+
658+
def flush(
659+
self,
660+
force: bool = False, # noqa: ARG002
661+
) -> None:
662+
"""Write buffered data to backend store.
663+
664+
Writes the current buffer, if it is larger than the block-size, or if
665+
the file is being closed.
666+
667+
Args:
668+
force: Unused.
669+
670+
"""
671+
if self.closed:
672+
raise ValueError("Flush on closed file")
673+
674+
if self.readable():
675+
# no-op to flush on read-mode
676+
return
677+
678+
self._writer.flush()
679+
680+
def close(self) -> None:
681+
"""Close file. Ensure flushing the buffer."""
682+
if self.closed:
683+
return
684+
685+
try:
686+
if self.mode == "rb":
687+
self._reader.close()
688+
else:
689+
self.flush(force=True)
690+
self._writer.close()
691+
finally:
692+
self.closed = True
693+
694+
@property
695+
def loc(self) -> int:
696+
"""Get current file location."""
697+
# Note, we override the `loc` attribute, because for the reader we manage that
698+
# state in Rust.
699+
return self.tell()
700+
701+
@loc.setter
702+
def loc(self, value: int) -> None:
703+
if value != 0:
704+
raise ValueError("Cannot set `.loc`. Use `seek` instead.")
500705

501706

502707
def register(protocol: str | Iterable[str], *, asynchronous: bool = False) -> None:
@@ -513,14 +718,16 @@ def register(protocol: str | Iterable[str], *, asynchronous: bool = False) -> No
513718
asynchronous operations. Defaults to False.
514719
515720
Example:
516-
>>> register("s3")
517-
>>> register("s3", asynchronous=True) # Registers an async store for "s3"
518-
>>> register(["gcs", "abfs"]) # Registers both "gcs" and "abfs"
721+
```py
722+
register("s3")
723+
register("s3", asynchronous=True) # Registers an async store for "s3"
724+
register(["gcs", "abfs"]) # Registers both "gcs" and "abfs"
725+
```
519726
520727
Notes:
521728
- Each protocol gets a dynamically generated subclass named
522-
`AsyncFsspecStore_<protocol>`.
523-
- This avoids modifying the original AsyncFsspecStore class.
729+
`AsyncFsspecStore_<protocol>`. This avoids modifying the original
730+
AsyncFsspecStore class.
524731
525732
"""
526733
if isinstance(protocol, str):
@@ -542,5 +749,6 @@ def _register(protocol: str, *, asynchronous: bool) -> None:
542749
"asynchronous": asynchronous,
543750
}, # Assign protocol dynamically
544751
),
545-
clobber=False,
752+
# Override any existing implementations of the same protocol
753+
clobber=True,
546754
)

tests/test_fsspec.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,38 @@ async def test_list_async(s3_store_config: S3Config):
225225

226226

227227
@pytest.mark.network
228-
def test_remote_parquet():
229-
register("https")
228+
def test_remote_parquet(s3_store_config: S3Config):
229+
register(["https", "s3"])
230230
fs = fsspec.filesystem("https")
231+
fs_s3 = fsspec.filesystem(
232+
"s3",
233+
config=s3_store_config,
234+
client_options={"allow_http": True},
235+
)
236+
231237
url = "github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" # noqa: E501
232238
pq.read_metadata(url, filesystem=fs)
233239

234240
# also test with full url
235241
url = "https://github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet"
236242
pq.read_metadata(url, filesystem=fs)
237243

244+
# Read the remote Parquet file into a PyArrow table
245+
table = pq.read_table(url, filesystem=fs)
246+
write_parquet_path = f"{TEST_BUCKET_NAME}/test.parquet"
247+
248+
# Write the table to s3
249+
pq.write_table(table, write_parquet_path, filesystem=fs_s3)
250+
251+
out = fs_s3.ls(f"{TEST_BUCKET_NAME}", detail=False)
252+
assert f"{TEST_BUCKET_NAME}/test.parquet" in out
253+
254+
# Read Parquet file from s3 and verify its contents
255+
parquet_table = pq.read_table(write_parquet_path, filesystem=fs_s3)
256+
assert parquet_table.equals(table), (
257+
"Parquet file contents from s3 do not match the original file"
258+
)
259+
238260

239261
def test_multi_file_ops(fs: AsyncFsspecStore):
240262
data = {

0 commit comments

Comments
 (0)