Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions protollm/rags/rag_core/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,19 @@ def set_collection_names(self, collection_names: list[str]) -> 'RetrievingPipeli
return self

def get_retrieved_docs(self, query: str) -> list[Document]:
if any([self._retrievers is None, self._collection_names is None]):
if self._retrievers is None or self._collection_names is None:
raise ValueError('Either retrievers or collection_names must not be None')

if len(self._retrievers) == len(self._collection_names):
_query = query
docs = self._retrievers[0].retrieve_top(self._collection_names[0], _query)
for i in range(1, len(self._retrievers)):
filter = {'uuid': {'$in': [doc.metadata['uuid'] for doc in docs]}}
docs_next = self._retrievers[i].retrieve_top(self._collection_names[i], _query, filter)
docs = docs_next
else:
raise Exception('The length of retrievers and collection_names must match')
if docs is None:
return []
for i in range(1, len(self._retrievers)):
filter = {'uuid': {'$in': [doc.metadata['uuid'] for doc in docs]}}
docs_next = self._retrievers[i].retrieve_top(self._collection_names[i], _query, filter)
if docs_next is None:
return []
docs = docs_next

return docs
15 changes: 6 additions & 9 deletions protollm/rags/stores/chroma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,16 @@ def insert_documents(collection: Chroma, docs: Iterable[Document]):

:raises KeyError: if there is no key 'source' in the documents' metadata from 'collection' or 'docs'
"""
first_element = next(docs)
docs = list(docs)
if not docs:
return
first_element = docs[0]
if 'source' not in first_element.metadata.keys():
raise KeyError('There is no file name, called <source>, in document metadata')

existing_docs_name = set(get_all_docs_name(collection))
new_docs_name = set([str(doc.metadata['source'].split('\\')[-1]) for doc in docs] +
[str(first_element.metadata['source'].split('\\')[-1])])
new_docs_name = set([str(doc.metadata['source'].split('\\')[-1]) for doc in docs])
docs_name_for_insert = new_docs_name.difference(existing_docs_name)
if first_element.metadata['source'].split('\\')[-1] in docs_name_for_insert:
docs_for_insert = [first_element]
else:
docs_for_insert = []
docs_for_insert += [doc for doc in docs if doc.metadata['source'].split('\\')[-1] in docs_name_for_insert]
docs_for_insert = [doc for doc in docs if doc.metadata['source'].split('\\')[-1] in docs_name_for_insert]
if docs_for_insert:
collection.add_documents(docs_for_insert)

4 changes: 2 additions & 2 deletions protollm_tools/llm-agents-api/protollm_agents/configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Literal
import logging
from protollm_agents.sdk.models import EmbeddingAPIModel, CompletionModel, MultimodalModel, TokenizerModel, ChatModel
from .sdk.models import EmbeddingAPIModel, CompletionModel, MultimodalModel, TokenizerModel, ChatModel
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from protollm_agents.sdk.vector_stores import ChromaVectorStore
from .sdk.vector_stores import ChromaVectorStore

logger = logging.getLogger(__name__)

Expand Down
8 changes: 4 additions & 4 deletions protollm_tools/llm-agents-api/protollm_agents/sdk/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from langchain_openai import ChatOpenAI

from protollm_agents.sdk.base import AgentAnswer, BaseAgent, Event
from protollm_agents.sdk.context import Context
from protollm_agents.sdk.pipelines.router_pipeline import RouterPipeline
from protollm_agents.sdk.pipelines.ensemble_router_pipeline import EnsembleRouterPipeline
from .base import AgentAnswer, BaseAgent, Event
from .context import Context
from .pipelines.router_pipeline import RouterPipeline
from .pipelines.ensemble_router_pipeline import EnsembleRouterPipeline


class StreamingAgent(BaseAgent, ABC):
Expand Down
4 changes: 2 additions & 2 deletions protollm_tools/llm-agents-api/protollm_agents/sdk/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field

from langchain_core.tools import Tool
from protollm_agents.sdk.base import ModelType, VectorStoreType, AgentType
from protollm_agents.sdk.models import TokenizerModel, CompletionModel, ChatModel, MultimodalModel, EmbeddingAPIModel
from .base import ModelType, VectorStoreType, AgentType
from .models import TokenizerModel, CompletionModel, ChatModel, MultimodalModel, EmbeddingAPIModel

@dataclass
class Context:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import Field

from protollm_agents.sdk.base import Event
from .base import Event


class EventType(str, enum.Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from transformers import AutoTokenizer

from protollm_agents.sdk.base import BaseRunnableModel
from .base import BaseRunnableModel


class BaseOpenAIModel(BaseRunnableModel, ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage, SystemMessage

from protollm_agents.sdk.base import Event
from protollm_agents.sdk.events import TextEvent, ErrorEvent, MultiDictEvent
from ..base import Event
from ..events import TextEvent, ErrorEvent, MultiDictEvent


logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage, SystemMessage

from protollm_agents.sdk.base import Event
from protollm_agents.sdk.events import TextEvent, ErrorEvent, MultiDictEvent
from ..base import Event
from ..events import TextEvent, ErrorEvent, MultiDictEvent


logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langchain_community.vectorstores import VectorStore

from protollm_agents.sdk.base import BaseVectorStore
from .base import BaseVectorStore


class ChromaVectorStore(BaseVectorStore):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Issues = "https://github.com/aimclub/ProtoLLM/issues"
[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
pytest-asyncio = "^0.24.0"
pytest-cov = "^7.0.0"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading
Loading