diff --git a/docker/compose.full.yaml b/docker/compose.full.yaml index 62763a177..cf6dc97fb 100644 --- a/docker/compose.full.yaml +++ b/docker/compose.full.yaml @@ -5,14 +5,16 @@ volumes: name: hatchet_config hatchet_api_key: name: hatchet_api_key - postgres_data: - name: postgres_data hatchet_rabbitmq_data: name: hatchet_rabbitmq_data hatchet_rabbitmq_conf: name: hatchet_rabbitmq_conf hatchet_postgres_data: name: hatchet_postgres_data + minio_data: + name: minio_data + postgres_data: + name: postgres_data services: postgres: @@ -34,6 +36,24 @@ services: postgres -c max_connections=1024 + minio: + image: minio/minio + profiles: [minio] + env_file: + - ./env/minio.env + volumes: + - minio_data:/data + ports: + - "9000:9000" + - "9001:9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 10s + timeout: 5s + retries: 5 + restart: on-failure + command: server /data --console-address ":9001" + hatchet-postgres: image: postgres:latest env_file: diff --git a/docker/compose.yaml b/docker/compose.yaml index d499475a4..8d2704d5d 100644 --- a/docker/compose.yaml +++ b/docker/compose.yaml @@ -1,6 +1,8 @@ volumes: postgres_data: name: postgres_data + minio_data: + name: minio_data services: postgres: @@ -22,6 +24,24 @@ services: postgres -c max_connections=1024 + minio: + image: minio/minio + profiles: [minio] + env_file: + - ./env/minio.env + volumes: + - minio_data:/data + ports: + - "9000:9000" + - "9001:9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 10s + timeout: 5s + retries: 5 + restart: on-failure + command: server /data --console-address ":9001" + graph_clustering: image: ragtoriches/cluster-prod ports: diff --git a/docker/env/minio.env b/docker/env/minio.env new file mode 100644 index 000000000..95ab4d156 --- /dev/null +++ b/docker/env/minio.env @@ -0,0 +1,2 @@ +MINIO_ROOT_USER=minioadmin +MINIO_ROOT_PASSWORD=minioadmin diff --git a/py/all_possible_config.toml b/py/all_possible_config.toml index 7b00620b6..7f310896b 100644 --- a/py/all_possible_config.toml +++ b/py/all_possible_config.toml @@ -211,6 +211,12 @@ concurrent_request_limit = 256 ################################################################################ [file] provider = "postgres" +# If using S3 +bucket_name = "" +endpoint_url = "" +region_name = "" +aws_access_key_id = "" +aws_secret_access_key = "" ################################################################################ # Ingestion Settings (IngestionConfig and nested settings) diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index a670f3486..cc0455691 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -89,18 +89,21 @@ # Crypto provider "CryptoConfig", "CryptoProvider", - # Email provider - "EmailConfig", - "EmailProvider", # Database providers "LimitSettings", "DatabaseConfig", "DatabaseProvider", "Handler", "PostgresConfigurationSettings", + # Email provider + "EmailConfig", + "EmailProvider", # Embedding provider "EmbeddingConfig", "EmbeddingProvider", + # File provider + "FileConfig", + "FileProvider", # Ingestion provider "IngestionConfig", "IngestionProvider", diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 5cd947f12..36c449835 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -11,6 +11,7 @@ ) from .email import EmailConfig, EmailProvider from .embedding import EmbeddingConfig, EmbeddingProvider +from .file import FileConfig, FileProvider from .ingestion import ( ChunkingStrategy, IngestionConfig, @@ -29,16 +30,9 @@ "AppConfig", "Provider", "ProviderConfig", - # Ingestion provider - "IngestionConfig", - "IngestionProvider", - "ChunkingStrategy", # Crypto provider "CryptoConfig", "CryptoProvider", - # Email provider - "EmailConfig", - "EmailProvider", # Database providers "DatabaseConnectionManager", "DatabaseConfig", @@ -46,9 +40,19 @@ "PostgresConfigurationSettings", "DatabaseProvider", "Handler", + # Email provider + "EmailConfig", + "EmailProvider", # Embedding provider "EmbeddingConfig", "EmbeddingProvider", + # File provider + "FileConfig", + "FileProvider", + # Ingestion provider + "IngestionConfig", + "IngestionProvider", + "ChunkingStrategy", # LLM provider "CompletionConfig", "CompletionProvider", diff --git a/py/core/base/providers/file.py b/py/core/base/providers/file.py new file mode 100644 index 000000000..3ce81e6b5 --- /dev/null +++ b/py/core/base/providers/file.py @@ -0,0 +1,110 @@ +import logging +import os +from abc import ABC, abstractmethod +from datetime import datetime +from io import BytesIO +from typing import BinaryIO, Optional +from uuid import UUID + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class FileConfig(ProviderConfig): + """ + Configuration for file storage providers. + """ + + provider: Optional[str] = None + + # S3-specific configuration + bucket_name: Optional[str] = None + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + region_name: Optional[str] = None + endpoint_url: Optional[str] = None + + @property + def supported_providers(self) -> list[str]: + """ + List of supported file storage providers. + """ + return [ + "postgres", + "s3", + ] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Unsupported file provider: {self.provider}") + + if self.provider == "s3" and ( + not self.bucket_name and not os.getenv("S3_BUCKET_NAME") + ): + raise ValueError( + "S3 bucket name is required when using S3 provider" + ) + + +class FileProvider(Provider, ABC): + """ + Base abstract class for file storage providers. + """ + + def __init__(self, config: FileConfig): + if not isinstance(config, FileConfig): + raise ValueError( + "FileProvider must be initialized with a `FileConfig`." + ) + super().__init__(config) + self.config: FileConfig = config + + @abstractmethod + async def initialize(self) -> None: + """Initialize the file provider.""" + pass + + @abstractmethod + async def store_file( + self, + document_id: UUID, + file_name: str, + file_content: BytesIO, + file_type: Optional[str] = None, + ) -> None: + """Store a file.""" + pass + + @abstractmethod + async def retrieve_file( + self, document_id: UUID + ) -> Optional[tuple[str, BinaryIO, int]]: + """Retrieve a file.""" + pass + + @abstractmethod + async def retrieve_files_as_zip( + self, + document_ids: Optional[list[UUID]] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> tuple[str, BinaryIO, int]: + """Retrieve multiple files as a zip.""" + pass + + @abstractmethod + async def delete_file(self, document_id: UUID) -> bool: + """Delete a file.""" + pass + + @abstractmethod + async def get_files_overview( + self, + offset: int, + limit: int, + filter_document_ids: Optional[list[UUID]] = None, + filter_file_names: Optional[list[str]] = None, + ) -> list[dict]: + """Get an overview of stored files.""" + pass diff --git a/py/core/main/abstractions.py b/py/core/main/abstractions.py index 6f08e3221..bdb24b013 100644 --- a/py/core/main/abstractions.py +++ b/py/core/main/abstractions.py @@ -19,9 +19,11 @@ OpenAICompletionProvider, OpenAIEmbeddingProvider, PostgresDatabaseProvider, + PostgresFileProvider, R2RAuthProvider, R2RCompletionProvider, R2RIngestionProvider, + S3FileProvider, SendGridEmailProvider, SimpleOrchestrationProvider, SupabaseAuthProvider, @@ -59,6 +61,7 @@ class R2RProviders(BaseModel): | OpenAIEmbeddingProvider | OllamaEmbeddingProvider ) + file: PostgresFileProvider | S3FileProvider completion_embedding: ( LiteLLMEmbeddingProvider | OpenAIEmbeddingProvider diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index 589d78821..87880ca7b 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -517,7 +517,7 @@ async def create_document( } file_name = file_data["filename"] - await self.providers.database.files_handler.store_file( + await self.providers.file.store_file( document_id, file_name, file_content, diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 0c3e25126..b01795107 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -12,6 +12,7 @@ EmailConfig, EmbeddingConfig, EmbeddingProvider, + FileConfig, IngestionConfig, OCRConfig, OrchestrationConfig, @@ -213,20 +214,40 @@ async def create_database_provider( quantization_type = ( self.config.embedding.quantization_settings.quantization_type ) - if db_config.provider == "postgres": - database_provider = PostgresDatabaseProvider( - db_config, - dimension, - crypto_provider=crypto_provider, - quantization_type=quantization_type, - ) - await database_provider.initialize() - return database_provider - else: + if db_config.provider != "postgres": raise ValueError( f"Database provider {db_config.provider} not supported" ) + database_provider = PostgresDatabaseProvider( + db_config, + dimension, + crypto_provider=crypto_provider, + quantization_type=quantization_type, + ) + await database_provider.initialize() + return database_provider + + @staticmethod + def create_file_provider( + config: FileConfig, database_provider=None, *args, **kwargs + ): + if config.provider == "postgres": + from core.providers import PostgresFileProvider + + return PostgresFileProvider( + config=config, + project_name=database_provider.project_name, + connection_manager=database_provider.connection_manager, + ) + + elif config.provider == "s3": + from core.providers import S3FileProvider + + return S3FileProvider(config) + else: + raise ValueError(f"File provider {config.provider} not supported") + @staticmethod def create_embedding_provider( embedding: EmbeddingConfig, *args, **kwargs @@ -407,6 +428,11 @@ async def create_providers( ) ) + file_provider = self.create_file_provider( + config=self.config.file, database_provider=database_provider + ) + await file_provider.initialize() + ocr_provider = ocr_provider_override or self.create_ocr_provider( self.config.ocr ) @@ -454,12 +480,13 @@ async def create_providers( return R2RProviders( auth=auth_provider, + completion_embedding=completion_embedding_provider, database=database_provider, + email=email_provider, embedding=embedding_provider, - completion_embedding=completion_embedding_provider, + file=file_provider, ingestion=ingestion_provider, llm=llm_provider, - email=email_provider, ocr=ocr_provider, orchestration=orchestration_provider, scheduler=scheduler_provider, diff --git a/py/core/main/config.py b/py/core/main/config.py index e9fbbb1c3..cbb344507 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -15,6 +15,7 @@ from ..base.providers.database import DatabaseConfig from ..base.providers.email import EmailConfig from ..base.providers.embedding import EmbeddingConfig +from ..base.providers.file import FileConfig from ..base.providers.ingestion import IngestionConfig from ..base.providers.llm import CompletionConfig from ..base.providers.ocr import OCRConfig @@ -60,7 +61,7 @@ class R2RConfig: "batch_size", "add_title_as_prefix", ], - # TODO - deprecated, remove + "file": ["provider"], "ingestion": ["provider"], "logging": ["provider", "log_table"], "database": ["provider"], @@ -70,16 +71,17 @@ class R2RConfig: "scheduler": ["provider"], } + agent: RAGAgentConfig app: AppConfig auth: AuthConfig completion: CompletionConfig + completion_embedding: EmbeddingConfig crypto: CryptoConfig database: DatabaseConfig - embedding: EmbeddingConfig - completion_embedding: EmbeddingConfig email: EmailConfig + embedding: EmbeddingConfig + file: FileConfig ingestion: IngestionConfig - agent: RAGAgentConfig ocr: OCRConfig orchestration: OrchestrationConfig scheduler: SchedulerConfig @@ -114,9 +116,10 @@ def __init__(self, config_data: dict[str, Any]): **self.completion, app=self.app ) # type: ignore self.crypto = CryptoConfig.create(**self.crypto, app=self.app) # type: ignore - self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore self.database = DatabaseConfig.create(**self.database, app=self.app) # type: ignore + self.email = EmailConfig.create(**self.email, app=self.app) # type: ignore self.embedding = EmbeddingConfig.create(**self.embedding, app=self.app) # type: ignore + self.file = FileConfig.create(**self.file, app=self.app) # type: ignore self.completion_embedding = EmbeddingConfig.create( **self.completion_embedding, app=self.app ) # type: ignore diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 25a46fba9..cd49df7f7 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -226,10 +226,8 @@ async def parse_file( try: # Pull file from DB - retrieved = ( - await self.providers.database.files_handler.retrieve_file( - document_info.id - ) + retrieved = await self.providers.file.retrieve_file( + document_info.id ) if not retrieved: # No file found in the DB, can't parse diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 8b4c83e79..f5d3809f0 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -222,9 +222,7 @@ def transform_chunk_id_to_id( async def download_file( self, document_id: UUID ) -> Optional[Tuple[str, BinaryIO, int]]: - if result := await self.providers.database.files_handler.retrieve_file( - document_id - ): + if result := await self.providers.file.retrieve_file(document_id): return result return None @@ -234,12 +232,10 @@ async def export_files( start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> tuple[str, BinaryIO, int]: - return ( - await self.providers.database.files_handler.retrieve_files_as_zip( - document_ids=document_ids, - start_date=start_date, - end_date=end_date, - ) + return await self.providers.file.retrieve_files_as_zip( + document_ids=document_ids, + start_date=start_date, + end_date=end_date, ) async def export_collections( diff --git a/py/core/providers/__init__.py b/py/core/providers/__init__.py index 216965d04..a8b9f7b6a 100644 --- a/py/core/providers/__init__.py +++ b/py/core/providers/__init__.py @@ -22,6 +22,10 @@ OllamaEmbeddingProvider, OpenAIEmbeddingProvider, ) +from .file import ( + PostgresFileProvider, + S3FileProvider, +) from .ingestion import ( # type: ignore R2RIngestionConfig, R2RIngestionProvider, @@ -61,17 +65,20 @@ "BcryptCryptoConfig", "NaClCryptoConfig", "NaClCryptoProvider", + # Database + "PostgresDatabaseProvider", # Embeddings "LiteLLMEmbeddingProvider", "OllamaEmbeddingProvider", "OpenAIEmbeddingProvider", - # Database - "PostgresDatabaseProvider", # Email "AsyncSMTPEmailProvider", "ConsoleMockEmailProvider", "SendGridEmailProvider", "MailerSendEmailProvider", + # File + "PostgresFileProvider", + "S3FileProvider", # LLM "AnthropicCompletionProvider", "OpenAICompletionProvider", diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index e99965dea..d1ad61f1e 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -14,7 +14,6 @@ from .collections import PostgresCollectionsHandler from .conversations import PostgresConversationsHandler from .documents import PostgresDocumentsHandler -from .files import PostgresFilesHandler from .graphs import ( PostgresCommunitiesHandler, PostgresEntitiesHandler, @@ -66,7 +65,6 @@ class PostgresDatabaseProvider(DatabaseProvider): relationships_handler: PostgresRelationshipsHandler graphs_handler: PostgresGraphsHandler prompts_handler: PostgresPromptsHandler - files_handler: PostgresFilesHandler conversations_handler: PostgresConversationsHandler limits_handler: PostgresLimitsHandler maintenance_handler: PostgresMaintenanceHandler @@ -194,9 +192,6 @@ def __init__( self.prompts_handler = PostgresPromptsHandler( self.project_name, self.connection_manager ) - self.files_handler = PostgresFilesHandler( - self.project_name, self.connection_manager - ) self.limits_handler = PostgresLimitsHandler( project_name=self.project_name, connection_manager=self.connection_manager, @@ -233,7 +228,6 @@ async def initialize(self): await self.users_handler.create_tables() await self.chunks_handler.create_tables() await self.prompts_handler.create_tables() - await self.files_handler.create_tables() await self.graphs_handler.create_tables() await self.communities_handler.create_tables() await self.entities_handler.create_tables() diff --git a/py/core/providers/file/__init__.py b/py/core/providers/file/__init__.py new file mode 100644 index 000000000..4cfd28918 --- /dev/null +++ b/py/core/providers/file/__init__.py @@ -0,0 +1,7 @@ +from .postgres import PostgresFileProvider +from .s3 import S3FileProvider + +__all__ = [ + "PostgresFileProvider", + "S3FileProvider", +] diff --git a/py/core/providers/database/files.py b/py/core/providers/file/postgres.py similarity index 89% rename from py/core/providers/database/files.py rename to py/core/providers/file/postgres.py index dc349a7eb..25feffd32 100644 --- a/py/core/providers/database/files.py +++ b/py/core/providers/file/postgres.py @@ -9,24 +9,32 @@ import asyncpg from fastapi import HTTPException -from core.base import Handler, R2RException - -from .base import PostgresConnectionManager +from core.base import FileConfig, FileProvider, R2RException logger = logging.getLogger() -class PostgresFilesHandler(Handler): - """PostgreSQL implementation of the FileHandler.""" - - TABLE_NAME = "files" +class PostgresFileProvider(FileProvider): + """PostgreSQL implementation of the FileProvider.""" - connection_manager: PostgresConnectionManager - - async def create_tables(self) -> None: + def __init__( + self, + config: FileConfig, + project_name: str, + connection_manager, # PostgresConnectionManager + ): + super().__init__(config) + self.table_name = "files" + self.project_name = project_name + self.connection_manager = connection_manager + + def _get_table_name(self, base_name: str) -> str: + return f"{self.project_name}.{base_name}" + + async def initialize(self) -> None: """Create the necessary tables for file storage.""" query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} ( + CREATE TABLE IF NOT EXISTS {self._get_table_name(self.table_name)} ( document_id UUID PRIMARY KEY, name TEXT NOT NULL, oid OID NOT NULL, @@ -46,10 +54,10 @@ async def create_tables(self) -> None: $$ LANGUAGE plpgsql; DROP TRIGGER IF EXISTS update_files_updated_at - ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}; + ON {self._get_table_name(self.table_name)}; CREATE TRIGGER update_files_updated_at - BEFORE UPDATE ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + BEFORE UPDATE ON {self._get_table_name(self.table_name)} FOR EACH ROW EXECUTE FUNCTION {self.project_name}.update_files_updated_at(); """ @@ -65,7 +73,7 @@ async def upsert_file( ) -> None: """Add or update a file entry in storage.""" query = f""" - INSERT INTO {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + INSERT INTO {self._get_table_name(self.table_name)} (document_id, name, oid, size, type) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (document_id) DO UPDATE SET @@ -130,7 +138,7 @@ async def retrieve_file( """Retrieve a file from storage.""" query = f""" SELECT name, oid, size - FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + FROM {self._get_table_name(self.table_name)} WHERE document_id = $1 """ @@ -163,7 +171,7 @@ async def retrieve_files_as_zip( query = f""" SELECT document_id, name, oid, size - FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + FROM {self._get_table_name(self.table_name)} WHERE 1=1 """ params: list = [] @@ -256,7 +264,7 @@ async def _read_lobject(self, conn, oid: int) -> bytes: async def delete_file(self, document_id: UUID) -> bool: """Delete a file from storage.""" query = f""" - SELECT oid FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + SELECT oid FROM {self._get_table_name(self.table_name)} WHERE document_id = $1 """ @@ -272,7 +280,7 @@ async def delete_file(self, document_id: UUID) -> bool: await self._delete_lobject(conn, oid) delete_query = f""" - DELETE FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + DELETE FROM {self._get_table_name(self.table_name)} WHERE document_id = $1 """ await conn.execute(delete_query, document_id) @@ -295,7 +303,7 @@ async def get_files_overview( params: list[str | list[str] | int] = [] query = f""" SELECT document_id, name, oid, size, type, created_at, updated_at - FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} + FROM {self._get_table_name(self.table_name)} """ if filter_document_ids: diff --git a/py/core/providers/file/s3.py b/py/core/providers/file/s3.py new file mode 100644 index 000000000..fa69cdd2c --- /dev/null +++ b/py/core/providers/file/s3.py @@ -0,0 +1,327 @@ +import logging +import os +import zipfile +from datetime import datetime +from io import BytesIO +from typing import BinaryIO, Optional +from uuid import UUID + +import boto3 +from botocore.exceptions import ClientError + +from core.base import FileConfig, FileProvider, R2RException + +logger = logging.getLogger() + + +class S3FileProvider(FileProvider): + """S3 implementation of the FileProvider.""" + + def __init__(self, config: FileConfig): + super().__init__(config) + + self.bucket_name = self.config.bucket_name or os.getenv( + "S3_BUCKET_NAME" + ) + aws_access_key_id = self.config.aws_access_key_id or os.getenv( + "AWS_ACCESS_KEY_ID" + ) + aws_secret_access_key = self.config.aws_secret_access_key or os.getenv( + "AWS_SECRET_ACCESS_KEY" + ) + region_name = self.config.region_name or os.getenv("AWS_REGION") + endpoint_url = self.config.endpoint_url or os.getenv("S3_ENDPOINT_URL") + + # Initialize S3 client + self.s3_client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + endpoint_url=endpoint_url, + ) + + def _get_s3_key(self, document_id: UUID) -> str: + """Generate a unique S3 key for a document.""" + return f"documents/{document_id}" + + async def initialize(self) -> None: + """Initialize S3 bucket.""" + try: + self.s3_client.head_bucket(Bucket=self.bucket_name) + logger.info(f"Using existing S3 bucket: {self.bucket_name}") + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code == "404": + logger.info(f"Creating S3 bucket: {self.bucket_name}") + self.s3_client.create_bucket(Bucket=self.bucket_name) + else: + logger.error(f"Error accessing S3 bucket: {e}") + raise R2RException( + status_code=500, + message=f"Failed to initialize S3 bucket: {e}", + ) from e + + async def store_file( + self, + document_id: UUID, + file_name: str, + file_content: BytesIO, + file_type: Optional[str] = None, + ) -> None: + """Store a file in S3.""" + try: + # Generate S3 key + s3_key = self._get_s3_key(document_id) + + # Upload to S3 + file_content.seek(0) # Reset pointer to beginning + self.s3_client.upload_fileobj( + file_content, + self.bucket_name, + s3_key, + ExtraArgs={ + "ContentType": file_type or "application/octet-stream", + "Metadata": { + "filename": file_name, + "document_id": str(document_id), + }, + }, + ) + + except Exception as e: + logger.error(f"Error storing file in S3: {e}") + raise R2RException( + status_code=500, message=f"Failed to store file in S3: {e}" + ) from e + + async def retrieve_file( + self, document_id: UUID + ) -> Optional[tuple[str, BinaryIO, int]]: + """Retrieve a file from S3.""" + s3_key = self._get_s3_key(document_id) + + try: + # Get file metadata from S3 + response = self.s3_client.head_object( + Bucket=self.bucket_name, Key=s3_key + ) + + file_name = response.get("Metadata", {}).get( + "filename", f"file-{document_id}" + ) + file_size = response.get("ContentLength", 0) + + # Download file from S3 + file_content = BytesIO() + self.s3_client.download_fileobj( + self.bucket_name, s3_key, file_content + ) + + file_content.seek(0) # Reset pointer to beginning + return file_name, file_content, file_size + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code in ["NoSuchKey", "404"]: + raise R2RException( + status_code=404, + message=f"File for document {document_id} not found", + ) from e + else: + raise R2RException( + status_code=500, + message=f"Error retrieving file from S3: {e}", + ) from e + + async def retrieve_files_as_zip( + self, + document_ids: Optional[list[UUID]] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ) -> tuple[str, BinaryIO, int]: + """Retrieve multiple files from S3 and return them as a zip file.""" + if not document_ids: + raise R2RException( + status_code=400, + message="Document IDs must be provided for S3 file retrieval", + ) + + zip_buffer = BytesIO() + + with zipfile.ZipFile( + zip_buffer, "w", zipfile.ZIP_DEFLATED + ) as zip_file: + for doc_id in document_ids: + try: + # Get file information - note that retrieve_file won't return None here + # since any errors will raise exceptions + result = await self.retrieve_file(doc_id) + if result: + file_name, file_content, _ = result + + # Read the content into a bytes object + if hasattr(file_content, "getvalue"): + content_bytes = file_content.getvalue() + else: + # For BinaryIO objects that don't have getvalue() + file_content.seek(0) + content_bytes = file_content.read() + + # Add file to zip + zip_file.writestr(file_name, content_bytes) + + except R2RException as e: + if e.status_code == 404: + # Skip files that don't exist + logger.warning( + f"File for document {doc_id} not found, skipping" + ) + continue + else: + raise + + zip_buffer.seek(0) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + zip_filename = f"files_export_{timestamp}.zip" + zip_size = zip_buffer.getbuffer().nbytes + + if zip_size == 0: + raise R2RException( + status_code=404, + message="No files found for the specified document IDs", + ) + + return zip_filename, zip_buffer, zip_size + + async def delete_file(self, document_id: UUID) -> bool: + """Delete a file from S3.""" + s3_key = self._get_s3_key(document_id) + + try: + # Check if file exists first + self.s3_client.head_object(Bucket=self.bucket_name, Key=s3_key) + + # Delete from S3 + self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key) + + return True + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code") + if error_code in ["NoSuchKey", "404"]: + raise R2RException( + status_code=404, + message=f"File for document {document_id} not found", + ) from e + logger.error(f"Error deleting file from S3: {e}") + raise R2RException( + status_code=500, message=f"Failed to delete file from S3: {e}" + ) from e + + async def get_files_overview( + self, + offset: int, + limit: int, + filter_document_ids: Optional[list[UUID]] = None, + filter_file_names: Optional[list[str]] = None, + ) -> list[dict]: + """ + Get an overview of stored files. + + Note: Since S3 doesn't have native query capabilities like a database, + this implementation works best when document IDs are provided. + """ + results = [] + + if filter_document_ids: + # We can efficiently get specific files by document ID + for doc_id in filter_document_ids: + s3_key = self._get_s3_key(doc_id) + try: + # Get metadata for this file + response = self.s3_client.head_object( + Bucket=self.bucket_name, Key=s3_key + ) + + file_info = { + "document_id": doc_id, + "file_name": response.get("Metadata", {}).get( + "filename", f"file-{doc_id}" + ), + "file_key": s3_key, + "file_size": response.get("ContentLength", 0), + "file_type": response.get("ContentType"), + "created_at": response.get("LastModified"), + "updated_at": response.get("LastModified"), + } + + results.append(file_info) + except ClientError: + # Skip files that don't exist + continue + else: + # This is a list operation on the bucket, which is less efficient + # We list objects with the documents/ prefix + try: + response = self.s3_client.list_objects_v2( + Bucket=self.bucket_name, + Prefix="documents/", + ) + + if "Contents" in response: + # Apply pagination manually + page_items = response["Contents"][offset : offset + limit] + + for item in page_items: + # Extract document ID from the key + key = item["Key"] + doc_id_str = key.split("/")[-1] + + try: + doc_id = UUID(doc_id_str) + + # Get detailed metadata + obj_response = self.s3_client.head_object( + Bucket=self.bucket_name, Key=key + ) + + file_name = obj_response.get("Metadata", {}).get( + "filename", f"file-{doc_id}" + ) + + # Apply filename filter if provided + if ( + filter_file_names + and file_name not in filter_file_names + ): + continue + + file_info = { + "document_id": doc_id, + "file_name": file_name, + "file_key": key, + "file_size": item.get("Size", 0), + "file_type": obj_response.get("ContentType"), + "created_at": item.get("LastModified"), + "updated_at": item.get("LastModified"), + } + + results.append(file_info) + except ValueError: + # Skip if the key doesn't contain a valid UUID + continue + except ClientError as e: + logger.error(f"Error listing files in S3 bucket: {e}") + raise R2RException( + status_code=500, + message=f"Failed to list files from S3: {e}", + ) from e + + if not results: + raise R2RException( + status_code=404, + message="No files found with the given filters", + ) + + return results diff --git a/py/r2r/r2r.toml b/py/r2r/r2r.toml index a6a244922..bfe377284 100644 --- a/py/r2r/r2r.toml +++ b/py/r2r/r2r.toml @@ -58,6 +58,9 @@ request_timeout = 60 [crypto] provider = "bcrypt" +[file] +provider = "postgres" + [database] default_collection_name = "Default" default_collection_description = "Your default collection." diff --git a/py/tests/unit/app/test_routes.py b/py/tests/unit/app/test_routes.py index 2694b2369..a0860ee77 100644 --- a/py/tests/unit/app/test_routes.py +++ b/py/tests/unit/app/test_routes.py @@ -22,6 +22,7 @@ from core.providers.database import PostgresDatabaseProvider from core.providers.email import ConsoleMockEmailProvider from core.providers.embeddings import OpenAIEmbeddingProvider +from core.providers.file import PostgresFileProvider from core.providers.ingestion import R2RIngestionProvider from core.providers.llm import OpenAICompletionProvider from core.providers.orchestration import SimpleOrchestrationProvider @@ -56,6 +57,8 @@ def mock_providers(): mock_embedding.config = Mock() mock_completion_embedding = create_autospec(OpenAIEmbeddingProvider) mock_completion_embedding.config = Mock() + mock_file = create_autospec(PostgresFileProvider) + mock_file.config = Mock() mock_llm = create_autospec(OpenAICompletionProvider) mock_llm.config = Mock() mock_ocr = create_autospec(MistralOCRProvider) @@ -76,6 +79,7 @@ def mock_providers(): database=mock_db, email=mock_email, embedding=mock_embedding, + file=mock_file, ingestion=mock_ingestion, llm=mock_llm, ocr=mock_ocr,