Skip to content

Commit 34a2343

Browse files
committed
ruff formatting
1 parent fcc246c commit 34a2343

15 files changed

Lines changed: 305 additions & 149 deletions

File tree

examples/rag/evaluation/rag_evaluator_example.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,44 @@
66

77
load_dotenv()
88

9-
MOCK_EVALUATOR_CONFIG = './examples/rag/evaluation/rag_eval_example_config.yaml'
10-
MOCK_INDEXER_CONFIG = './examples/rag/evaluation/indexer_eval_example_config.yaml'
11-
MOCK_RAG_CONFIG = './examples/rag/evaluation/rag_evaluated_example_config.yaml'
9+
MOCK_EVALUATOR_CONFIG = "./examples/rag/evaluation/rag_eval_example_config.yaml"
10+
MOCK_INDEXER_CONFIG = "./examples/rag/evaluation/indexer_eval_example_config.yaml"
11+
MOCK_RAG_CONFIG = "./examples/rag/evaluation/rag_evaluated_example_config.yaml"
12+
1213

1314
def get_args():
14-
parser = argparse.ArgumentParser(description='Run RAG Evaluation pipeline with specified parameters or use default mock data')
15-
parser.add_argument('--eval-config', type=str, default=MOCK_EVALUATOR_CONFIG, help='Path to a rag evaluator config file.')
16-
parser.add_argument('--indexer-config', type=str, default=MOCK_INDEXER_CONFIG, help='Path to an Indexer config file.')
17-
parser.add_argument('--rag-config', type=str, default=MOCK_RAG_CONFIG, help='Path to a rag config file.')
15+
parser = argparse.ArgumentParser(
16+
description="Run RAG Evaluation pipeline with specified parameters or use default mock data"
17+
)
18+
parser.add_argument(
19+
"--eval-config",
20+
type=str,
21+
default=MOCK_EVALUATOR_CONFIG,
22+
help="Path to a rag evaluator config file.",
23+
)
24+
parser.add_argument(
25+
"--indexer-config",
26+
type=str,
27+
default=MOCK_INDEXER_CONFIG,
28+
help="Path to an Indexer config file.",
29+
)
30+
parser.add_argument(
31+
"--rag-config",
32+
type=str,
33+
default=MOCK_RAG_CONFIG,
34+
help="Path to a rag config file.",
35+
)
1836

1937
return parser.parse_args()
2038

39+
2140
if __name__ == "__main__":
2241
args = get_args()
2342

2443
# Instantiate RAGEvaluator
2544
evaluator = RAGEvaluator.from_config(args.eval_config)
2645

2746
# Run the evaluation
28-
result = evaluator(
29-
indexer_config = args.indexer_config,
30-
rag_config = args.rag_config
31-
)
47+
result = evaluator(indexer_config=args.indexer_config, rag_config=args.rag_config)
3248

33-
print(result)
49+
print(result)

scripts/data_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020

2121
# Extract 100 files, and copy them in '0000_small' folder
2222
for i in range(100):
23-
os.system(f'cp 0000/{os.listdir("0000")[i]} 0000_small')
23+
os.system(f"cp 0000/{os.listdir('0000')[i]} 0000_small")

src/mmore/index/indexer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ def _create_collection_with_schema(self, collection_name: str):
115115
FieldSchema(
116116
name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=128
117117
),
118-
FieldSchema(
119-
name="document_id", dtype=DataType.VARCHAR, max_length=128
120-
),
118+
FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=128),
121119
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
122120
FieldSchema(
123121
name="dense_embedding",

src/mmore/process/execution_state.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ def initialize(distributed_mode=False, client=None):
3131
"""
3232
if ExecutionState._use_dask is not None:
3333
raise Exception("Execution state already initialized")
34-
assert (
35-
distributed_mode is not None
36-
), "Distributed mode must be set to True or False"
34+
assert distributed_mode is not None, (
35+
"Distributed mode must be set to True or False"
36+
)
3737
ExecutionState._use_dask = distributed_mode
3838

3939
if distributed_mode:
40-
assert (
41-
client is not None
42-
), "You must be in the context of a dask client to use distributed mode"
40+
assert client is not None, (
41+
"You must be in the context of a dask client to use distributed mode"
42+
)
4343
ExecutionState._dask_var = Variable("should_stop_execution", client=client)
4444
ExecutionState._dask_var.set(False)
4545
logger.info("Execution state initialized (distributed mode)")

src/mmore/process/post_processor/chunker/multimodal.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def chunk(self, sample: MultimodalSample) -> List[MultimodalSample]:
8181
chunks = []
8282
for i, (chunk, mods) in enumerate(zip(text_chunks, modalities_chunks)):
8383
s = MultimodalSample(
84-
text=chunk.text, modalities=mods, metadata=sample.metadata, id=f"{sample.id}+{i}"
84+
text=chunk.text,
85+
modalities=mods,
86+
metadata=sample.metadata,
87+
id=f"{sample.id}+{i}",
8588
)
8689
chunks.append(s)
8790

src/mmore/process/post_processor/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _log_plan(self):
4949
logger.info("-" * 50)
5050
logger.info("PP Pipeline:")
5151
for i, processor in enumerate(self.post_processors):
52-
logger.info(f" > {i+1}. {processor.name}")
52+
logger.info(f" > {i + 1}. {processor.name}")
5353
logger.info("-" * 50)
5454

5555
@classmethod
@@ -75,7 +75,7 @@ def run(self, samples: List[MultimodalSample]) -> List[MultimodalSample]:
7575
for i, processor in enumerate(self.post_processors):
7676
samples = processor.batch_process(samples)
7777
if self.output_config.save_each_step:
78-
self.save_results(samples, f"{i+1}___{processor.name}.jsonl")
78+
self.save_results(samples, f"{i + 1}___{processor.name}.jsonl")
7979
self.save_results(samples, "final_pp.jsonl")
8080
return samples
8181

src/mmore/process/processors/media_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def accepts(cls, file: FileDescriptor) -> bool:
4242

4343
@staticmethod
4444
def load_models(
45-
self=None, fast_mode=False # pyright: ignore[reportSelfClsParameterName]
45+
self=None, # pyright: ignore[reportSelfClsParameterName]
46+
fast_mode=False,
4647
):
4748
if self:
4849
model_name = (

src/mmore/process/processors/pdf_processor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
class PDFProcessor(Processor):
23-
artifact_dict = None #create_model_dict()
23+
artifact_dict = None
2424

2525
def __init__(self, config=None):
2626
super().__init__(config=config or ProcessorConfig())
@@ -34,7 +34,7 @@ def accepts(cls, file: FileDescriptor) -> bool:
3434
def load_models(disable_image_extraction: bool = False):
3535
if PDFProcessor.artifact_dict is None:
3636
PDFProcessor.artifact_dict = create_model_dict()
37-
37+
3838
marker_config = {
3939
"disable_image_extraction": disable_image_extraction,
4040
"languages": None,
@@ -46,9 +46,9 @@ def load_models(disable_image_extraction: bool = False):
4646
artifact_dict=PDFProcessor.artifact_dict,
4747
config=config_parser.generate_config_dict(),
4848
)
49-
49+
5050
converter.initialize_processors(converter.default_processors)
51-
51+
5252
return converter
5353

5454
# overwriting the process_batch
@@ -178,8 +178,8 @@ def _extract_images(pdf_doc, xref) -> Optional[Image.Image]:
178178
if self.config.custom_config.get("extract_images", True):
179179
for img_info in page.get_images(full=False):
180180
image = _extract_images(pdf_doc, img_info[0])
181-
if image and clean_image(
182-
image
181+
if (
182+
image and clean_image(image)
183183
): # clean image filters images below size 512x512 and variance below 100, these are defaults and can be changed
184184
embedded_images.append(image)
185185
all_text.append(self.config.attachment_tag)
@@ -209,7 +209,7 @@ def _process_parallel(
209209
):
210210
try:
211211
torch.cuda.set_device(gpu_id)
212-
212+
213213
if PDFProcessor.artifact_dict is None:
214214
PDFProcessor.artifact_dict = create_model_dict()
215215

src/mmore/rag/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __post_init__(self):
7474
else (
7575
"COHERE"
7676
if self.llm_name in _COHERE_MODELS
77-
else "HF" if self.base_url is None else None
77+
else "HF"
78+
if self.base_url is None
79+
else None
7880
)
7981
)
8082
)

src/mmore/rag/retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def retrieve(
122122
return []
123123

124124
# Validate that the specified search type is allowed
125-
assert search_type in get_args(
126-
self._search_types
127-
), f"Invalid search_type: {search_type}. Must be 'dense', 'sparse', or 'hybrid'"
125+
assert search_type in get_args(self._search_types), (
126+
f"Invalid search_type: {search_type}. Must be 'dense', 'sparse', or 'hybrid'"
127+
)
128128

129129
# Determine the weight used to combine dense and sparse search scores
130130
search_weight = self._search_weights.get(search_type, self.hybrid_search_weight)

0 commit comments

Comments
 (0)