30
30
integration.
31
31
"""
32
32
33
- # ruff: noqa: ANN401, PTH123, FBT001, FBT002
33
+ # ruff: noqa: ANN401, EM102, PTH123, FBT001, FBT002, S101
34
34
35
35
from __future__ import annotations
36
36
46
46
import fsspec .spec
47
47
48
48
import obstore as obs
49
- from obstore import Bytes
49
+ from obstore import open_reader , open_writer
50
50
from obstore .store import from_url
51
51
52
52
if TYPE_CHECKING :
53
53
from collections .abc import Coroutine , Iterable
54
54
55
- from obstore import Bytes
55
+ from obstore import Attributes , Bytes , ReadableFile , WritableFile
56
56
from obstore .store import (
57
57
AzureConfig ,
58
58
AzureConfigInput ,
@@ -462,41 +462,246 @@ def _open(
462
462
autocommit : Any = True , # noqa: ARG002
463
463
cache_options : Any = None , # noqa: ARG002
464
464
** kwargs : Any ,
465
- ) -> BufferedFileSimple :
465
+ ) -> BufferedFile :
466
466
"""Return raw bytes-mode file-like from the file-system."""
467
- return BufferedFileSimple (self , path , mode , ** kwargs )
467
+ if mode not in ("wb" , "rb" ):
468
+ err_msg = f"Only 'rb' and 'wb' modes supported, got: { mode } "
469
+ raise ValueError (err_msg )
470
+
471
+ return BufferedFile (self , path , mode , ** kwargs )
472
+
468
473
474
+ class BufferedFile (fsspec .spec .AbstractBufferedFile ):
475
+ """A buffered readable or writable file.
476
+
477
+ This is a wrapper around [`obstore.ReadableFile`][] and [`obstore.WritableFile`][].
478
+ If you don't have a need to use the fsspec integration, you may be better served by
479
+ using [`open_reader`][obstore.open_reader] or [`open_writer`][obstore.open_writer]
480
+ directly.
481
+ """
469
482
470
- class BufferedFileSimple (fsspec .spec .AbstractBufferedFile ):
471
- """Implementation of buffered file around `fsspec.spec.AbstractBufferedFile`."""
483
+ mode : Literal ["rb" , "wb" ]
484
+ _reader : ReadableFile
485
+ _writer : WritableFile
486
+ _writer_loc : int
487
+ """Stream position.
472
488
489
+ Only defined for writers. We use the underlying rust stream position for reading.
490
+ """
491
+
492
+ @overload
473
493
def __init__ (
474
494
self ,
475
495
fs : AsyncFsspecStore ,
476
496
path : str ,
477
- mode : str = "rb" ,
497
+ mode : Literal ["rb" ] = "rb" ,
498
+ * ,
499
+ buffer_size : int = 1024 * 1024 ,
500
+ ** kwargs : Any ,
501
+ ) -> None : ...
502
+ @overload
503
+ def __init__ (
504
+ self ,
505
+ fs : AsyncFsspecStore ,
506
+ path : str ,
507
+ mode : Literal ["wb" ],
508
+ * ,
509
+ buffer_size : int = 10 * 1024 * 1024 ,
510
+ attributes : Attributes | None = None ,
511
+ tags : dict [str , str ] | None = None ,
512
+ ** kwargs : Any ,
513
+ ) -> None : ...
514
+ def __init__ ( # noqa: PLR0913
515
+ self ,
516
+ fs : AsyncFsspecStore ,
517
+ path : str ,
518
+ mode : Literal ["rb" , "wb" ] = "rb" ,
519
+ * ,
520
+ buffer_size : int | None = None ,
521
+ attributes : Attributes | None = None ,
522
+ tags : dict [str , str ] | None = None ,
478
523
** kwargs : Any ,
479
524
) -> None :
480
- """Create new buffered file."""
481
- if mode != "rb" :
482
- raise ValueError ("Only 'rb' mode is currently supported" )
525
+ """Create new buffered file.
526
+
527
+ Args:
528
+ fs: The underlying fsspec store to read from.
529
+ path: The path within the store to use.
530
+ mode: `"rb"` for a readable binary file or `"wb"` for a writable binary
531
+ file. Defaults to "rb".
532
+
533
+ Keyword Args:
534
+ attributes: Provide a set of `Attributes`. Only used when writing. Defaults
535
+ to `None`.
536
+ buffer_size: Up to `buffer_size` bytes will be buffered in memory. **When
537
+ reading:** The minimum number of bytes to read in a single request.
538
+ **When writing:** If `buffer_size` is exceeded, data will be uploaded
539
+ as a multipart upload in chunks of `buffer_size`. Defaults to None.
540
+ tags: Provide tags for this object. Only used when writing. Defaults to
541
+ `None`.
542
+ kwargs: Keyword arguments passed on to [`fsspec.spec.AbstractBufferedFile`][].
543
+
544
+ """ # noqa: E501
483
545
super ().__init__ (fs , path , mode , ** kwargs )
484
546
485
- def read (self , length : int = - 1 ) -> Any :
547
+ bucket , path = fs ._split_path (path ) # noqa: SLF001
548
+ store = fs ._construct_store (bucket ) # noqa: SLF001
549
+
550
+ self .mode = mode
551
+
552
+ if self .mode == "rb" :
553
+ buffer_size = 1024 * 1024 if buffer_size is None else buffer_size
554
+ self ._reader = open_reader (store , path , buffer_size = buffer_size )
555
+ elif self .mode == "wb" :
556
+ buffer_size = 10 * 1024 * 1024 if buffer_size is None else buffer_size
557
+ self ._writer = open_writer (
558
+ store ,
559
+ path ,
560
+ attributes = attributes ,
561
+ buffer_size = buffer_size ,
562
+ tags = tags ,
563
+ )
564
+
565
+ self ._writer_loc = 0
566
+ else :
567
+ raise ValueError (f"Invalid mode: { mode } " )
568
+
569
+ def read (self , length : int = - 1 ) -> bytes :
486
570
"""Return bytes from the remote file.
487
571
488
572
Args:
489
573
length: if positive, returns up to this many bytes; if negative, return all
490
574
remaining bytes.
491
575
576
+ Returns:
577
+ Data in bytes
578
+
492
579
"""
580
+ if self .mode != "rb" :
581
+ raise ValueError ("File not in read mode" )
493
582
if length < 0 :
494
- data = self .fs .cat_file (self .path , self .loc , self .size )
495
- self .loc = self .size
496
- else :
497
- data = self .fs .cat_file (self .path , self .loc , self .loc + length )
498
- self .loc += length
499
- return data
583
+ length = self .size - self .tell ()
584
+ if self .closed :
585
+ raise ValueError ("I/O operation on closed file." )
586
+ if length == 0 :
587
+ # don't even bother calling fetch
588
+ return b""
589
+
590
+ out = self ._reader .read (length )
591
+ return out .to_bytes ()
592
+
593
+ def readline (self ) -> bytes :
594
+ """Read until first occurrence of newline character."""
595
+ if self .mode != "rb" :
596
+ raise ValueError ("File not in read mode" )
597
+
598
+ out = self ._reader .readline ()
599
+ return out .to_bytes ()
600
+
601
+ def readlines (self ) -> list [bytes ]:
602
+ """Return all data, split by the newline character."""
603
+ if self .mode != "rb" :
604
+ raise ValueError ("File not in read mode" )
605
+
606
+ out = self ._reader .readlines ()
607
+ return [b .to_bytes () for b in out ]
608
+
609
+ def tell (self ) -> int :
610
+ """Get current file location."""
611
+ if self .mode == "rb" :
612
+ return self ._reader .tell ()
613
+
614
+ if self .mode == "wb" :
615
+ # There's no way to get the stream position from the underlying writer
616
+ # because it's async. Here we happen to be using the async writer in a
617
+ # synchronous way, so we keep our own stream position.
618
+ assert self ._writer_loc is not None
619
+ return self ._writer_loc
620
+
621
+ raise ValueError (f"Unexpected mode { self .mode } " )
622
+
623
+ def seek (self , loc : int , whence : int = 0 ) -> int :
624
+ """Set current file location.
625
+
626
+ Args:
627
+ loc: byte location
628
+ whence: Either
629
+ - `0`: from start of file
630
+ - `1`: current location
631
+ - `2`: end of file
632
+
633
+ """
634
+ if self .mode != "rb" :
635
+ raise ValueError ("Seek only available in read mode." )
636
+
637
+ return self ._reader .seek (loc , whence )
638
+
639
+ def write (self , data : bytes ) -> int :
640
+ """Write data to buffer.
641
+
642
+ Args:
643
+ data: Set of bytes to be written.
644
+
645
+ """
646
+ if not self .writable ():
647
+ raise ValueError ("File not in write mode" )
648
+ if self .closed :
649
+ raise ValueError ("I/O operation on closed file." )
650
+ if self .forced :
651
+ raise ValueError ("This file has been force-flushed, can only close" )
652
+
653
+ num_written = self ._writer .write (data )
654
+ self ._writer_loc += num_written
655
+
656
+ return num_written
657
+
658
+ def flush (
659
+ self ,
660
+ force : bool = False , # noqa: ARG002
661
+ ) -> None :
662
+ """Write buffered data to backend store.
663
+
664
+ Writes the current buffer, if it is larger than the block-size, or if
665
+ the file is being closed.
666
+
667
+ Args:
668
+ force: Unused.
669
+
670
+ """
671
+ if self .closed :
672
+ raise ValueError ("Flush on closed file" )
673
+
674
+ if self .readable ():
675
+ # no-op to flush on read-mode
676
+ return
677
+
678
+ self ._writer .flush ()
679
+
680
+ def close (self ) -> None :
681
+ """Close file. Ensure flushing the buffer."""
682
+ if self .closed :
683
+ return
684
+
685
+ try :
686
+ if self .mode == "rb" :
687
+ self ._reader .close ()
688
+ else :
689
+ self .flush (force = True )
690
+ self ._writer .close ()
691
+ finally :
692
+ self .closed = True
693
+
694
+ @property
695
+ def loc (self ) -> int :
696
+ """Get current file location."""
697
+ # Note, we override the `loc` attribute, because for the reader we manage that
698
+ # state in Rust.
699
+ return self .tell ()
700
+
701
+ @loc .setter
702
+ def loc (self , value : int ) -> None :
703
+ if value != 0 :
704
+ raise ValueError ("Cannot set `.loc`. Use `seek` instead." )
500
705
501
706
502
707
def register (protocol : str | Iterable [str ], * , asynchronous : bool = False ) -> None :
@@ -513,14 +718,16 @@ def register(protocol: str | Iterable[str], *, asynchronous: bool = False) -> No
513
718
asynchronous operations. Defaults to False.
514
719
515
720
Example:
516
- >>> register("s3")
517
- >>> register("s3", asynchronous=True) # Registers an async store for "s3"
518
- >>> register(["gcs", "abfs"]) # Registers both "gcs" and "abfs"
721
+ ```py
722
+ register("s3")
723
+ register("s3", asynchronous=True) # Registers an async store for "s3"
724
+ register(["gcs", "abfs"]) # Registers both "gcs" and "abfs"
725
+ ```
519
726
520
727
Notes:
521
728
- Each protocol gets a dynamically generated subclass named
522
- `AsyncFsspecStore_<protocol>`.
523
- - This avoids modifying the original AsyncFsspecStore class.
729
+ `AsyncFsspecStore_<protocol>`. This avoids modifying the original
730
+ AsyncFsspecStore class.
524
731
525
732
"""
526
733
if isinstance (protocol , str ):
@@ -542,5 +749,6 @@ def _register(protocol: str, *, asynchronous: bool) -> None:
542
749
"asynchronous" : asynchronous ,
543
750
}, # Assign protocol dynamically
544
751
),
545
- clobber = False ,
752
+ # Override any existing implementations of the same protocol
753
+ clobber = True ,
546
754
)
0 commit comments