From 948c96deb3b33a50a6758907e6036b321adef8e2 Mon Sep 17 00:00:00 2001 From: Suvendu-UI Date: Tue, 8 Oct 2024 15:07:08 +0530 Subject: [PATCH] Add tqdm and trange imports for progress tracking when indexing --- benchmarks/retrieval/retrieve.py | 9 ++-- benchmarks/retrieval/retrieve_kaggle.py | 3 +- package-lock.json | 6 +++ sage/chat.py | 10 ++--- sage/chunker.py | 15 ++++--- sage/config.py | 1 + sage/data_manager.py | 21 ++++----- sage/embedder.py | 59 +++++++++++++------------ sage/github.py | 11 ++--- sage/index.py | 3 +- sage/llm.py | 1 + sage/retriever.py | 1 + sage/vector_store.py | 9 ++-- 13 files changed, 83 insertions(+), 66 deletions(-) create mode 100644 package-lock.json diff --git a/benchmarks/retrieval/retrieve.py b/benchmarks/retrieval/retrieve.py index 38516be..11495f2 100644 --- a/benchmarks/retrieval/retrieve.py +++ b/benchmarks/retrieval/retrieve.py @@ -7,6 +7,7 @@ import logging import os import time +from tqdm import tqdm, trange import configargparse from dotenv import load_dotenv @@ -69,12 +70,12 @@ def main(): golden_docs = [] # List of ir_measures.Qrel objects retrieved_docs = [] # List of ir_measures.ScoredDoc objects - for question_idx, item in enumerate(benchmark): + for question_idx, item in tqdm(enumerate(benchmark)): print(f"Processing question {question_idx}...") query_id = str(question_idx) # Solely needed for ir_measures library. - for golden_filepath in item[args.gold_field]: + for golden_filepath in tqdm(item[args.gold_field]): # All the file paths in the golden answer are equally relevant for the query (i.e. the order is irrelevant), # so we set relevance=1 for all of them. golden_docs.append(Qrel(query_id=query_id, doc_id=golden_filepath, relevance=1)) @@ -82,7 +83,7 @@ def main(): # Make a retrieval call for the current question. retrieved = retriever.invoke(item[args.question_field]) item["retrieved"] = [] - for doc_idx, doc in enumerate(retrieved): + for doc_idx, doc in tqdm(enumerate(retrieved)): # The absolute value of the scores below does not affect the metrics; it merely determines the ranking of # the retrieved documents. The key of the score varies depending on the underlying retriever. If there's no # score, we use 1/(doc_idx+1) since it preserves the order of the documents. @@ -111,7 +112,7 @@ def main(): with open(output_file, "w") as f: json.dump(out_data, f, indent=4) - for key in sorted(results.keys()): + for key in tqdm(sorted(results.keys())): print(f"{key}: {results[key]}") print(f"Predictions and metrics saved to {output_file}") diff --git a/benchmarks/retrieval/retrieve_kaggle.py b/benchmarks/retrieval/retrieve_kaggle.py index da689bd..67fa655 100644 --- a/benchmarks/retrieval/retrieve_kaggle.py +++ b/benchmarks/retrieval/retrieve_kaggle.py @@ -8,6 +8,7 @@ import sage.config from sage.retriever import build_retriever_from_args +from tqdm import tqdm, trange logging.basicConfig(level=logging.INFO) logger = logging.getLogger() @@ -36,7 +37,7 @@ def main(): benchmark = [row for row in benchmark] outputs = [] - for question_idx, item in enumerate(benchmark): + for question_idx, item in tqdm(enumerate(benchmark)): print(f"Processing question {question_idx}...") retrieved = retriever.invoke(item["question"]) diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..85d3c65 --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "sage", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/sage/chat.py b/sage/chat.py index 321103c..786bcb6 100644 --- a/sage/chat.py +++ b/sage/chat.py @@ -83,7 +83,7 @@ def main(): args = parser.parse_args() - for validator in arg_validators: + for validator in tqdm(arg_validators): validator(args) rag_chain = build_rag_chain(args) @@ -95,14 +95,14 @@ def source_md(file_path: str, url: str) -> str: async def _predict(message, history): """Performs one RAG operation.""" history_langchain_format = [] - for human, ai in history: + for human, ai in tqdm(history): history_langchain_format.append(HumanMessage(content=human)) history_langchain_format.append(AIMessage(content=ai)) history_langchain_format.append(HumanMessage(content=message)) query_rewrite = "" response = "" - async for event in rag_chain.astream_events( + async for event in tqdm(rag_chain.astream_events)( { "input": message, "chat_history": history_langchain_format, @@ -110,10 +110,10 @@ async def _predict(message, history): version="v1", ): if event["name"] == "retrieve_documents" and "output" in event["data"]: - sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in event["data"]["output"]] + sources = [(doc.metadata["file_path"], doc.metadata["url"]) for doc in tqdm((event["data"]["output"]))] # Deduplicate while preserving the order. sources = list(dict.fromkeys(sources)) - response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in sources]) + "\n## Response:\n" + response += "## Sources:\n" + "\n".join([source_md(s[0], s[1]) for s in tqdm(sources)]) + "\n## Response:\n" elif event["event"] == "on_chat_model_stream": chunk = event["data"]["chunk"].content diff --git a/sage/chunker.py b/sage/chunker.py index a2f4dd3..0360d68 100644 --- a/sage/chunker.py +++ b/sage/chunker.py @@ -13,6 +13,7 @@ from semchunk import chunk as chunk_via_semchunk from tree_sitter import Node from tree_sitter_language_pack import get_parser +from tqdm import tqdm, trange from sage.constants import TEXT_FIELD @@ -130,17 +131,17 @@ def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> Lis return self.text_chunker.chunk(file_content[node.start_byte : node.end_byte], file_metadata) chunks = [] - for child in node.children: + for child in tqdm(node.children): chunks.extend(self._chunk_node(child, file_content, file_metadata)) - for chunk in chunks: + for chunk in tqdm(chunks): # This should always be true. Otherwise there must be a bug in the code. assert chunk.num_tokens <= self.max_tokens # Merge neighboring chunks if their combined size doesn't exceed max_tokens. The goal is to avoid pathologically # small chunks that end up being undeservedly preferred by the retriever. merged_chunks = [] - for chunk in chunks: + for chunk in tqdm(chunks): if not merged_chunks: merged_chunks.append(chunk) elif merged_chunks[-1].num_tokens + chunk.num_tokens < self.max_tokens - 50: @@ -160,7 +161,7 @@ def _chunk_node(self, node: Node, file_content: str, file_metadata: Dict) -> Lis merged_chunks.append(chunk) chunks = merged_chunks - for chunk in merged_chunks: + for chunk in tqdm(merged_chunks): # This should always be true. Otherwise there's a bug worth investigating. assert chunk.num_tokens <= self.max_tokens @@ -221,7 +222,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]: return [] file_chunks = self._chunk_node(tree.root_node, file_content, file_metadata) - for chunk in file_chunks: + for chunk in tqdm(file_chunks): # Make sure that the chunk has content and doesn't exceed the max_tokens limit. Otherwise there must be # a bug in the code. assert ( @@ -250,7 +251,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]: file_chunks = [] start = 0 - for text_chunk in text_chunks: + for text_chunk in tqdm(text_chunks): # This assertion should always be true. Otherwise there's a bug worth finding. assert self.count_tokens(text_chunk) <= self.max_tokens - extra_tokens @@ -289,7 +290,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]: tmp_metadata = {"file_path": filename.replace(".ipynb", ".py")} chunks = self.code_chunker.chunk(python_code, tmp_metadata) - for chunk in chunks: + for chunk in tqdm(chunks): # Update filenames back to .ipynb chunk.metadata["file_path"] = filename return chunks diff --git a/sage/config.py b/sage/config.py index b7866c5..b7a055e 100644 --- a/sage/config.py +++ b/sage/config.py @@ -10,6 +10,7 @@ from configargparse import ArgumentParser from sage.reranker import RerankerProvider +from tqdm import tqdm, trange # Limits defined here: https://ai.google.dev/gemini-api/docs/models/gemini # NOTE: MAX_CHUNKS_PER_BATCH isn't documented anywhere but we pick a reasonable value diff --git a/sage/data_manager.py b/sage/data_manager.py index 88c5e6b..04648a8 100644 --- a/sage/data_manager.py +++ b/sage/data_manager.py @@ -8,6 +8,7 @@ import requests from git import GitCommandError, Repo +from tqdm import tqdm, trange class DataManager: @@ -130,7 +131,7 @@ def _parse_filter_file(self, file_path: str) -> bool: lines = f.readlines() parsed_data = {"ext": [], "file": [], "dir": []} - for line in lines: + for line in tqdm(lines): if line.startswith("#"): # This is a comment line. continue @@ -149,7 +150,7 @@ def _should_include(self, file_path: str) -> bool: return False # Exclude hidden files and directories. - if any(part.startswith(".") for part in file_path.split(os.path.sep)): + if any(part.startswith(".") for part in tqdm(file_path.split(os.path.sep))): return False if not self.inclusions and not self.exclusions: @@ -165,13 +166,13 @@ def _should_include(self, file_path: str) -> bool: return ( extension in self.inclusions.get("ext", []) or file_name in self.inclusions.get("file", []) - or any(d in dirs for d in self.inclusions.get("dir", [])) + or any(d in dirs for d in tqdm(self.inclusions.get("dir", []))) ) elif self.exclusions: return ( extension not in self.exclusions.get("ext", []) and file_name not in self.exclusions.get("file", []) - and all(d not in dirs for d in self.exclusions.get("dir", [])) + and all(d not in dirs for d in tqdm(self.exclusions.get("dir", []))) ) return True @@ -194,20 +195,20 @@ def walk(self) -> Generator[Tuple[Any, Dict], None, None]: os.remove(excluded_log_file) logging.info("Logging excluded files at %s", excluded_log_file) - for root, _, files in os.walk(self.local_path): - file_paths = [os.path.join(root, file) for file in files] - included_file_paths = [f for f in file_paths if self._should_include(f)] + for root, _, files in tqdm(os.walk(self.local_path)): + file_paths = [os.path.join(root, file) for file in tqdm(files)] + included_file_paths = [f for f in tqdm(file_paths) if self._should_include(f)] with open(included_log_file, "a") as f: - for path in included_file_paths: + for path in tqdm(included_file_paths): f.write(path + "\n") excluded_file_paths = set(file_paths).difference(set(included_file_paths)) with open(excluded_log_file, "a") as f: - for path in excluded_file_paths: + for path in tqdm(excluded_file_paths): f.write(path + "\n") - for file_path in included_file_paths: + for file_path in tqdm(included_file_paths): with open(file_path, "r") as f: try: contents = f.read() diff --git a/sage/embedder.py b/sage/embedder.py index 28147c6..d2c844e 100644 --- a/sage/embedder.py +++ b/sage/embedder.py @@ -17,6 +17,7 @@ from sage.chunker import Chunk, Chunker from sage.constants import TEXT_FIELD from sage.data_manager import DataManager +from tqdm import tqdm, trange Vector = Tuple[Dict, List[float]] # (metadata, embedding) @@ -57,16 +58,16 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) - chunk_count = 0 dataset_name = self.data_manager.dataset_id.replace("/", "_") - for content, metadata in self.data_manager.walk(): + for content, metadata in tqdm(self.data_manager.walk()): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) if len(batch) > chunks_per_batch: - for i in range(0, len(batch), chunks_per_batch): + for i in tqdm(range(0, len(batch), chunks_per_batch)): sub_batch = batch[i : i + chunks_per_batch] openai_batch_id = self._issue_job_for_chunks(sub_batch, batch_id=f"{dataset_name}/{len(batch_ids)}") - batch_ids[openai_batch_id] = [chunk.metadata for chunk in sub_batch] + batch_ids[openai_batch_id] = [chunk.metadata for chunk in tqdm(sub_batch)] if max_embedding_jobs and len(batch_ids) >= max_embedding_jobs: logging.info("Reached the maximum number of embedding jobs. Stopping.") return @@ -75,7 +76,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) - # Finally, commit the last batch. if batch: openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}") - batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch] + batch_ids[openai_batch_id] = [chunk.metadata for chunk in tqdm(batch)] logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count) timestamp = int(time.time()) @@ -95,9 +96,9 @@ def embeddings_are_ready(self, metadata_file: str) -> bool: batch_ids = json.load(f) job_ids = batch_ids.keys() - statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids] - are_ready = all(status.status in ["completed", "failed"] for status in statuses) - status_counts = Counter(status.status for status in statuses) + statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in tqdm(job_ids)] + are_ready = all(status.status in ["completed", "failed"] for status in tqdm(statuses)) + status_counts = Counter(status.status for status in tqdm(statuses)) logging.info("Job statuses: %s", status_counts) return are_ready @@ -117,9 +118,9 @@ def download_embeddings( batch_ids = json.load(f) job_ids = batch_ids.keys() - statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in job_ids] + statuses = [self.client.batches.retrieve(job_id.strip()) for job_id in tqdm(job_ids)] - for idx, status in enumerate(statuses): + for idx, status in tqdm(enumerate(statuses)): if status.status == "failed": logging.error("Job failed: %s", status) continue @@ -134,7 +135,7 @@ def download_embeddings( data = json.loads(file_response.text)["response"]["body"]["data"] logging.info("Job %s generated %d embeddings.", status.id, len(data)) - for datum in data: + for datum in tqdm(data): idx = int(datum["index"]) metadata = batch_metadata[idx] if ( @@ -184,7 +185,7 @@ def _export_to_jsonl(list_of_dicts: List[Dict], output_file: str): if not os.path.exists(directory): os.makedirs(directory) with open(output_file, "w") as f: - for item in list_of_dicts: + for item in tqdm(list_of_dicts): json.dump(item, f) f.write("\n") @@ -193,7 +194,7 @@ def _chunks_to_request(chunks: List[Chunk], batch_id: str, model: str, dimension """Convert a list of chunks to a batch request.""" body = { "model": model, - "input": [chunk.content for chunk in chunks], + "input": [chunk.content for chunk in tqdm(chunks)], } # These are the only two models that support a dynamic embedding size. @@ -222,7 +223,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): batch = [] chunk_count = 0 - for content, metadata in self.data_manager.walk(): + for content, metadata in tqdm(self.data_manager.walk()): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) @@ -233,11 +234,11 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): time.sleep(60) # Voyage API rate limits to 1m tokens per minute; we'll pause every 900k tokens. if len(batch) > chunks_per_batch: - for i in range(0, len(batch), chunks_per_batch): + for i in tqdm(range(0, len(batch), chunks_per_batch)): sub_batch = batch[i : i + chunks_per_batch] logging.info("Embedding %d chunks...", len(sub_batch)) result = self._make_batch_request(sub_batch) - for chunk, datum in zip(sub_batch, result["data"]): + for chunk, datum in tqdm(zip(sub_batch, result["data"])): self.embedding_data.append((chunk.metadata, datum["embedding"])) batch = [] @@ -245,7 +246,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): if batch: logging.info("Embedding %d chunks...", len(batch)) result = self._make_batch_request(batch) - for chunk, datum in zip(batch, result["data"]): + for chunk, datum in tqdm(zip(batch, result["data"])): self.embedding_data.append((chunk.metadata, datum["embedding"])) logging.info(f"Successfully embedded {chunk_count} chunks.") @@ -257,7 +258,7 @@ def embeddings_are_ready(self, *args, **kwargs) -> bool: def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]: """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset.""" - for chunk_metadata, embedding in self.embedding_data: + for chunk_metadata, embedding in tqdm(self.embedding_data): yield (chunk_metadata, embedding) @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(6)) @@ -265,7 +266,7 @@ def _make_batch_request(self, chunks: List[Chunk]) -> Dict: """Makes a batch request to the Voyage API with exponential backoff when we hit rate limits.""" url = "https://api.voyageai.com/v1/embeddings" headers = {"Authorization": f"Bearer {os.environ['VOYAGE_API_KEY']}", "Content-Type": "application/json"} - payload = {"input": [chunk.content for chunk in chunks], "model": self.embedding_model} + payload = {"input": [chunk.content for chunk in tqdm(chunks)], "model": self.embedding_model} response = requests.post(url, json=payload, headers=headers) if not response.status_code == 200: @@ -286,7 +287,7 @@ def __init__(self, data_manager: DataManager, chunker: Chunker, index_name: str, self.client = marqo.Client(url=url) self.index = self.client.index(index_name) - all_index_names = [result["indexName"] for result in self.client.get_indexes()["results"]] + all_index_names = [result["indexName"] for result in tqdm(self.client.get_indexes()["results"])] if not index_name in all_index_names: self.client.create_index(index_name, model=model) @@ -299,17 +300,17 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): batch = [] job_count = 0 - for content, metadata in self.data_manager.walk(): + for content, metadata in tqdm(self.data_manager.walk()): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) if len(batch) > chunks_per_batch: - for i in range(0, len(batch), chunks_per_batch): + for i in tqdm(range(0, len(batch), chunks_per_batch)): sub_batch = batch[i : i + chunks_per_batch] logging.info("Indexing %d chunks...", len(sub_batch)) self.index.add_documents( - documents=[chunk.metadata for chunk in sub_batch], + documents=[chunk.metadata for chunk in tqdm(sub_batch)], tensor_fields=[TEXT_FIELD], ) job_count += 1 @@ -321,7 +322,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): # Finally, commit the last batch. if batch: - self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD]) + self.index.add_documents(documents=[chunk.metadata for chunk in tqdm(batch)], tensor_fields=[TEXT_FIELD]) logging.info(f"Successfully embedded {chunk_count} chunks.") def embeddings_are_ready(self) -> bool: @@ -348,7 +349,7 @@ def __init__(self, data_manager: DataManager, chunker: Chunker, embedding_model: def _make_batch_request(self, chunks: List[Chunk]) -> Dict: return genai.embed_content( - model=self.embedding_model, content=[chunk.content for chunk in chunks], task_type="retrieval_document" + model=self.embedding_model, content=[chunk.content for chunk in tqdm(chunks)], task_type="retrieval_document" ) def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): @@ -359,17 +360,17 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): request_count = 0 last_request_time = time.time() - for content, metadata in self.data_manager.walk(): + for content, metadata in tqdm(self.data_manager.walk()): chunks = self.chunker.chunk(content, metadata) chunk_count += len(chunks) batch.extend(chunks) if len(batch) > chunks_per_batch: - for i in range(0, len(batch), chunks_per_batch): + for i in tqdm(range(0, len(batch), chunks_per_batch)): sub_batch = batch[i : i + chunks_per_batch] logging.info("Embedding %d chunks...", len(sub_batch)) result = self._make_batch_request(sub_batch) - for chunk, embedding in zip(sub_batch, result["embedding"]): + for chunk, embedding in tqdm(zip(sub_batch, result["embedding"])): self.embedding_data.append((chunk.metadata, embedding)) request_count += 1 @@ -393,7 +394,7 @@ def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None): if batch: logging.info("Embedding %d chunks...", len(batch)) result = self._make_batch_request(batch) - for chunk, embedding in zip(batch, result["embedding"]): + for chunk, embedding in tqdm(zip(batch, result["embedding"]):) self.embedding_data.append((chunk.metadata, embedding)) logging.info(f"Successfully embedded {chunk_count} chunks.") @@ -404,7 +405,7 @@ def embeddings_are_ready(self, *args, **kwargs) -> bool: def download_embeddings(self, *args, **kwargs) -> Generator[Vector, None, None]: """Yields (chunk_metadata, embedding) pairs for each chunk in the dataset.""" - for chunk_metadata, embedding in self.embedding_data: + for chunk_metadata, embedding in tqdm(self.embedding_data): yield chunk_metadata, embedding diff --git a/sage/github.py b/sage/github.py index 934ed58..10ebbc1 100644 --- a/sage/github.py +++ b/sage/github.py @@ -10,6 +10,7 @@ from sage.chunker import Chunk, Chunker from sage.constants import TEXT_FIELD from sage.data_manager import DataManager +from tqdm import tqdm, trange tokenizer = tiktoken.get_encoding("cl100k_base") @@ -64,7 +65,7 @@ def download(self) -> bool: logging.info(f"Fetching issues from {url}") response = self._get_page_of_issues(url) response.raise_for_status() - for issue in response.json(): + for issue in tqdm(response.json()): if not "pull_request" in issue: self.issues.append( GitHubIssue( @@ -83,7 +84,7 @@ def download(self) -> bool: def walk(self) -> Generator[Tuple[Any, Dict], None, None]: """Yields a tuple of (issue_content, issue_metadata) for each GitHub issue in the repository.""" - for issue in self.issues: + for issue in tqdm(self.issues): yield issue, {} # empty metadata @staticmethod @@ -98,7 +99,7 @@ def _get_next_link_from_header(response): link_header = response.headers.get("link") if link_header: links = link_header.split(", ") - for link in links: + for link in tqdm(links): url, rel = link.split("; ") url = url[1:-1] # The URL is enclosed in angle brackets rel = rel[5:-1] # e.g. rel="next" -> next @@ -130,7 +131,7 @@ def _get_comments(self, comments_url) -> List[GitHubIssueComment]: logging.warn(f"Timeout fetching comments from {comments_url}") return [] comments = [] - for comment in response.json(): + for comment in tqdm(response.json()): comments.append( GitHubIssueComment( url=comment["url"], @@ -222,7 +223,7 @@ def chunk(self, content: Any, metadata: Dict) -> List[Chunk]: chunks.append(issue_body_chunk) - for comment_idx, comment in enumerate(issue.comments): + for comment_idx, comment in tqdm(enumerate(issue.comments)): # This is just approximate, because when we actually add a comment to the chunk there might be some extra # tokens, like a "Comment:" prefix. approx_comment_size = len(tokenizer.encode(comment.body, disallowed_special=())) + 20 # 20 for buffer diff --git a/sage/index.py b/sage/index.py index 80b23bb..eba46a6 100644 --- a/sage/index.py +++ b/sage/index.py @@ -12,6 +12,7 @@ from sage.embedder import build_batch_embedder_from_flags from sage.github import GitHubIssuesChunker, GitHubIssuesManager from sage.vector_store import build_vector_store_from_args +from tqdm import tqdm, trange logging.basicConfig(level=logging.INFO) logger = logging.getLogger() @@ -33,7 +34,7 @@ def main(): args = parser.parse_args() - for validator in arg_validators: + for validator in tqdm(arg_validators): validator(args) # Additionally validate embedder and vector store compatibility. diff --git a/sage/llm.py b/sage/llm.py index 8b33460..88d000d 100644 --- a/sage/llm.py +++ b/sage/llm.py @@ -3,6 +3,7 @@ from langchain_anthropic import ChatAnthropic from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI +from tqdm import tqdm, trange def build_llm_via_langchain(provider: str, model: str): diff --git a/sage/retriever.py b/sage/retriever.py index 1f88954..dc0e0f7 100644 --- a/sage/retriever.py +++ b/sage/retriever.py @@ -10,6 +10,7 @@ from sage.llm import build_llm_via_langchain from sage.reranker import build_reranker from sage.vector_store import build_vector_store_from_args +from tqdm import tqdm, trange def build_retriever_from_args(args, data_manager: Optional[DataManager] = None): diff --git a/sage/vector_store.py b/sage/vector_store.py index 4ed2568..6b8b8c3 100644 --- a/sage/vector_store.py +++ b/sage/vector_store.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from functools import cached_property from typing import Dict, Generator, List, Optional, Tuple +from tqdm import tqdm, trange import marqo import nltk @@ -45,7 +46,7 @@ def upsert_batch(self, vectors: List[Vector], namespace: str): def upsert(self, vectors: Generator[Vector, None, None], namespace: str): """Upserts in batches of 100, since vector stores have a limit on upsert size.""" batch = [] - for metadata, embedding in vectors: + for metadata, embedding in tqdm(vectors): batch.append((metadata, embedding)) if len(batch) == 100: self.upsert_batch(batch, namespace) @@ -104,7 +105,7 @@ def index(self): def patched_query(*args, **kwargs): result = original_query(*args, **kwargs) - for res in result["matches"]: + for res in tqdm(result["matches"]): if TEXT_FIELD in res["metadata"]: res["metadata"]["context"] = res["metadata"][TEXT_FIELD] return result @@ -124,7 +125,7 @@ def ensure_exists(self): def upsert_batch(self, vectors: List[Vector], namespace: str): pinecone_vectors = [] - for i, (metadata, embedding) in enumerate(vectors): + for i, (metadata, embedding) in tqdm(enumerate(vectors)): vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata} if self.bm25_encoder: vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD]) @@ -172,7 +173,7 @@ def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str # the result, and instead take the "filename" directly from the result. def patched_method(self, results): documents: List[Document] = [] - for result in results["hits"]: + for result in tqdm(results["hits"]): content = result.pop(TEXT_FIELD) documents.append(Document(page_content=content, metadata=result)) return documents