Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Added local Repo Support #106

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sage/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

load_dotenv()


def build_rag_chain(args):
"""Builds a RAG chain via LangChain."""
llm = build_llm_via_langchain(args.llm_provider, args.llm_model)
Expand Down
15 changes: 12 additions & 3 deletions sage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def add_repo_args(parser: ArgumentParser) -> Callable:
default="repos",
help="The local directory to store the repository",
)
parser.add(
"--repo-mode",
default = "remote",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also define the possible values (e.g. choices=["local", "remote"])

choices=["local", "remote"],
help="Define where is the repo present"
)
return validate_repo_args


Expand Down Expand Up @@ -248,9 +254,12 @@ def validate_all(args):


def validate_repo_args(args):
"""Validates the configuration of the repository."""
if not re.match(r"^[^/]+/[^/]+$", args.repo_id):
raise ValueError("repo_id must be in the format 'owner/repo'")
"""Validates the configuration of the repository.
For remote repositories, validates that repo_id is in 'owner/repo' format.
For local mode, accepts a single directory path.
"""
if args.repo_mode != "local" and not re.match(r"^[^/]+/[^/]+$", args.repo_id):
raise ValueError("repo_id must be in the format 'owner/repo' for remote repositories")


def _validate_openai_embedding_args(args):
Expand Down
3 changes: 3 additions & 0 deletions sage/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def download(self) -> bool:
"""Clones the repository to the local directory, if it's not already cloned."""
if os.path.exists(self.local_path):
# The repository is already cloned.
logging.info("Repository already exists..")
return True

if not self.is_public and not self.access_token:
Expand Down Expand Up @@ -254,3 +255,5 @@ def from_args(args: Dict):
"For private repositories, please set the GITHUB_TOKEN variable in your environment."
)
return repo_manager


2 changes: 1 addition & 1 deletion sage/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,4 @@ def build_batch_embedder_from_flags(data_manager: DataManager, chunker: Chunker,
elif args.embedding_provider == "gemini":
return GeminiBatchEmbedder(data_manager, chunker, embedding_model=args.embedding_model)
else:
raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
raise ValueError(f"Unrecognized embedder type {args.embedding_provider}")
11 changes: 7 additions & 4 deletions sage/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,30 @@ def main():
]

args = parser.parse_args()

for validator in arg_validators:
validator(args)

if args.llm_retriever:
logging.warning("The LLM retriever does not require indexing, so this script is a no-op.")
return

# Additionally validate embedder and vector store compatibility.
if args.embedding_provider == "openai" and args.vector_store_provider != "pinecone":
parser.error("When using OpenAI embedder, the vector store type must be Pinecone.")
if args.embedding_provider == "marqo" and args.vector_store_provider != "marqo":
parser.error("When using the marqo embedder, the vector store type must also be marqo.")

if args.repo_mode == "local" and args.local_dir == "repos":
parser.error("You must not store the local repo inside the repos folder")

######################
# Step 1: Embeddings #
######################

# Index the repository.
repo_embedder = None
if args.index_repo:
# Check the repo-mode
logging.info("Cloning the repository...")
repo_manager = GitHubRepoManager.from_args(args)
logging.info("Embedding the repo...")
Expand Down
Loading