diff --git a/obstore/python/obstore/fsspec.py b/obstore/python/obstore/fsspec.py index d704860a..3cff4616 100644 --- a/obstore/python/obstore/fsspec.py +++ b/obstore/python/obstore/fsspec.py @@ -30,7 +30,7 @@ integration. """ -# ruff: noqa: ANN401, PTH123, FBT001, FBT002 +# ruff: noqa: ANN401, EM102, PTH123, FBT001, FBT002, S101 from __future__ import annotations @@ -46,13 +46,13 @@ import fsspec.spec import obstore as obs -from obstore import Bytes +from obstore import open_reader, open_writer from obstore.store import from_url if TYPE_CHECKING: from collections.abc import Coroutine, Iterable - from obstore import Bytes + from obstore import Attributes, Bytes, ReadableFile, WritableFile from obstore.store import ( AzureConfig, AzureConfigInput, @@ -462,41 +462,246 @@ def _open( autocommit: Any = True, # noqa: ARG002 cache_options: Any = None, # noqa: ARG002 **kwargs: Any, - ) -> BufferedFileSimple: + ) -> BufferedFile: """Return raw bytes-mode file-like from the file-system.""" - return BufferedFileSimple(self, path, mode, **kwargs) + if mode not in ("wb", "rb"): + err_msg = f"Only 'rb' and 'wb' modes supported, got: {mode}" + raise ValueError(err_msg) + + return BufferedFile(self, path, mode, **kwargs) + +class BufferedFile(fsspec.spec.AbstractBufferedFile): + """A buffered readable or writable file. + + This is a wrapper around [`obstore.ReadableFile`][] and [`obstore.WritableFile`][]. + If you don't have a need to use the fsspec integration, you may be better served by + using [`open_reader`][obstore.open_reader] or [`open_writer`][obstore.open_writer] + directly. + """ -class BufferedFileSimple(fsspec.spec.AbstractBufferedFile): - """Implementation of buffered file around `fsspec.spec.AbstractBufferedFile`.""" + mode: Literal["rb", "wb"] + _reader: ReadableFile + _writer: WritableFile + _writer_loc: int + """Stream position. + Only defined for writers. We use the underlying rust stream position for reading. + """ + + @overload def __init__( self, fs: AsyncFsspecStore, path: str, - mode: str = "rb", + mode: Literal["rb"] = "rb", + *, + buffer_size: int = 1024 * 1024, + **kwargs: Any, + ) -> None: ... + @overload + def __init__( + self, + fs: AsyncFsspecStore, + path: str, + mode: Literal["wb"], + *, + buffer_size: int = 10 * 1024 * 1024, + attributes: Attributes | None = None, + tags: dict[str, str] | None = None, + **kwargs: Any, + ) -> None: ... + def __init__( # noqa: PLR0913 + self, + fs: AsyncFsspecStore, + path: str, + mode: Literal["rb", "wb"] = "rb", + *, + buffer_size: int | None = None, + attributes: Attributes | None = None, + tags: dict[str, str] | None = None, **kwargs: Any, ) -> None: - """Create new buffered file.""" - if mode != "rb": - raise ValueError("Only 'rb' mode is currently supported") + """Create new buffered file. + + Args: + fs: The underlying fsspec store to read from. + path: The path within the store to use. + mode: `"rb"` for a readable binary file or `"wb"` for a writable binary + file. Defaults to "rb". + + Keyword Args: + attributes: Provide a set of `Attributes`. Only used when writing. Defaults + to `None`. + buffer_size: Up to `buffer_size` bytes will be buffered in memory. **When + reading:** The minimum number of bytes to read in a single request. + **When writing:** If `buffer_size` is exceeded, data will be uploaded + as a multipart upload in chunks of `buffer_size`. Defaults to None. + tags: Provide tags for this object. Only used when writing. Defaults to + `None`. + kwargs: Keyword arguments passed on to [`fsspec.spec.AbstractBufferedFile`][]. + + """ # noqa: E501 super().__init__(fs, path, mode, **kwargs) - def read(self, length: int = -1) -> Any: + bucket, path = fs._split_path(path) # noqa: SLF001 + store = fs._construct_store(bucket) # noqa: SLF001 + + self.mode = mode + + if self.mode == "rb": + buffer_size = 1024 * 1024 if buffer_size is None else buffer_size + self._reader = open_reader(store, path, buffer_size=buffer_size) + elif self.mode == "wb": + buffer_size = 10 * 1024 * 1024 if buffer_size is None else buffer_size + self._writer = open_writer( + store, + path, + attributes=attributes, + buffer_size=buffer_size, + tags=tags, + ) + + self._writer_loc = 0 + else: + raise ValueError(f"Invalid mode: {mode}") + + def read(self, length: int = -1) -> bytes: """Return bytes from the remote file. Args: length: if positive, returns up to this many bytes; if negative, return all remaining bytes. + Returns: + Data in bytes + """ + if self.mode != "rb": + raise ValueError("File not in read mode") if length < 0: - data = self.fs.cat_file(self.path, self.loc, self.size) - self.loc = self.size - else: - data = self.fs.cat_file(self.path, self.loc, self.loc + length) - self.loc += length - return data + length = self.size - self.tell() + if self.closed: + raise ValueError("I/O operation on closed file.") + if length == 0: + # don't even bother calling fetch + return b"" + + out = self._reader.read(length) + return out.to_bytes() + + def readline(self) -> bytes: + """Read until first occurrence of newline character.""" + if self.mode != "rb": + raise ValueError("File not in read mode") + + out = self._reader.readline() + return out.to_bytes() + + def readlines(self) -> list[bytes]: + """Return all data, split by the newline character.""" + if self.mode != "rb": + raise ValueError("File not in read mode") + + out = self._reader.readlines() + return [b.to_bytes() for b in out] + + def tell(self) -> int: + """Get current file location.""" + if self.mode == "rb": + return self._reader.tell() + + if self.mode == "wb": + # There's no way to get the stream position from the underlying writer + # because it's async. Here we happen to be using the async writer in a + # synchronous way, so we keep our own stream position. + assert self._writer_loc is not None + return self._writer_loc + + raise ValueError(f"Unexpected mode {self.mode}") + + def seek(self, loc: int, whence: int = 0) -> int: + """Set current file location. + + Args: + loc: byte location + whence: Either + - `0`: from start of file + - `1`: current location + - `2`: end of file + + """ + if self.mode != "rb": + raise ValueError("Seek only available in read mode.") + + return self._reader.seek(loc, whence) + + def write(self, data: bytes) -> int: + """Write data to buffer. + + Args: + data: Set of bytes to be written. + + """ + if not self.writable(): + raise ValueError("File not in write mode") + if self.closed: + raise ValueError("I/O operation on closed file.") + if self.forced: + raise ValueError("This file has been force-flushed, can only close") + + num_written = self._writer.write(data) + self._writer_loc += num_written + + return num_written + + def flush( + self, + force: bool = False, # noqa: ARG002 + ) -> None: + """Write buffered data to backend store. + + Writes the current buffer, if it is larger than the block-size, or if + the file is being closed. + + Args: + force: Unused. + + """ + if self.closed: + raise ValueError("Flush on closed file") + + if self.readable(): + # no-op to flush on read-mode + return + + self._writer.flush() + + def close(self) -> None: + """Close file. Ensure flushing the buffer.""" + if self.closed: + return + + try: + if self.mode == "rb": + self._reader.close() + else: + self.flush(force=True) + self._writer.close() + finally: + self.closed = True + + @property + def loc(self) -> int: + """Get current file location.""" + # Note, we override the `loc` attribute, because for the reader we manage that + # state in Rust. + return self.tell() + + @loc.setter + def loc(self, value: int) -> None: + if value != 0: + raise ValueError("Cannot set `.loc`. Use `seek` instead.") def register(protocol: str | Iterable[str], *, asynchronous: bool = False) -> None: @@ -513,14 +718,16 @@ def register(protocol: str | Iterable[str], *, asynchronous: bool = False) -> No asynchronous operations. Defaults to False. Example: - >>> register("s3") - >>> register("s3", asynchronous=True) # Registers an async store for "s3" - >>> register(["gcs", "abfs"]) # Registers both "gcs" and "abfs" + ```py + register("s3") + register("s3", asynchronous=True) # Registers an async store for "s3" + register(["gcs", "abfs"]) # Registers both "gcs" and "abfs" + ``` Notes: - Each protocol gets a dynamically generated subclass named - `AsyncFsspecStore_`. - - This avoids modifying the original AsyncFsspecStore class. + `AsyncFsspecStore_`. This avoids modifying the original + AsyncFsspecStore class. """ if isinstance(protocol, str): @@ -542,5 +749,6 @@ def _register(protocol: str, *, asynchronous: bool) -> None: "asynchronous": asynchronous, }, # Assign protocol dynamically ), - clobber=False, + # Override any existing implementations of the same protocol + clobber=True, ) diff --git a/tests/test_fsspec.py b/tests/test_fsspec.py index 629f04bc..e94e808e 100644 --- a/tests/test_fsspec.py +++ b/tests/test_fsspec.py @@ -225,9 +225,15 @@ async def test_list_async(s3_store_config: S3Config): @pytest.mark.network -def test_remote_parquet(): - register("https") +def test_remote_parquet(s3_store_config: S3Config): + register(["https", "s3"]) fs = fsspec.filesystem("https") + fs_s3 = fsspec.filesystem( + "s3", + config=s3_store_config, + client_options={"allow_http": True}, + ) + url = "github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" # noqa: E501 pq.read_metadata(url, filesystem=fs) @@ -235,6 +241,22 @@ def test_remote_parquet(): url = "https://github.com/opengeospatial/geoparquet/raw/refs/heads/main/examples/example.parquet" pq.read_metadata(url, filesystem=fs) + # Read the remote Parquet file into a PyArrow table + table = pq.read_table(url, filesystem=fs) + write_parquet_path = f"{TEST_BUCKET_NAME}/test.parquet" + + # Write the table to s3 + pq.write_table(table, write_parquet_path, filesystem=fs_s3) + + out = fs_s3.ls(f"{TEST_BUCKET_NAME}", detail=False) + assert f"{TEST_BUCKET_NAME}/test.parquet" in out + + # Read Parquet file from s3 and verify its contents + parquet_table = pq.read_table(write_parquet_path, filesystem=fs_s3) + assert parquet_table.equals(table), ( + "Parquet file contents from s3 do not match the original file" + ) + def test_multi_file_ops(fs: AsyncFsspecStore): data = {