Skip to content

Commit 5036dd5

Browse files
authored
Merge pull request #387 from nodestream-proj/0.14
0.14 Development
2 parents 60d5851 + cb5bc11 commit 5036dd5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2678
-832
lines changed

nodestream/cli/commands/run.py

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class Run(NodestreamCommand):
5252
description="Ensure all specified targets are migrated before running specified pipelines",
5353
flag=True,
5454
),
55+
option(
56+
"storage-backend",
57+
description="Storage backend to use for checkpointing",
58+
flag=False,
59+
),
5560
*PROMETHEUS_OPTIONS,
5661
]
5762

nodestream/cli/operations/run_copy.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ...databases import Copier, GraphDatabaseWriter
44
from ...pipeline import Pipeline
5+
from ...pipeline.object_storage import ObjectStore
56
from ...project import Project, Target
67
from ..commands.nodestream_command import NodestreamCommand
78
from .operation import Operation
@@ -29,7 +30,9 @@ async def perform(self, command: NodestreamCommand):
2930
def build_pipeline(self) -> Pipeline:
3031
copier = self.build_copier()
3132
writer = self.build_writer()
32-
return Pipeline([copier, writer], step_outbox_size=10000)
33+
return Pipeline(
34+
[copier, writer], step_outbox_size=10000, object_store=ObjectStore.null()
35+
)
3336

3437
def build_copier(self) -> Copier:
3538
return Copier(

nodestream/cli/operations/run_pipeline.py

+4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def print_effective_config(config):
9292
command.line("<info>Effective configuration:</info>")
9393
command.line(f"<info>{safe_dump(config)}</info>")
9494

95+
storage_name = command.option("storage-backend")
96+
object_store = self.project.get_object_storage_by_name(storage_name)
97+
9598
return RunRequest(
9699
pipeline_name=pipeline.name,
97100
initialization_arguments=PipelineInitializationArguments(
@@ -101,6 +104,7 @@ def print_effective_config(config):
101104
extra_steps=list(
102105
self.get_writer_steps_for_specified_targets(command, pipeline)
103106
),
107+
object_store=object_store,
104108
),
105109
progress_reporter=self.create_progress_reporter(command, pipeline.name),
106110
)
+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from .apis import SimpleApiExtractor
22
from .extractor import Extractor
3-
from .files import FileExtractor, RemoteFileExtractor
3+
from .files import FileExtractor
44
from .iterable import IterableExtractor
55
from .ttls import TimeToLiveConfigurationExtractor
66

77
__all__ = (
88
"Extractor",
99
"IterableExtractor",
10-
"FileExtractor",
11-
"RemoteFileExtractor",
1210
"TimeToLiveConfigurationExtractor",
1311
"SimpleApiExtractor",
12+
"FileExtractor",
1413
)
+38-6
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,52 @@
11
from abc import abstractmethod
2-
from typing import Any, AsyncGenerator
2+
from typing import Any, AsyncGenerator, Generic, TypeVar
33

4-
from ..step import Step
4+
from ..step import Step, StepContext
55

6+
R = TypeVar("R")
7+
T = TypeVar("T")
8+
CHECKPOINT_OBJECT_KEY = "extractor_progress_checkpoint"
69

7-
class Extractor(Step):
10+
11+
class Extractor(Step, Generic[R, T]):
812
"""Extractors represent the source of a set of records.
913
1014
They are like any other step. However, they ignore the incoming record '
1115
stream and instead produce their own stream of records. For this reason
1216
they generally should only be set at the beginning of a pipeline.
1317
"""
1418

15-
def emit_outstanding_records(self):
16-
return self.extract_records()
19+
CHECKPOINT_INTERVAL = 1000
20+
21+
async def start(self, context: StepContext):
22+
if checkpoint := context.object_store.get_pickled(CHECKPOINT_OBJECT_KEY):
23+
context.info("Found Checkpoint For Extractor. Signaling to resume from it.")
24+
await self.resume_from_checkpoint(checkpoint)
25+
26+
async def finish(self, context: StepContext):
27+
context.debug("Clearing checkpoint for extractor since extractor is finished.")
28+
context.object_store.delete(CHECKPOINT_OBJECT_KEY)
29+
30+
async def make_checkpoint(self) -> T:
31+
return None
32+
33+
async def resume_from_checkpoint(self, checkpoint_object: T):
34+
pass
35+
36+
async def commit_checkpoint(self, context: StepContext) -> None:
37+
if checkpoint := await self.make_checkpoint():
38+
context.object_store.put_picklable(CHECKPOINT_OBJECT_KEY, checkpoint)
39+
40+
async def emit_outstanding_records(
41+
self, context: StepContext
42+
) -> AsyncGenerator[R, None]:
43+
items_generated = 0
44+
async for record in self.extract_records():
45+
yield record
46+
items_generated += 1
47+
if items_generated % self.CHECKPOINT_INTERVAL == 0:
48+
await self.commit_checkpoint(context)
1749

1850
@abstractmethod
19-
async def extract_records(self) -> AsyncGenerator[Any, Any]:
51+
def extract_records(self) -> AsyncGenerator[R, Any]:
2052
raise NotImplementedError

nodestream/pipeline/extractors/files.py

+145-51
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import asynccontextmanager
88
from csv import DictReader
99
from glob import glob
10-
from io import BufferedReader, IOBase, TextIOWrapper
10+
from io import BufferedReader, BytesIO, IOBase, TextIOWrapper
1111
from logging import getLogger
1212
from pathlib import Path
1313
from typing import (
@@ -18,6 +18,7 @@
1818
Dict,
1919
Iterable,
2020
List,
21+
Optional,
2122
Tuple,
2223
)
2324

@@ -28,6 +29,7 @@
2829
from ...model import JsonLikeDocument
2930
from ...pluggable import Pluggable
3031
from ...subclass_registry import MissingFromRegistryError, SubclassRegistry
32+
from .credential_utils import AwsClientFactory
3133
from .extractor import Extractor
3234

3335
SUPPORTED_FILE_FORMAT_REGISTRY = SubclassRegistry()
@@ -496,7 +498,123 @@ def describe(self) -> str:
496498
return f"{len(self.urls)} remote files"
497499

498500

499-
class UnifiedFileExtractor(Extractor):
501+
class S3File(ReadableFile):
502+
"""A readable file that is stored in S3.
503+
504+
This class is used to read files from S3. The class takes a key, bucket,
505+
and an S3 client. The class uses the S3 client to get the object from S3
506+
and yield an instance of the file that can be read by the pipeline.
507+
508+
The class also has a method to archive the file after it has been read.
509+
510+
"""
511+
512+
def __init__(
513+
self,
514+
key: str,
515+
s3_client,
516+
bucket: str,
517+
archive_dir: str | None,
518+
object_format: str | None,
519+
) -> None:
520+
self.logger = getLogger(__name__)
521+
self.key = key
522+
self.s3_client = s3_client
523+
self.bucket = bucket
524+
self.archive_dir = archive_dir
525+
self.object_format = object_format
526+
527+
def archive_if_required(self, key: str):
528+
if not self.archive_dir:
529+
return
530+
531+
self.logger.info("Archiving S3 Object", extra=dict(key=key))
532+
filename = Path(key).name
533+
self.s3_client.copy(
534+
Bucket=self.bucket,
535+
Key=f"{self.archive_dir}/{filename}",
536+
CopySource={"Bucket": self.bucket, "Key": key},
537+
)
538+
self.s3_client.delete_object(Bucket=self.bucket, Key=key)
539+
540+
def path_like(self) -> Path:
541+
path = Path(self.key)
542+
return path.with_suffix(self.object_format or path.suffix)
543+
544+
@asynccontextmanager
545+
async def as_reader(self, reader: IOBase):
546+
streaming_body = self.s3_client.get_object(Bucket=self.bucket, Key=self.key)[
547+
"Body"
548+
]
549+
yield reader(BytesIO(streaming_body.read()))
550+
self.archive_if_required(self.key)
551+
552+
553+
class S3FileSource(FileSource, alias="s3"):
554+
"""A class that represents a source of files stored in S3.
555+
556+
This class is used to read files from S3. The class takes a bucket, prefix,
557+
and an S3 client. The class uses the S3 client to list the objects in the
558+
bucket and yield instances of S3File that can be read by the pipeline.
559+
560+
The class also has a method to archive the file after it has been read.
561+
"""
562+
563+
@classmethod
564+
def from_file_data(
565+
cls,
566+
bucket: str,
567+
prefix: Optional[str] = None,
568+
archive_dir: Optional[str] = None,
569+
object_format: Optional[str] = None,
570+
**aws_client_args,
571+
):
572+
return cls(
573+
bucket=bucket,
574+
prefix=prefix,
575+
archive_dir=archive_dir,
576+
object_format=object_format,
577+
s3_client=AwsClientFactory(**aws_client_args).make_client("s3"),
578+
)
579+
580+
def __init__(
581+
self,
582+
bucket: str,
583+
s3_client,
584+
archive_dir: Optional[str] = None,
585+
object_format: Optional[str] = None,
586+
prefix: Optional[str] = None,
587+
):
588+
self.bucket = bucket
589+
self.s3_client = s3_client
590+
self.archive_dir = archive_dir
591+
self.object_format = object_format
592+
self.prefix = prefix or ""
593+
594+
def object_is_in_archive(self, key: str) -> bool:
595+
return key.startswith(self.archive_dir) if self.archive_dir else False
596+
597+
def find_keys_in_bucket(self) -> Iterable[str]:
598+
# Returns all keys in the bucket that are not in the archive dir
599+
# and have the prefix.
600+
paginator = self.s3_client.get_paginator("list_objects_v2")
601+
page_iterator = paginator.paginate(Bucket=self.bucket, Prefix=self.prefix)
602+
for page in page_iterator:
603+
keys = (obj["Key"] for obj in page.get("Contents", []))
604+
yield from filter(lambda k: not self.object_is_in_archive(k), keys)
605+
606+
async def get_files(self):
607+
for key in self.find_keys_in_bucket():
608+
yield S3File(
609+
key=key,
610+
s3_client=self.s3_client,
611+
bucket=self.bucket,
612+
archive_dir=self.archive_dir,
613+
object_format=self.object_format,
614+
)
615+
616+
617+
class FileExtractor(Extractor):
500618
"""A class that extracts records from files.
501619
502620
This class is used to extract records from files. The class takes a list
@@ -507,7 +625,31 @@ class UnifiedFileExtractor(Extractor):
507625
"""
508626

509627
@classmethod
510-
def from_file_data(cls, sources: List[Dict[str, Any]]) -> "UnifiedFileExtractor":
628+
def local(cls, globs: Iterable[str]):
629+
return FileExtractor.from_file_data([{"type": "local", "globs": globs}])
630+
631+
@classmethod
632+
def s3(cls, **kwargs):
633+
return cls([S3FileSource.from_file_data(**kwargs)])
634+
635+
@classmethod
636+
def remote(
637+
cls,
638+
urls: Iterable[str],
639+
memory_spooling_max_size_in_mb: int = 10,
640+
):
641+
return FileExtractor.from_file_data(
642+
[
643+
{
644+
"type": "http",
645+
"urls": urls,
646+
"memory_spooling_max_size_in_mb": memory_spooling_max_size_in_mb,
647+
}
648+
]
649+
)
650+
651+
@classmethod
652+
def from_file_data(cls, sources: List[Dict[str, Any]]) -> "FileExtractor":
511653
return cls(
512654
[FileSource.from_file_data_with_type_label(source) for source in sources]
513655
)
@@ -570,51 +712,3 @@ async def extract_records(self) -> AsyncGenerator[Any, Any]:
570712
self.logger.warning(
571713
f"No files found for source: {file_source.describe()}"
572714
)
573-
574-
575-
# DEPRECATED CODE BELOW ##
576-
#
577-
# The classes below are slated to be removed in the future.
578-
# Additionally, there are aliases from the old class names to the new class
579-
# names to ensure backwards compatibility. These aliases will be removed in
580-
# the future.
581-
582-
583-
class FileExtractor(UnifiedFileExtractor):
584-
"""A class that extracts records from local files.
585-
586-
This class is slated to be removed in the future. It is a subclass of
587-
UnifiedFileExtractor that is used to extract records from local files
588-
"""
589-
590-
@classmethod
591-
def from_file_data(cls, globs: Iterable[str]):
592-
return UnifiedFileExtractor.from_file_data([{"type": "local", "globs": globs}])
593-
594-
595-
class RemoteFileExtractor(UnifiedFileExtractor):
596-
"""A class that extracts records from remote files.
597-
598-
This class is slated to be removed in the future. It is a subclass of
599-
UnifiedFileExtractor that is used to extract records from remote files.
600-
"""
601-
602-
@classmethod
603-
def from_file_data(
604-
cls,
605-
urls: Iterable[str],
606-
memory_spooling_max_size_in_mb: int = 10,
607-
):
608-
return UnifiedFileExtractor.from_file_data(
609-
[
610-
{
611-
"type": "http",
612-
"urls": urls,
613-
"memory_spooling_max_size_in_mb": memory_spooling_max_size_in_mb,
614-
}
615-
]
616-
)
617-
618-
619-
SupportedFileFormat = FileCodec
620-
SupportedCompressedFileFormat = CompressionCodec

nodestream/pipeline/extractors/iterable.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from logging import getLogger
12
from typing import Any, AsyncGenerator, Iterable
23

34
from .extractor import Extractor
@@ -12,7 +13,20 @@ def range(cls, start=0, stop=100, step=1):
1213

1314
def __init__(self, iterable: Iterable[Any]) -> None:
1415
self.iterable = iterable
16+
self.index = 0
17+
self.logger = getLogger(self.__class__.__name__)
1518

1619
async def extract_records(self) -> AsyncGenerator[Any, Any]:
17-
for record in self.iterable:
20+
for index, record in enumerate(self.iterable):
21+
if index < self.index:
22+
continue
23+
self.index = index
1824
yield record
25+
26+
async def make_checkpoint(self):
27+
return self.index
28+
29+
async def resume_from_checkpoint(self, checkpoint):
30+
if isinstance(checkpoint, int):
31+
self.index = checkpoint
32+
self.logger.info(f"Resuming from checkpoint {checkpoint}")
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .athena_extractor import AthenaExtractor
22
from .dynamodb_extractor import DynamoDBExtractor
3-
from .s3_extractor import S3Extractor
43

5-
__all__ = ("AthenaExtractor", "S3Extractor", "DynamoDBExtractor")
4+
__all__ = ("AthenaExtractor", "DynamoDBExtractor")

0 commit comments

Comments
 (0)