7
7
from contextlib import asynccontextmanager
8
8
from csv import DictReader
9
9
from glob import glob
10
- from io import BufferedReader , IOBase , TextIOWrapper
10
+ from io import BufferedReader , BytesIO , IOBase , TextIOWrapper
11
11
from logging import getLogger
12
12
from pathlib import Path
13
13
from typing import (
18
18
Dict ,
19
19
Iterable ,
20
20
List ,
21
+ Optional ,
21
22
Tuple ,
22
23
)
23
24
28
29
from ...model import JsonLikeDocument
29
30
from ...pluggable import Pluggable
30
31
from ...subclass_registry import MissingFromRegistryError , SubclassRegistry
32
+ from .credential_utils import AwsClientFactory
31
33
from .extractor import Extractor
32
34
33
35
SUPPORTED_FILE_FORMAT_REGISTRY = SubclassRegistry ()
@@ -496,7 +498,123 @@ def describe(self) -> str:
496
498
return f"{ len (self .urls )} remote files"
497
499
498
500
499
- class UnifiedFileExtractor (Extractor ):
501
+ class S3File (ReadableFile ):
502
+ """A readable file that is stored in S3.
503
+
504
+ This class is used to read files from S3. The class takes a key, bucket,
505
+ and an S3 client. The class uses the S3 client to get the object from S3
506
+ and yield an instance of the file that can be read by the pipeline.
507
+
508
+ The class also has a method to archive the file after it has been read.
509
+
510
+ """
511
+
512
+ def __init__ (
513
+ self ,
514
+ key : str ,
515
+ s3_client ,
516
+ bucket : str ,
517
+ archive_dir : str | None ,
518
+ object_format : str | None ,
519
+ ) -> None :
520
+ self .logger = getLogger (__name__ )
521
+ self .key = key
522
+ self .s3_client = s3_client
523
+ self .bucket = bucket
524
+ self .archive_dir = archive_dir
525
+ self .object_format = object_format
526
+
527
+ def archive_if_required (self , key : str ):
528
+ if not self .archive_dir :
529
+ return
530
+
531
+ self .logger .info ("Archiving S3 Object" , extra = dict (key = key ))
532
+ filename = Path (key ).name
533
+ self .s3_client .copy (
534
+ Bucket = self .bucket ,
535
+ Key = f"{ self .archive_dir } /{ filename } " ,
536
+ CopySource = {"Bucket" : self .bucket , "Key" : key },
537
+ )
538
+ self .s3_client .delete_object (Bucket = self .bucket , Key = key )
539
+
540
+ def path_like (self ) -> Path :
541
+ path = Path (self .key )
542
+ return path .with_suffix (self .object_format or path .suffix )
543
+
544
+ @asynccontextmanager
545
+ async def as_reader (self , reader : IOBase ):
546
+ streaming_body = self .s3_client .get_object (Bucket = self .bucket , Key = self .key )[
547
+ "Body"
548
+ ]
549
+ yield reader (BytesIO (streaming_body .read ()))
550
+ self .archive_if_required (self .key )
551
+
552
+
553
+ class S3FileSource (FileSource , alias = "s3" ):
554
+ """A class that represents a source of files stored in S3.
555
+
556
+ This class is used to read files from S3. The class takes a bucket, prefix,
557
+ and an S3 client. The class uses the S3 client to list the objects in the
558
+ bucket and yield instances of S3File that can be read by the pipeline.
559
+
560
+ The class also has a method to archive the file after it has been read.
561
+ """
562
+
563
+ @classmethod
564
+ def from_file_data (
565
+ cls ,
566
+ bucket : str ,
567
+ prefix : Optional [str ] = None ,
568
+ archive_dir : Optional [str ] = None ,
569
+ object_format : Optional [str ] = None ,
570
+ ** aws_client_args ,
571
+ ):
572
+ return cls (
573
+ bucket = bucket ,
574
+ prefix = prefix ,
575
+ archive_dir = archive_dir ,
576
+ object_format = object_format ,
577
+ s3_client = AwsClientFactory (** aws_client_args ).make_client ("s3" ),
578
+ )
579
+
580
+ def __init__ (
581
+ self ,
582
+ bucket : str ,
583
+ s3_client ,
584
+ archive_dir : Optional [str ] = None ,
585
+ object_format : Optional [str ] = None ,
586
+ prefix : Optional [str ] = None ,
587
+ ):
588
+ self .bucket = bucket
589
+ self .s3_client = s3_client
590
+ self .archive_dir = archive_dir
591
+ self .object_format = object_format
592
+ self .prefix = prefix or ""
593
+
594
+ def object_is_in_archive (self , key : str ) -> bool :
595
+ return key .startswith (self .archive_dir ) if self .archive_dir else False
596
+
597
+ def find_keys_in_bucket (self ) -> Iterable [str ]:
598
+ # Returns all keys in the bucket that are not in the archive dir
599
+ # and have the prefix.
600
+ paginator = self .s3_client .get_paginator ("list_objects_v2" )
601
+ page_iterator = paginator .paginate (Bucket = self .bucket , Prefix = self .prefix )
602
+ for page in page_iterator :
603
+ keys = (obj ["Key" ] for obj in page .get ("Contents" , []))
604
+ yield from filter (lambda k : not self .object_is_in_archive (k ), keys )
605
+
606
+ async def get_files (self ):
607
+ for key in self .find_keys_in_bucket ():
608
+ yield S3File (
609
+ key = key ,
610
+ s3_client = self .s3_client ,
611
+ bucket = self .bucket ,
612
+ archive_dir = self .archive_dir ,
613
+ object_format = self .object_format ,
614
+ )
615
+
616
+
617
+ class FileExtractor (Extractor ):
500
618
"""A class that extracts records from files.
501
619
502
620
This class is used to extract records from files. The class takes a list
@@ -507,7 +625,31 @@ class UnifiedFileExtractor(Extractor):
507
625
"""
508
626
509
627
@classmethod
510
- def from_file_data (cls , sources : List [Dict [str , Any ]]) -> "UnifiedFileExtractor" :
628
+ def local (cls , globs : Iterable [str ]):
629
+ return FileExtractor .from_file_data ([{"type" : "local" , "globs" : globs }])
630
+
631
+ @classmethod
632
+ def s3 (cls , ** kwargs ):
633
+ return cls ([S3FileSource .from_file_data (** kwargs )])
634
+
635
+ @classmethod
636
+ def remote (
637
+ cls ,
638
+ urls : Iterable [str ],
639
+ memory_spooling_max_size_in_mb : int = 10 ,
640
+ ):
641
+ return FileExtractor .from_file_data (
642
+ [
643
+ {
644
+ "type" : "http" ,
645
+ "urls" : urls ,
646
+ "memory_spooling_max_size_in_mb" : memory_spooling_max_size_in_mb ,
647
+ }
648
+ ]
649
+ )
650
+
651
+ @classmethod
652
+ def from_file_data (cls , sources : List [Dict [str , Any ]]) -> "FileExtractor" :
511
653
return cls (
512
654
[FileSource .from_file_data_with_type_label (source ) for source in sources ]
513
655
)
@@ -570,51 +712,3 @@ async def extract_records(self) -> AsyncGenerator[Any, Any]:
570
712
self .logger .warning (
571
713
f"No files found for source: { file_source .describe ()} "
572
714
)
573
-
574
-
575
- # DEPRECATED CODE BELOW ##
576
- #
577
- # The classes below are slated to be removed in the future.
578
- # Additionally, there are aliases from the old class names to the new class
579
- # names to ensure backwards compatibility. These aliases will be removed in
580
- # the future.
581
-
582
-
583
- class FileExtractor (UnifiedFileExtractor ):
584
- """A class that extracts records from local files.
585
-
586
- This class is slated to be removed in the future. It is a subclass of
587
- UnifiedFileExtractor that is used to extract records from local files
588
- """
589
-
590
- @classmethod
591
- def from_file_data (cls , globs : Iterable [str ]):
592
- return UnifiedFileExtractor .from_file_data ([{"type" : "local" , "globs" : globs }])
593
-
594
-
595
- class RemoteFileExtractor (UnifiedFileExtractor ):
596
- """A class that extracts records from remote files.
597
-
598
- This class is slated to be removed in the future. It is a subclass of
599
- UnifiedFileExtractor that is used to extract records from remote files.
600
- """
601
-
602
- @classmethod
603
- def from_file_data (
604
- cls ,
605
- urls : Iterable [str ],
606
- memory_spooling_max_size_in_mb : int = 10 ,
607
- ):
608
- return UnifiedFileExtractor .from_file_data (
609
- [
610
- {
611
- "type" : "http" ,
612
- "urls" : urls ,
613
- "memory_spooling_max_size_in_mb" : memory_spooling_max_size_in_mb ,
614
- }
615
- ]
616
- )
617
-
618
-
619
- SupportedFileFormat = FileCodec
620
- SupportedCompressedFileFormat = CompressionCodec
0 commit comments