Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, List, Optional, Dict
from __future__ import annotations

from typing import Any, Dict, List, Optional, Union
from enum import Enum


Expand Down Expand Up @@ -1036,6 +1038,7 @@ class TimedExecutionKey(Enum):
class ETLSplitStrategy(EnumKern):
CHUNK = "CHUNK"
SHRINK = "SHRINK"
NONE = "NONE"


class ETLFileType(Enum):
Expand Down Expand Up @@ -1169,7 +1172,7 @@ def from_mimetype(value: str):
return ETLFileType.DEFAULT

@classmethod
def get_default_extractor(cls, file_type: Optional["ETLFileType"] = None):
def get_default_extractor(cls, file_type: Optional[ETLFileType] = None) -> ETLExtractorEnum:
if file_type == ETLFileType.MD:
return ETLExtractorMD.FILESYSTEM
elif file_type == ETLFileType.PDF:
Expand All @@ -1192,7 +1195,7 @@ def get_default_extractor(cls, file_type: Optional["ETLFileType"] = None):
return ETLExtractorTxt.LANGCHAIN
raise ValueError(f"No default extractor for given file type {file_type}")

def get_extractor_from_string(self, extractor: Optional[str] = None) -> EnumKern:
def get_extractor_from_string(self, extractor: Optional[str] = None) -> ETLExtractorEnum:
if extractor is None:
return self.get_default_extractor(self)
if self == ETLFileType.MD:
Expand Down Expand Up @@ -1298,6 +1301,20 @@ class ETLExtractorJson(EnumKern):
PANDAS = "PANDAS"


ETLExtractorEnum = Union[
ETLExtractorMD,
ETLExtractorPDF,
ETLExtractorWord,
ETLExtractorExcel,
ETLExtractorPowerpoint,
ETLExtractorImg,
ETLExtractorTxt,
ETLExtractorCsv,
ETLExtractorTsv,
ETLExtractorJson,
]


class ETLExtractors:
def get_all_extractors() -> Dict[EnumKern, List[str]]:
all_extractors = {}
Expand Down
100 changes: 52 additions & 48 deletions etl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_full_config_and_tokenizer_from_config_id(
for_dataset = True

etl_preset_item = etl_config_presets_db_co.get(
etl_config_id or file_reference.meta_data.get("etl_config_id")
etl_config_id or (file_reference.meta_data or {}).get("etl_config_id")
)
extraction_config, etl_file_type = get_extraction_config_for_file_type(
etl_preset_item, content_type or file_reference.content_type
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_download_key(org_id: str, download_id: str) -> Path:
def get_extraction_key(
org_id: str,
download_id: str,
extractor: enums.ETLExtractorPDF,
extractor: enums.ETLExtractorEnum,
llm_config: Dict[str, Any],
) -> Path:
extraction_key = Path(org_id) / download_id / "extract" / extractor.value
Expand Down Expand Up @@ -400,87 +400,91 @@ def get_extraction_key(
def get_splitting_key(
org_id: str,
download_id: str,
extractor: enums.ETLExtractorPDF,
extractor: enums.ETLExtractorEnum,
split_strategy: Optional[enums.ETLSplitStrategy] = None,
llm_config: Optional[Dict[str, Any]] = None,
) -> Path:
extraction_key = Path(org_id) / download_id / "split" / extractor.value

resolved_strategy = split_strategy or enums.ETLSplitStrategy.NONE
extraction_key = get_extraction_key(
org_id, download_id, extractor, llm_config or {}
)
splitting_key = extraction_key / "split" / resolved_strategy.value
if llm_config:
llm_identifier = enums.LLMProvider.from_string(llm_config.get("llmIdentifier"))
extraction_key = extraction_key / llm_identifier.as_key()

if llm_identifier == enums.LLMProvider.AZURE:
engine = llm_config.get("engine", "")
api_base = llm_config.get("apiBase", "")
api_version = llm_config.get("apiVersion", "")
api_hash = get_hashed_string(api_base, api_version)
extraction_key = extraction_key / engine / api_hash
elif llm_identifier == enums.LLMProvider.OPENAI:
model = llm_config.get("model")
extraction_key = extraction_key / model

if overwrite_vision_prompt := llm_config.get("overwriteVisionPrompt"):
prompt_hash = get_hashed_string(overwrite_vision_prompt)
extraction_key = extraction_key / prompt_hash
else:
extraction_key = extraction_key / "DEFAULT_PROMPT"

return extraction_key
splitting_key = splitting_key / _llm_config_cache_path_suffix(
extractor, llm_config, ""
)
return splitting_key


def get_transformation_key(
org_id: str,
download_id: str,
extractor: enums.ETLExtractorPDF,
extractor: enums.ETLExtractorEnum,
llm_config: Dict[str, Any],
prompt: Optional[str] = "",
transformation_type: Optional[
enums.ETLTransformerType
] = enums.ETLTransformerType.NO_TRANSFORMATION,
split_strategy: Optional[enums.ETLSplitStrategy] = None,
) -> Path:
llm_identifier = enums.LLMProvider.from_string(llm_config.get("llmIdentifier"))
resolved_split = split_strategy or enums.ETLSplitStrategy.NONE
splitting_key = get_splitting_key(
org_id,
download_id,
extractor,
resolved_split,
llm_config,
)
transformation_key = (
Path(org_id)
/ download_id
splitting_key
/ "transform"
/ transformation_type.value
/ llm_identifier.as_key()
/ (_llm_config_cache_path_suffix(extractor, llm_config, prompt or ""))
)

return transformation_key


def get_hashed_string(*args, delimiter: str = "_", from_bytes: bool = False) -> str:
if not from_bytes:
_hash = delimiter.join(map(str, args)).encode()
else:
try:
_hash = next(map(bytes, args))
except StopIteration:
raise ValueError("ERROR: A 'bytes' argument is required to hash")

hasher = hashlib.sha256(_hash)
return hasher.hexdigest()


def _llm_config_cache_path_suffix(
extractor: enums.ETLExtractorEnum,
llm_config: Dict[str, Any],
prompt: str,
) -> Path:
llm_identifier = enums.LLMProvider.from_string(llm_config.get("llmIdentifier"))
path = Path(llm_identifier.as_key())
if llm_identifier == enums.LLMProvider.AZURE:
engine = llm_config.get("engine", "")
api_base = llm_config.get("apiBase", "")
api_version = llm_config.get("apiVersion", "")
api_hash = get_hashed_string(extractor.value, api_base, api_version, prompt)
transformation_key = transformation_key / engine / api_hash
path = path / engine / api_hash
elif llm_identifier == enums.LLMProvider.AZURE_FOUNDRY:
model = llm_config.get("model", "")
api_hash = get_hashed_string(
extractor.value, llm_config.get("apiBase", ""), prompt
)
transformation_key = transformation_key / model / api_hash
path = path / model / api_hash
elif (
llm_identifier == enums.LLMProvider.OPENAI
or llm_identifier == enums.LLMProvider.PRIVATEMODE_AI
):
model = llm_config.get("model")
extractor_hash = get_hashed_string(extractor.value, prompt)
transformation_key = transformation_key / model / extractor_hash

return transformation_key


def get_hashed_string(*args, delimiter: str = "_", from_bytes: bool = False) -> str:
if not from_bytes:
_hash = delimiter.join(map(str, args)).encode()
else:
try:
_hash = next(map(bytes, args))
except StopIteration:
raise ValueError("ERROR: A 'bytes' argument is required to hash")

hasher = hashlib.sha256(_hash)
return hasher.hexdigest()
path = path / model / extractor_hash
return path


def get_extraction_config_for_file_type(
Expand Down