Skip to content

Add S3 File Provider #2164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions docker/compose.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ volumes:
name: hatchet_config
hatchet_api_key:
name: hatchet_api_key
postgres_data:
name: postgres_data
hatchet_rabbitmq_data:
name: hatchet_rabbitmq_data
hatchet_rabbitmq_conf:
name: hatchet_rabbitmq_conf
hatchet_postgres_data:
name: hatchet_postgres_data
minio_data:
name: minio_data
postgres_data:
name: postgres_data

services:
postgres:
Expand All @@ -34,6 +36,24 @@ services:
postgres
-c max_connections=1024

minio:
image: minio/minio
profiles: [minio]
env_file:
- ./env/minio.env
volumes:
- minio_data:/data
ports:
- "9000:9000"
- "9001:9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 10s
timeout: 5s
retries: 5
restart: on-failure
command: server /data --console-address ":9001"

hatchet-postgres:
image: postgres:latest
env_file:
Expand Down
20 changes: 20 additions & 0 deletions docker/compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
volumes:
postgres_data:
name: postgres_data
minio_data:
name: minio_data

services:
postgres:
Expand All @@ -22,6 +24,24 @@ services:
postgres
-c max_connections=1024

minio:
image: minio/minio
profiles: [minio]
env_file:
- ./env/minio.env
volumes:
- minio_data:/data
ports:
- "9000:9000"
- "9001:9001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
interval: 10s
timeout: 5s
retries: 5
restart: on-failure
command: server /data --console-address ":9001"

graph_clustering:
image: ragtoriches/cluster-prod
ports:
Expand Down
2 changes: 2 additions & 0 deletions docker/env/minio.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
MINIO_ROOT_USER=minioadmin
MINIO_ROOT_PASSWORD=minioadmin
6 changes: 6 additions & 0 deletions py/all_possible_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ concurrent_request_limit = 256
################################################################################
[file]
provider = "postgres"
# If using S3
bucket_name = ""
endpoint_url = ""
region_name = ""
aws_access_key_id = ""
aws_secret_access_key = ""

################################################################################
# Ingestion Settings (IngestionConfig and nested settings)
Expand Down
9 changes: 6 additions & 3 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,21 @@
# Crypto provider
"CryptoConfig",
"CryptoProvider",
# Email provider
"EmailConfig",
"EmailProvider",
# Database providers
"LimitSettings",
"DatabaseConfig",
"DatabaseProvider",
"Handler",
"PostgresConfigurationSettings",
# Email provider
"EmailConfig",
"EmailProvider",
# Embedding provider
"EmbeddingConfig",
"EmbeddingProvider",
# File provider
"FileConfig",
"FileProvider",
# Ingestion provider
"IngestionConfig",
"IngestionProvider",
Expand Down
18 changes: 11 additions & 7 deletions py/core/base/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from .email import EmailConfig, EmailProvider
from .embedding import EmbeddingConfig, EmbeddingProvider
from .file import FileConfig, FileProvider
from .ingestion import (
ChunkingStrategy,
IngestionConfig,
Expand All @@ -29,26 +30,29 @@
"AppConfig",
"Provider",
"ProviderConfig",
# Ingestion provider
"IngestionConfig",
"IngestionProvider",
"ChunkingStrategy",
# Crypto provider
"CryptoConfig",
"CryptoProvider",
# Email provider
"EmailConfig",
"EmailProvider",
# Database providers
"DatabaseConnectionManager",
"DatabaseConfig",
"LimitSettings",
"PostgresConfigurationSettings",
"DatabaseProvider",
"Handler",
# Email provider
"EmailConfig",
"EmailProvider",
# Embedding provider
"EmbeddingConfig",
"EmbeddingProvider",
# File provider
"FileConfig",
"FileProvider",
# Ingestion provider
"IngestionConfig",
"IngestionProvider",
"ChunkingStrategy",
# LLM provider
"CompletionConfig",
"CompletionProvider",
Expand Down
110 changes: 110 additions & 0 deletions py/core/base/providers/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
from abc import ABC, abstractmethod
from datetime import datetime
from io import BytesIO
from typing import BinaryIO, Optional
from uuid import UUID

from .base import Provider, ProviderConfig

logger = logging.getLogger()


class FileConfig(ProviderConfig):
"""
Configuration for file storage providers.
"""

provider: Optional[str] = None

# S3-specific configuration
bucket_name: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
region_name: Optional[str] = None
endpoint_url: Optional[str] = None

@property
def supported_providers(self) -> list[str]:
"""
List of supported file storage providers.
"""
return [
"postgres",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A shot in the dark here, but should this be "supabase"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll keep this to s3, since it should support any s3 compatible solution (not just Supabase).

"s3",
]

def validate_config(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Unsupported file provider: {self.provider}")

if self.provider == "s3" and (
not self.bucket_name and not os.getenv("S3_BUCKET_NAME")
):
raise ValueError(
"S3 bucket name is required when using S3 provider"
)


class FileProvider(Provider, ABC):
"""
Base abstract class for file storage providers.
"""

def __init__(self, config: FileConfig):
if not isinstance(config, FileConfig):
raise ValueError(
"FileProvider must be initialized with a `FileConfig`."
)
super().__init__(config)
self.config: FileConfig = config

@abstractmethod
async def initialize(self) -> None:
"""Initialize the file provider."""
pass

@abstractmethod
async def store_file(
self,
document_id: UUID,
file_name: str,
file_content: BytesIO,
file_type: Optional[str] = None,
) -> None:
"""Store a file."""
pass

@abstractmethod
async def retrieve_file(
self, document_id: UUID
) -> Optional[tuple[str, BinaryIO, int]]:
"""Retrieve a file."""
pass

@abstractmethod
async def retrieve_files_as_zip(
self,
document_ids: Optional[list[UUID]] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> tuple[str, BinaryIO, int]:
"""Retrieve multiple files as a zip."""
pass

@abstractmethod
async def delete_file(self, document_id: UUID) -> bool:
"""Delete a file."""
pass

@abstractmethod
async def get_files_overview(
self,
offset: int,
limit: int,
filter_document_ids: Optional[list[UUID]] = None,
filter_file_names: Optional[list[str]] = None,
) -> list[dict]:
"""Get an overview of stored files."""
pass
3 changes: 3 additions & 0 deletions py/core/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
OpenAICompletionProvider,
OpenAIEmbeddingProvider,
PostgresDatabaseProvider,
PostgresFileProvider,
R2RAuthProvider,
R2RCompletionProvider,
R2RIngestionProvider,
S3FileProvider,
SendGridEmailProvider,
SimpleOrchestrationProvider,
SupabaseAuthProvider,
Expand Down Expand Up @@ -59,6 +61,7 @@ class R2RProviders(BaseModel):
| OpenAIEmbeddingProvider
| OllamaEmbeddingProvider
)
file: PostgresFileProvider | S3FileProvider
completion_embedding: (
LiteLLMEmbeddingProvider
| OpenAIEmbeddingProvider
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/v3/documents_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ async def create_document(
}

file_name = file_data["filename"]
await self.providers.database.files_handler.store_file(
await self.providers.file.store_file(
document_id,
file_name,
file_content,
Expand Down
51 changes: 39 additions & 12 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
EmailConfig,
EmbeddingConfig,
EmbeddingProvider,
FileConfig,
IngestionConfig,
OCRConfig,
OrchestrationConfig,
Expand Down Expand Up @@ -213,20 +214,40 @@ async def create_database_provider(
quantization_type = (
self.config.embedding.quantization_settings.quantization_type
)
if db_config.provider == "postgres":
database_provider = PostgresDatabaseProvider(
db_config,
dimension,
crypto_provider=crypto_provider,
quantization_type=quantization_type,
)
await database_provider.initialize()
return database_provider
else:
if db_config.provider != "postgres":
raise ValueError(
f"Database provider {db_config.provider} not supported"
)

database_provider = PostgresDatabaseProvider(
db_config,
dimension,
crypto_provider=crypto_provider,
quantization_type=quantization_type,
)
await database_provider.initialize()
return database_provider

@staticmethod
def create_file_provider(
config: FileConfig, database_provider=None, *args, **kwargs
):
if config.provider == "postgres":
from core.providers import PostgresFileProvider

return PostgresFileProvider(
config=config,
project_name=database_provider.project_name,
connection_manager=database_provider.connection_manager,
)

elif config.provider == "s3":
from core.providers import S3FileProvider

return S3FileProvider(config)
else:
raise ValueError(f"File provider {config.provider} not supported")

@staticmethod
def create_embedding_provider(
embedding: EmbeddingConfig, *args, **kwargs
Expand Down Expand Up @@ -407,6 +428,11 @@ async def create_providers(
)
)

file_provider = self.create_file_provider(
config=self.config.file, database_provider=database_provider
)
await file_provider.initialize()

ocr_provider = ocr_provider_override or self.create_ocr_provider(
self.config.ocr
)
Expand Down Expand Up @@ -454,12 +480,13 @@ async def create_providers(

return R2RProviders(
auth=auth_provider,
completion_embedding=completion_embedding_provider,
database=database_provider,
email=email_provider,
embedding=embedding_provider,
completion_embedding=completion_embedding_provider,
file=file_provider,
ingestion=ingestion_provider,
llm=llm_provider,
email=email_provider,
ocr=ocr_provider,
orchestration=orchestration_provider,
scheduler=scheduler_provider,
Expand Down
Loading