From d76d2db612b414860c0b91ae1f73486c6ccdabae Mon Sep 17 00:00:00 2001 From: georgiedekker Date: Wed, 9 Apr 2025 19:19:35 +0200 Subject: [PATCH 1/2] fixed hirag --- hirag/_llm.py | 336 +++++++++++++++++++++++++++++++++++++++++++++++-- hirag/hirag.py | 93 ++++++++++++-- 2 files changed, 411 insertions(+), 18 deletions(-) diff --git a/hirag/_llm.py b/hirag/_llm.py index 843a695..e9590d2 100644 --- a/hirag/_llm.py +++ b/hirag/_llm.py @@ -1,6 +1,8 @@ import numpy as np from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError +import aiohttp +import json from tenacity import ( retry, @@ -15,22 +17,85 @@ global_openai_async_client = None global_azure_openai_async_client = None +global_deepseek_session = None +global_ollama_session = None def get_openai_async_client_instance(): global global_openai_async_client if global_openai_async_client is None: - global_openai_async_client = AsyncOpenAI() + # Check for environment variables for custom OpenAI configuration + base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) + api_key = os.environ.get("OPENAI_API_KEY", 'ollama') # Default to ollama if not set + + # Create the client with the environment variables if they exist + if base_url: + global_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) + else: + global_openai_async_client = AsyncOpenAI() return global_openai_async_client def get_azure_openai_async_client_instance(): global global_azure_openai_async_client if global_azure_openai_async_client is None: - global_azure_openai_async_client = AsyncAzureOpenAI() + # Check for environment variables for custom Azure configuration + api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-05-15") + azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", None) + azure_key = os.environ.get("AZURE_OPENAI_API_KEY", None) + + # If Azure configuration is available, use it + if azure_endpoint and azure_key: + global_azure_openai_async_client = AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + api_key=azure_key + ) + else: + # Fall back to the OpenAI client with Azure-compatible settings + base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) + api_key = os.environ.get("OPENAI_API_KEY", "ollama") + + if base_url: + global_azure_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) + else: + global_azure_openai_async_client = AsyncOpenAI() return global_azure_openai_async_client +def get_deepseek_session(): + """Get or create a DeepSeek API session""" + global global_deepseek_session + if global_deepseek_session is None: + global_deepseek_session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}", + "Content-Type": "application/json" + } + ) + return global_deepseek_session + + +def get_ollama_session(): + """Get or create an Ollama API session""" + global global_ollama_session + if global_ollama_session is None: + ollama_base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") + + # Set up basic authentication if provided + headers = {"Content-Type": "application/json"} + auth = None + if os.environ.get("OLLAMA_API_KEY"): + headers["Authorization"] = f"Bearer {os.environ.get('OLLAMA_API_KEY')}" + + global_ollama_session = aiohttp.ClientSession( + base_url=ollama_base_url, + headers=headers, + auth=auth + ) + return global_ollama_session + + @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -67,8 +132,9 @@ async def openai_complete_if_cache( async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + model = os.environ.get("OPENAI_MODEL", "gpt-4o") return await openai_complete_if_cache( - "gpt-4o", + model, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -78,8 +144,10 @@ async def gpt_4o_complete( async def gpt_35_turbo_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + # Use a model from env if available, otherwise fallback to GPT-3.5 + model = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo") return await openai_complete_if_cache( - "gpt-3.5-turbo", + model, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -90,8 +158,96 @@ async def gpt_35_turbo_complete( async def gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + # Use a model from env if available, otherwise fallback to GPT-4o-mini + model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + return await openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + +async def gpt_custom_model_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Get model name from environment variables or fallback to a default + model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) return await openai_complete_if_cache( - "gpt-4o-mini", + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def deepseek_complete_if_cache( + model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + session = get_deepseek_session() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + # Get model name from environment variables or use default + if model is None: + model = os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat")) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Check cache if available + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + # Prepare request + payload = { + "model": model, + "messages": messages, + } + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["temperature", "top_p", "max_tokens", "stream"]: + payload[key] = value + + # Make API request + deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") + async with session.post(f"{deepseek_base_url}/v1/chat/completions", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"DeepSeek API error: {response.status} - {error_text}") + + response_json = await response.json() + completion = response_json["choices"][0]["message"]["content"] + + # Cache the response if enabled + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": completion, "model": model}} + ) + await hashing_kv.index_done_callback() + + return completion + + +async def deepseek_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Model can be overridden with OPENAI_MODEL_NAME for consistency with other API calls + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat"))) + return await deepseek_complete_if_cache( + model, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -99,7 +255,79 @@ async def gpt_4o_mini_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def ollama_complete_if_cache( + model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + session = get_ollama_session() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + # Get model name from environment variables or use default + if model is None: + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Check cache if available + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + # Prepare request + payload = { + "model": model, + "messages": messages, + "stream": False + } + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["temperature", "top_p", "num_predict"]: + payload[key] = value + + # Make API request + async with session.post("/api/chat", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Ollama API error: {response.status} - {error_text}") + + response_json = await response.json() + completion = response_json["message"]["content"] + + # Cache the response if enabled + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": completion, "model": model}} + ) + await hashing_kv.index_done_callback() + + return completion + + +async def ollama_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + return await ollama_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -107,8 +335,10 @@ async def gpt_4o_mini_complete( ) async def openai_embedding(texts: list[str]) -> np.ndarray: openai_async_client = get_openai_async_client_instance() + # Use model from env if available + model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) response = await openai_async_client.embeddings.create( - model="text-embedding-3-small", input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data]) @@ -154,8 +384,10 @@ async def azure_openai_complete_if_cache( async def azure_gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + # Use model from env if available + model = os.environ.get("OPENAI_MODEL", "gpt-4o") return await azure_openai_complete_if_cache( - "gpt-4o", + model, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -166,8 +398,23 @@ async def azure_gpt_4o_complete( async def azure_gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: + # Use model from env if available + model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + return await azure_openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + +async def azure_openai_custom_model_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Get model name from environment variables or fallback to a default + model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) return await azure_openai_complete_if_cache( - "gpt-4o-mini", + model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, @@ -175,7 +422,7 @@ async def azure_gpt_4o_mini_complete( ) -@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), @@ -183,7 +430,74 @@ async def azure_gpt_4o_mini_complete( ) async def azure_openai_embedding(texts: list[str]) -> np.ndarray: azure_openai_client = get_azure_openai_async_client_instance() + # Use model from env if available + model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) response = await azure_openai_client.embeddings.create( - model="text-embedding-3-small", input=texts, encoding_format="float" + model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data]) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def deepseek_embedding(texts: list[str]) -> np.ndarray: + session = get_deepseek_session() + + # Get embedding model from environment variables + model = os.environ.get("DEEPSEEK_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-embedding")) + + # Prepare request payload + payload = { + "model": model, + "input": texts, + "encoding_format": "float" + } + + # Make API request + deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") + async with session.post(f"{deepseek_base_url}/v1/embeddings", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"DeepSeek Embedding API error: {response.status} - {error_text}") + + response_json = await response.json() + embeddings = [data["embedding"] for data in response_json["data"]] + + return np.array(embeddings) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def ollama_embedding(texts: list[str]) -> np.ndarray: + session = get_ollama_session() + + # Get embedding model from environment variables + # Ollama might use the same model for completion and embedding + model = os.environ.get("OLLAMA_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + + # Ollama can only process one text at a time for embeddings + all_embeddings = [] + for text in texts: + # Prepare request payload + payload = { + "model": model, + "prompt": text, + } + + # Make API request + async with session.post("/api/embeddings", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Ollama Embedding API error: {response.status} - {error_text}") + + response_json = await response.json() + embeddings = response_json["embedding"] + all_embeddings.append(embeddings) + + return np.array(all_embeddings) diff --git a/hirag/hirag.py b/hirag/hirag.py index 722148f..15846f3 100644 --- a/hirag/hirag.py +++ b/hirag/hirag.py @@ -16,6 +16,12 @@ azure_gpt_4o_complete, azure_openai_embedding, azure_gpt_4o_mini_complete, + gpt_custom_model_complete, + azure_openai_custom_model_complete, + deepseek_embedding, + deepseek_complete, + ollama_embedding, + ollama_complete ) from ._op import ( chunking_by_token_size, @@ -52,6 +58,20 @@ ) +def determine_default_provider(): + """Determine the default provider based on environment variables""" + if 'GLM_MODEL' in os.environ: + return "glm" + elif 'DEEPSEEK_API_KEY' in os.environ and os.environ.get('DEEPSEEK_API_KEY'): + return "deepseek" + elif os.environ.get('OPENAI_API_BASE', '').find('ollama') >= 0 or os.environ.get('OPENAI_API_KEY') == 'ollama': + return "ollama" + elif os.environ.get('AZURE_OPENAI_API_KEY'): + return "azure" + else: + return "openai" + + @dataclass class HiRAG: working_dir: str = field( @@ -75,7 +95,7 @@ class HiRAG: ] = chunking_by_token_size chunk_token_size: int = 1200 chunk_overlap_token_size: int = 100 - tiktoken_model_name: str = "gpt-4o" + tiktoken_model_name: str = "cl100k_base" # Use cl100k_base as default tokenizer # entity extraction entity_extract_max_gleaning: int = 1 @@ -90,7 +110,7 @@ class HiRAG: node_embedding_algorithm: str = "node2vec" node2vec_params: dict = field( default_factory=lambda: { - "dimensions": 1536, + "dimensions": 3584, # Use 3584 as default for gte-qwen2-7b "num_walks": 10, "walk_length": 40, "num_walks": 10, @@ -106,18 +126,57 @@ class HiRAG: ) # text embedding - embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) + embedding_func: EmbeddingFunc = field( + default_factory=lambda: ( + ollama_embedding + if os.environ.get("PROVIDER", determine_default_provider()).lower() in ("ollama", "glm") + else ( + deepseek_embedding + if os.environ.get("PROVIDER", determine_default_provider()).lower() == "deepseek" + else openai_embedding + ) + ) + ) embedding_batch_num: int = 32 embedding_func_max_async: int = 8 query_better_than_threshold: float = 0.2 # LLM - using_azure_openai: bool = False - # best_model_func: callable = gpt_35_turbo_complete - best_model_func: callable = gpt_4o_mini_complete + using_azure_openai: bool = field( + default_factory=lambda: os.environ.get("PROVIDER", determine_default_provider()).lower() == "azure" + ) + best_model_func: callable = field( + default_factory=lambda: ( + # Determine provider from environment variable + ollama_complete + if os.environ.get("PROVIDER", determine_default_provider()).lower() in ("ollama", "glm") + else ( + deepseek_complete + if os.environ.get("PROVIDER", determine_default_provider()).lower() == "deepseek" + else ( + # For OpenAI and Azure, use the custom model if specified + gpt_custom_model_complete + if os.environ.get("OPENAI_MODEL_NAME") or os.environ.get("OPENAI_MODEL") + else gpt_4o_mini_complete + ) + ) + ) + ) best_model_max_token_size: int = 32768 best_model_max_async: int = 8 - cheap_model_func: callable = gpt_35_turbo_complete + + cheap_model_func: callable = field( + default_factory=lambda: ( + # Determine provider from environment variable + ollama_complete + if os.environ.get("PROVIDER", determine_default_provider()).lower() in ("ollama", "glm") + else ( + deepseek_complete + if os.environ.get("PROVIDER", determine_default_provider()).lower() == "deepseek" + else gpt_35_turbo_complete + ) + ) + ) cheap_model_max_token_size: int = 32768 cheap_model_max_async: int = 8 @@ -145,6 +204,10 @@ def __post_init__(self): # If there's no OpenAI API key, use Azure OpenAI if self.best_model_func == gpt_4o_complete: self.best_model_func = azure_gpt_4o_complete + elif self.best_model_func == gpt_custom_model_complete: + self.best_model_func = azure_openai_custom_model_complete + elif self.best_model_func == gpt_4o_mini_complete: + self.best_model_func = azure_gpt_4o_mini_complete if self.cheap_model_func == gpt_4o_mini_complete: self.cheap_model_func = azure_gpt_4o_mini_complete if self.embedding_func == openai_embedding: @@ -152,6 +215,22 @@ def __post_init__(self): logger.info( "Switched the default openai funcs to Azure OpenAI if you didn't set any of it" ) + + # Log which provider we're using + provider = os.environ.get("PROVIDER", determine_default_provider()).lower() + model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", + os.environ.get("GLM_MODEL", "default model"))) + + if provider == "ollama": + logger.info(f"Using Ollama as provider with model: {model_name}") + elif provider == "glm": + logger.info(f"Using GLM as provider with model: {model_name}") + elif provider == "deepseek": + logger.info(f"Using DeepSeek as provider with model: {model_name}") + elif provider == "azure": + logger.info(f"Using Azure OpenAI as provider with model: {model_name}") + else: + logger.info(f"Using OpenAI as provider with model: {model_name}") if not os.path.exists(self.working_dir) and self.always_create_working_dir: logger.info(f"Creating working directory {self.working_dir}") From 717fa5c420e843b06f64747cdc1f04da55eb66f3 Mon Sep 17 00:00:00 2001 From: georgiedekker Date: Tue, 15 Apr 2025 12:54:25 +0200 Subject: [PATCH 2/2] native cohere and ollama support, to be tested still --- config.yaml | 23 +- hi_Search_cohere.py | 261 +++++++++++++++ hi_Search_ollama.py | 168 ++++++++++ hirag/_llm.py | 737 +++++++++++++++++-------------------------- hirag/_llm_backup.py | 503 +++++++++++++++++++++++++++++ hirag/hirag.py | 4 +- 6 files changed, 1251 insertions(+), 445 deletions(-) create mode 100644 hi_Search_cohere.py create mode 100644 hi_Search_ollama.py create mode 100644 hirag/_llm_backup.py diff --git a/config.yaml b/config.yaml index a6ad54f..e85e73b 100644 --- a/config.yaml +++ b/config.yaml @@ -5,6 +5,16 @@ openai: api_key: "***" base_url: "***" +# Example config.yaml structure +ollama: + base_url: http://localhost:11434 # Default Ollama URL + embedding_model: None # Or another suitable embedding model + chat_model: rjmalagon/gte-qwen2-7b-instruct:f16 + embedding_dim: 768 # Dimension for nomic-embed-text, adjust if using another model +# ... other configs like model_params, hirag ... +model_params: + max_token_size: 8192 # Example, adjust as needed + # GLM Configuration glm: model: "glm-4-plus" @@ -18,17 +28,24 @@ deepseek: api_key: "***" base_url: "https://api.deepseek.com" +cohere: + api_key: "***" + model: "command-r" + embedding_model: "embed-english-v3.0" + embedding_dim: 1024 + # Model Parameters model_params: openai_embedding_dim: 1536 glm_embedding_dim: 2048 + cohere_embedding_dim: 1024 max_token_size: 8192 # HiRAG Configuration hirag: - working_dir: "your_work_dir" - enable_llm_cache: false + working_dir: "./hirag_index_cohere + enable_llm_cache: true enable_hierachical_mode: true embedding_batch_num: 6 embedding_func_max_async: 8 - enable_naive_rag: true \ No newline at end of file + enable_naive_rag: false \ No newline at end of file diff --git a/hi_Search_cohere.py b/hi_Search_cohere.py new file mode 100644 index 0000000..abeea26 --- /dev/null +++ b/hi_Search_cohere.py @@ -0,0 +1,261 @@ +import os +import logging +import numpy as np +import yaml +import cohere +import asyncio +from hirag import HiRAG, QueryParam +from dataclasses import dataclass +from hirag.base import BaseKVStorage +from hirag._utils import compute_args_hash +from typing import List, Dict, Any, Optional + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Load configuration from YAML file +try: + with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) +except FileNotFoundError: + logger.error("Error: config.yaml not found. Please create it with Cohere and HiRAG settings.") + exit(1) +except yaml.YAMLError as e: + logger.error(f"Error parsing config.yaml: {e}") + exit(1) + +# Extract Cohere configurations +try: + COHERE_API_KEY = config['cohere']['api_key'] + COHERE_CHAT_MODEL = config['cohere']['model'] + COHERE_EMBEDDING_MODEL = config['cohere']['embedding_model'] + COHERE_EMBEDDING_DIM = config['cohere']['embedding_dim'] + # Optional: Use environment variables as fallback or override + COHERE_API_KEY = os.environ.get("COHERE_API_KEY", COHERE_API_KEY) +except KeyError as e: + logger.error(f"Missing key in config.yaml under 'cohere': {e}") + exit(1) + +if not COHERE_API_KEY: + logger.error("Cohere API key not found in config.yaml or COHERE_API_KEY environment variable.") + exit(1) + +# Extract HiRAG configurations +try: + HIRAG_WORKING_DIR = config['hirag']['working_dir'] + HIRAG_ENABLE_LLM_CACHE = config['hirag'].get('enable_llm_cache', True) + HIRAG_ENABLE_HIERARCHICAL_MODE = config['hirag'].get('enable_hierachical_mode', True) + HIRAG_EMBEDDING_BATCH_NUM = config['hirag'].get('embedding_batch_num', 16) + HIRAG_EMBEDDING_FUNC_MAX_ASYNC = config['hirag'].get('embedding_func_max_async', 4) + HIRAG_ENABLE_NAIVE_RAG = config['hirag'].get('enable_naive_rag', False) + # Optional input file path from config + INPUT_FILE_PATH = config.get('input_file', None) +except KeyError as e: + logger.error(f"Missing key in config.yaml under 'hirag': {e}") + exit(1) + + +# --- Embedding Function --- + +@dataclass +class EmbeddingFunc: + embedding_dim: int + # Cohere doesn't explicitly publish a max token size for embed v3 like OpenAI does for its models. + # We'll omit it here unless specific constraints are needed. + # max_token_size: int + func: callable + + async def __call__(self, *args, **kwargs) -> np.ndarray: + return await self.func(*args, **kwargs) + +def wrap_embedding_func_with_attrs(**kwargs): + """Wrap an async function with attributes required by HiRAG.""" + def final_decorator(func) -> EmbeddingFunc: + # Ensure the function is async + if not asyncio.iscoroutinefunction(func): + raise TypeError(f"The decorated function {func.__name__} must be async.") + new_func = EmbeddingFunc(**kwargs, func=func) + return new_func + return final_decorator + +@wrap_embedding_func_with_attrs(embedding_dim=COHERE_EMBEDDING_DIM) +async def COHERE_embedding(texts: list[str]) -> np.ndarray: + """Generates embeddings for a list of texts using Cohere API.""" + # Note: Cohere recommends using AsyncClient for concurrent requests + co_async = cohere.AsyncClient(api_key=COHERE_API_KEY) + try: + # Determine input type based on typical HiRAG usage: 'search_document' for indexing. + # HiRAG might call this for queries too; Cohere recommends 'search_query' for queries. + # For simplicity here, we use 'search_document'. A more robust implementation + # might inspect the call context or pass an input_type hint. + response = await co_async.embed( + texts=texts, + model=COHERE_EMBEDDING_MODEL, + input_type="search_document" # Use "search_query" when embedding single queries + ) + # Ensure embeddings are numpy arrays + embeddings = np.array(response.embeddings, dtype=np.float32) + if embeddings.shape[0] != len(texts) or embeddings.shape[1] != COHERE_EMBEDDING_DIM: + logger.error(f"Unexpected embedding shape: {embeddings.shape}. Expected ({len(texts)}, {COHERE_EMBEDDING_DIM})") + # Handle error appropriately, maybe raise or return empty array + raise ValueError("Embedding dimension mismatch or incorrect number of embeddings returned.") + return embeddings + except cohere.CohereError as e: + logger.error(f"Cohere API error during embedding: {e}") + # Re-raise or handle as needed; returning empty array might cause issues downstream + raise + except Exception as e: + logger.error(f"Unexpected error during embedding: {e}") + raise + finally: + # Ensure the async client session is closed + await co_async.close() + + +# --- Model (Chat) Function --- + +def _format_history_for_cohere(history_messages: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Converts OpenAI-style history to Cohere format.""" + cohere_history = [] + for msg in history_messages: + role = msg.get("role", "").lower() + content = msg.get("content", "") + if role == "user": + cohere_history.append({"role": "USER", "message": content}) + elif role == "assistant" or role == "model": # HiRAG might use 'model' + cohere_history.append({"role": "CHATBOT", "message": content}) + # Silently ignore system messages here, handled by 'preamble' in co.chat + return cohere_history + +async def COHERE_model_if_cache( + prompt: str, + system_prompt: Optional[str] = None, + history_messages: List[Dict[str, str]] = [], + **kwargs +) -> str: + """Uses Cohere Chat API, checking cache first.""" + co_async = cohere.AsyncClient(api_key=COHERE_API_KEY) + hashing_kv: Optional[BaseKVStorage] = kwargs.pop("hashing_kv", None) + cache_key = None + + # Prepare request details for hashing and API call + chat_history = _format_history_for_cohere(history_messages) + # For hashing, combine relevant parts. Use a simplified representation. + hash_payload = { + "model": COHERE_CHAT_MODEL, + "message": prompt, + "chat_history": chat_history, + "preamble": system_prompt, + # Include other relevant kwargs if they affect the output significantly + "temperature": kwargs.get("temperature", 0.3) # Example + } + + # Check cache + if hashing_kv is not None: + cache_key = compute_args_hash(hash_payload) + logger.debug(f"Checking cache for key: {cache_key}") + cached_response = await hashing_kv.get_by_id(cache_key) + if cached_response is not None and "return" in cached_response: + logger.info(f"Cache hit for key: {cache_key}") + await co_async.close() # Close client if returning from cache + return cached_response["return"] + else: + logger.info(f"Cache miss for key: {cache_key}") + + + # Call Cohere API + try: + logger.debug(f"Calling Cohere chat model: {COHERE_CHAT_MODEL}") + response = await co_async.chat( + model=COHERE_CHAT_MODEL, + message=prompt, + chat_history=chat_history, + preamble=system_prompt, + temperature=kwargs.get("temperature", 0.3), # Pass through relevant params + # max_tokens=kwargs.get("max_tokens", None) # Example if needed + ) + result_text = response.text + + # Store in cache if enabled + if hashing_kv is not None and cache_key is not None: + logger.debug(f"Storing response in cache for key: {cache_key}") + await hashing_kv.upsert( + {cache_key: {"return": result_text, "model": COHERE_CHAT_MODEL}} + ) + + return result_text + + except cohere.CohereError as e: + logger.error(f"Cohere API error during chat: {e}") + raise # Re-raise to signal failure + except Exception as e: + logger.error(f"Unexpected error during chat: {e}") + raise + finally: + # Ensure the async client session is closed + await co_async.close() + + +# --- Main Execution Logic --- + +async def main(): + """Initializes HiRAG with Cohere and performs indexing/querying.""" + + logger.info("Initializing HiRAG with Cohere backend...") + graph_func = HiRAG(working_dir=HIRAG_WORKING_DIR, + enable_llm_cache=HIRAG_ENABLE_LLM_CACHE, + embedding_func=COHERE_embedding, + best_model_func=COHERE_model_if_cache, # Use Cohere for both best and cheap + cheap_model_func=COHERE_model_if_cache, + enable_hierachical_mode=HIRAG_ENABLE_HIERARCHICAL_MODE, + embedding_batch_num=HIRAG_EMBEDDING_BATCH_NUM, + embedding_func_max_async=HIRAG_EMBEDDING_FUNC_MAX_ASYNC, + enable_naive_rag=HIRAG_ENABLE_NAIVE_RAG) + + # --- Indexing --- + # Check if the working directory exists and might already be indexed. + # HiRAG's insert might handle this, but explicit checks can be useful. + if INPUT_FILE_PATH: + if not os.path.exists(INPUT_FILE_PATH): + logger.error(f"Input file not found: {INPUT_FILE_PATH}") + return # Exit if input file specified but not found + + # Check if indexing might be needed (e.g., based on existence of index files) + # For simplicity, we'll just run insert. Add more sophisticated checks if needed. + logger.info(f"Indexing data from: {INPUT_FILE_PATH}") + try: + with open(INPUT_FILE_PATH, 'r', encoding='utf-8') as f: + text_content = f.read() + # Assuming insert is idempotent or handles re-indexing appropriately + await graph_func.insert(text_content) # Use await for async insert if available + logger.info("Indexing complete.") + except Exception as e: + logger.error(f"Error during indexing: {e}") + return # Stop if indexing fails + else: + logger.warning("No input_file specified in config.yaml. Skipping indexing.") + logger.warning("Ensure the working directory contains a pre-built index or run indexing manually.") + + + # --- Querying --- + query_text = "What are the main capabilities of this system?" # Example query + logger.info(f"Performing HiRAG query: '{query_text}'") + try: + # Assuming query is async or HiRAG handles the async calls internally + # If graph_func.query itself needs await: result = await graph_func.query(...) + result = graph_func.query(query_text, param=QueryParam(mode="hi" if HIRAG_ENABLE_HIERARCHICAL_MODE else "naive")) + logger.info("Query Result:") + print(result) # Print the result directly + except Exception as e: + logger.error(f"Error during query: {e}") + + +if __name__ == "__main__": + # Ensure event loop is running for async operations + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Execution interrupted by user.") + except Exception as e: + logger.error(f"An unexpected error occurred in main execution: {e}") diff --git a/hi_Search_ollama.py b/hi_Search_ollama.py new file mode 100644 index 0000000..6d09a6b --- /dev/null +++ b/hi_Search_ollama.py @@ -0,0 +1,168 @@ +import os +import logging +import numpy as np +import yaml +from hirag import HiRAG, QueryParam +import ollama # Import the ollama library +from dataclasses import dataclass +from hirag.base import BaseKVStorage +from hirag._utils import compute_args_hash +import asyncio + +# Load configuration from YAML file +# Ensure your config.yaml has an 'ollama' section with base_url, embedding_model, chat_model, embedding_dim +# and a 'model_params' section with max_token_size +try: + with open('config.yaml', 'r') as file: + config = yaml.safe_load(file) +except FileNotFoundError: + print("Error: config.yaml not found. Please create it with necessary ollama and model_params sections.") + exit(1) +except KeyError as e: + print(f"Error: Missing key in config.yaml: {e}. Ensure ollama and model_params sections are complete.") + exit(1) + + +# Extract Ollama configurations +OLLAMA_EMBEDDING_MODEL = config['ollama']['embedding_model'] +OLLAMA_CHAT_MODEL = config['ollama']['chat_model'] +OLLAMA_URL = config['ollama'].get('base_url', 'http://localhost:11434') # Use default if not specified +OLLAMA_EMBEDDING_DIM = config['ollama']['embedding_dim'] +MAX_TOKEN_SIZE = config['model_params']['max_token_size'] + +@dataclass +class EmbeddingFunc: + embedding_dim: int + max_token_size: int + func: callable + + async def __call__(self, *args, **kwargs) -> np.ndarray: + return await self.func(*args, **kwargs) + +def wrap_embedding_func_with_attrs(**kwargs): + """Wrap a function with attributes""" + + def final_decro(func) -> EmbeddingFunc: + new_func = EmbeddingFunc(**kwargs, func=func) + return new_func + + return final_decro + +# Define the async Ollama client +# Note: Client instantiation might be better outside the functions if reused heavily, +# but keeping it simple here based on examples. +async def get_ollama_async_client(): + return ollama.AsyncClient(host=OLLAMA_URL) + +@wrap_embedding_func_with_attrs(embedding_dim=OLLAMA_EMBEDDING_DIM, max_token_size=MAX_TOKEN_SIZE) +async def OLLAMA_embedding(texts: list[str]) -> np.ndarray: + """Generates embeddings using the configured Ollama embedding model.""" + client = await get_ollama_async_client() + embeddings = [] + for text in texts: + # ollama.embed currently doesn't support batching in the library, process one by one + # Keep an eye on library updates for potential batch support. + response = await client.embed(model=OLLAMA_EMBEDDING_MODEL, input=text) + embeddings.append(response['embedding']) + return np.array(embeddings, dtype=np.float32) + + +async def OLLAMA_model_if_cache( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + """Sends a chat request to the configured Ollama chat model, using HiRAG cache.""" + client = await get_ollama_async_client() + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if available------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) # history_messages should already be in {"role": "...", "content": "..."} format + messages.append({"role": "user", "content": prompt}) + + if hashing_kv is not None: + # Use the specific Ollama chat model name for hashing + args_hash = compute_args_hash(OLLAMA_CHAT_MODEL, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + logging.info(f"Cache hit for hash: {args_hash}") + return if_cache_return["return"] + logging.info(f"Cache miss for hash: {args_hash}") + # ----------------------------------------------------- + + # Ensure kwargs passed to ollama.chat are valid for its API + # Filter out hashing_kv if it was passed initially + valid_ollama_kwargs = {k: v for k, v in kwargs.items() if k != "hashing_kv"} + + response = await client.chat( + model=OLLAMA_CHAT_MODEL, messages=messages, **valid_ollama_kwargs + ) + + response_content = response['message']['content'] + + # Cache the response ----------------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response_content, "model": OLLAMA_CHAT_MODEL}} + ) + logging.info(f"Cached response for hash: {args_hash}") + # ----------------------------------------------------- + return response_content + + +# Initialize HiRAG with Ollama functions +# Ensure hirag section in config.yaml exists and has necessary keys +try: + graph_func = HiRAG( + working_dir=config['hirag']['working_dir'], + enable_llm_cache=config['hirag']['enable_llm_cache'], + embedding_func=OLLAMA_embedding, + best_model_func=OLLAMA_model_if_cache, # Use Ollama for both best and cheap + cheap_model_func=OLLAMA_model_if_cache, + enable_hierachical_mode=config['hirag']['enable_hierachical_mode'], + embedding_batch_num=config['hirag']['embedding_batch_num'], # Consider Ollama's performance + embedding_func_max_async=config['hirag']['embedding_func_max_async'], # Adjust based on Ollama setup + enable_naive_rag=config['hirag']['enable_naive_rag'] + ) +except KeyError as e: + print(f"Error: Missing key in config.yaml under 'hirag': {e}") + exit(1) + +async def main(): + # --- Insertion Phase --- + # Comment out this block if the working directory has already been indexed + try: + print("Attempting to insert data...") + # Replace "your_data.txt" with the actual path to your text file + file_path = "your_data.txt" + with open(file_path, "r") as f: + data = f.read() + await graph_func.insert(data) # HiRAG insert is now async + print(f"Data from {file_path} inserted successfully into {config['hirag']['working_dir']}") + except FileNotFoundError: + print(f"Error: Data file '{file_path}' not found. Please create it or comment out the insertion block.") + # Decide if you want to exit or continue without insertion + # exit(1) + print("Continuing without data insertion.") + except Exception as e: + print(f"An error occurred during insertion: {e}") + # exit(1) # Optional: exit if insertion fails + + + # --- Query Phase --- + print("\nPerforming hi search using Ollama:") + query_text = "What are the key concepts discussed?" # Example query + try: + # HiRAG query is now async + result = await graph_func.query(query_text, param=QueryParam(mode="hi")) + print(f"\nQuery: {query_text}") + print(f"Result:\n{result}") + except Exception as e: + print(f"\nAn error occurred during query: {e}") + print("Ensure the Ollama server is running and models are available.") + +if __name__ == "__main__": + # Setup basic logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + asyncio.run(main()) diff --git a/hirag/_llm.py b/hirag/_llm.py index e9590d2..ce6ede1 100644 --- a/hirag/_llm.py +++ b/hirag/_llm.py @@ -1,503 +1,358 @@ import numpy as np - -from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError -import aiohttp -import json - +import os +import cohere +from cohere import CohereAPIError, CohereConnectionError, CohereRateLimitError from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) -import os from ._utils import compute_args_hash, wrap_embedding_func_with_attrs from .base import BaseKVStorage +import logging -global_openai_async_client = None -global_azure_openai_async_client = None -global_deepseek_session = None -global_ollama_session = None - +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -def get_openai_async_client_instance(): - global global_openai_async_client - if global_openai_async_client is None: - # Check for environment variables for custom OpenAI configuration - base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) - api_key = os.environ.get("OPENAI_API_KEY", 'ollama') # Default to ollama if not set - - # Create the client with the environment variables if they exist - if base_url: - global_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) - else: - global_openai_async_client = AsyncOpenAI() - return global_openai_async_client +global_cohere_async_client = None -def get_azure_openai_async_client_instance(): - global global_azure_openai_async_client - if global_azure_openai_async_client is None: - # Check for environment variables for custom Azure configuration - api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-05-15") - azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", None) - azure_key = os.environ.get("AZURE_OPENAI_API_KEY", None) +def get_cohere_async_client_instance(): + """Get or create an asynchronous Cohere client instance.""" + global global_cohere_async_client + if global_cohere_async_client is None: + api_key = os.environ.get("COHERE_API_KEY") + if not api_key: + logger.warning("COHERE_API_KEY environment variable not set. Cohere calls will fail.") + # Allow creation to proceed, but calls will likely fail, providing feedback. + # Alternatively, raise an error: raise ValueError("COHERE_API_KEY not set") - # If Azure configuration is available, use it - if azure_endpoint and azure_key: - global_azure_openai_async_client = AsyncAzureOpenAI( - api_version=api_version, - azure_endpoint=azure_endpoint, - api_key=azure_key - ) - else: - # Fall back to the OpenAI client with Azure-compatible settings - base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) - api_key = os.environ.get("OPENAI_API_KEY", "ollama") - - if base_url: - global_azure_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) - else: - global_azure_openai_async_client = AsyncOpenAI() - return global_azure_openai_async_client - - -def get_deepseek_session(): - """Get or create a DeepSeek API session""" - global global_deepseek_session - if global_deepseek_session is None: - global_deepseek_session = aiohttp.ClientSession( - headers={ - "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}", - "Content-Type": "application/json" - } + # Initialize the async client + # Add timeout configurations if needed, e.g., timeout=60 + global_cohere_async_client = cohere.AsyncClient( + api_key=api_key, + # Consider adding client-side timeouts if appropriate + # timeout=(10, 60) # (connect timeout, read timeout) ) - return global_deepseek_session + logger.info("Cohere AsyncClient initialized.") + return global_cohere_async_client -def get_ollama_session(): - """Get or create an Ollama API session""" - global global_ollama_session - if global_ollama_session is None: - ollama_base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") - - # Set up basic authentication if provided - headers = {"Content-Type": "application/json"} - auth = None - if os.environ.get("OLLAMA_API_KEY"): - headers["Authorization"] = f"Bearer {os.environ.get('OLLAMA_API_KEY')}" - - global_ollama_session = aiohttp.ClientSession( - base_url=ollama_base_url, - headers=headers, - auth=auth - ) - return global_ollama_session +def _format_chat_history_for_cohere(history_messages: list[dict]) -> list[dict]: + """Converts a list of messages from OpenAI format to Cohere format.""" + cohere_history = [] + role_map = {"user": "USER", "assistant": "CHATBOT", "system": "SYSTEM"} # SYSTEM role may not map directly, handled by preamble + for msg in history_messages: + role = role_map.get(msg.get("role")) + content = msg.get("content") + if role and content and role != "SYSTEM": # System messages handled by preamble + cohere_history.append({"role": role, "message": content}) + elif role == "SYSTEM": + logger.warning("System messages in history are ignored; use the 'system_prompt' parameter instead for Cohere preamble.") + return cohere_history @retry( stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError)), + retry=retry_if_exception_type((CohereRateLimitError, CohereConnectionError, CohereAPIError)), + reraise=True # Reraise the exception after retries are exhausted ) -async def openai_complete_if_cache( - model, prompt, system_prompt=None, history_messages=[], **kwargs +async def cohere_complete_if_cache( + model: str | None = None, + prompt: str | None = None, + system_prompt: str | None = None, + history_messages: list[dict] | None = None, + **kwargs ) -> str: - openai_async_client = get_openai_async_client_instance() + """ + Generates a completion using the Cohere API, with caching support. + + Args: + model (str | None): The Cohere model ID (e.g., 'command-r'). Defaults to COHERE_CHAT_MODEL env var or 'command-r'. + prompt (str | None): The user's prompt/message. + system_prompt (str | None): The system prompt (preamble for Cohere). + history_messages (list[dict] | None): A list of previous messages in OpenAI format [{'role': 'user'|'assistant', 'content': ...}]. + **kwargs: Additional arguments passed to the Cohere client's chat method (e.g., temperature, max_tokens) + and 'hashing_kv' for caching. + + Returns: + str: The generated text content. + + Raises: + CohereAPIError, CohereConnectionError, CohereRateLimitError: If API calls fail after retries. + ValueError: If prompt is None. + """ + if prompt is None: + raise ValueError("Prompt cannot be None for cohere_complete_if_cache") + + cohere_async_client = get_cohere_async_client_instance() hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] - - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) - - if hashing_kv is not None: - await hashing_kv.upsert( - {args_hash: {"return": response.choices[0].message.content, "model": model}} - ) - await hashing_kv.index_done_callback() - return response.choices[0].message.content - - -async def gpt_4o_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - model = os.environ.get("OPENAI_MODEL", "gpt-4o") - return await openai_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - -async def gpt_35_turbo_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Use a model from env if available, otherwise fallback to GPT-3.5 - model = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo") - return await openai_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -async def gpt_4o_mini_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Use a model from env if available, otherwise fallback to GPT-4o-mini - model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") - return await openai_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) -async def gpt_custom_model_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Get model name from environment variables or fallback to a default - model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) - return await openai_complete_if_cache( - model_name, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), -) -async def deepseek_complete_if_cache( - model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs -) -> str: - session = get_deepseek_session() - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - - # Get model name from environment variables or use default - if model is None: - model = os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat")) + # Determine model, preferring explicit > env var > default + effective_model = model or os.environ.get("COHERE_CHAT_MODEL", "command-r") - # Prepare messages - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - # Check cache if available - if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] - - # Prepare request - payload = { - "model": model, - "messages": messages, + # Format history messages for Cohere API + chat_history = _format_chat_history_for_cohere(history_messages or []) + + # Arguments for hashing and API call (excluding non-API kwargs like hashing_kv) + api_args = { + "model": effective_model, + "message": prompt, + "preamble": system_prompt, + "chat_history": chat_history, + **kwargs # Pass through other Cohere-specific args like temperature, max_tokens } - # Add additional parameters - for key, value in kwargs.items(): - if key in ["temperature", "top_p", "max_tokens", "stream"]: - payload[key] = value - - # Make API request - deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") - async with session.post(f"{deepseek_base_url}/v1/chat/completions", json=payload) as response: - if response.status != 200: - error_text = await response.text() - raise RuntimeError(f"DeepSeek API error: {response.status} - {error_text}") - - response_json = await response.json() - completion = response_json["choices"][0]["message"]["content"] - - # Cache the response if enabled - if hashing_kv is not None: - await hashing_kv.upsert( - {args_hash: {"return": completion, "model": model}} - ) - await hashing_kv.index_done_callback() - - return completion - + # Filter out None values before hashing/calling API + api_args_filtered = {k: v for k, v in api_args.items() if v is not None} -async def deepseek_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Model can be overridden with OPENAI_MODEL_NAME for consistency with other API calls - model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat"))) - return await deepseek_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), -) -async def ollama_complete_if_cache( - model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs -) -> str: - session = get_ollama_session() - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - - # Get model name from environment variables or use default - if model is None: - model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) - - # Prepare messages - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - - # Check cache if available if hashing_kv is not None: - args_hash = compute_args_hash(model, messages) + # Use filtered args for hashing to ensure consistency + args_hash = compute_args_hash(api_args_filtered) if_cache_return = await hashing_kv.get_by_id(args_hash) if if_cache_return is not None: + logger.debug(f"Cache hit for Cohere completion (hash: {args_hash})") return if_cache_return["return"] - - # Prepare request - payload = { - "model": model, - "messages": messages, - "stream": False - } - - # Add additional parameters - for key, value in kwargs.items(): - if key in ["temperature", "top_p", "num_predict"]: - payload[key] = value - - # Make API request - async with session.post("/api/chat", json=payload) as response: - if response.status != 200: - error_text = await response.text() - raise RuntimeError(f"Ollama API error: {response.status} - {error_text}") - - response_json = await response.json() - completion = response_json["message"]["content"] - - # Cache the response if enabled - if hashing_kv is not None: - await hashing_kv.upsert( - {args_hash: {"return": completion, "model": model}} - ) - await hashing_kv.index_done_callback() - - return completion + logger.debug(f"Cache miss for Cohere completion (hash: {args_hash})") + try: + logger.debug(f"Calling Cohere chat API with args: {api_args_filtered}") + response = await cohere_async_client.chat(**api_args_filtered) + completion_text = response.text + logger.debug(f"Received Cohere chat response. Text length: {len(completion_text)}") -async def ollama_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) - return await ollama_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) -@retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError)), -) -async def openai_embedding(texts: list[str]) -> np.ndarray: - openai_async_client = get_openai_async_client_instance() - # Use model from env if available - model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) - response = await openai_async_client.embeddings.create( - model=model, input=texts, encoding_format="float" - ) - return np.array([dp.embedding for dp in response.data]) - - -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError)), -) -async def azure_openai_complete_if_cache( - deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - azure_openai_client = get_azure_openai_async_client_instance() - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - messages = [] - if system_prompt: - messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) - messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - args_hash = compute_args_hash(deployment_name, messages) - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] - - response = await azure_openai_client.chat.completions.create( - model=deployment_name, messages=messages, **kwargs - ) + except (CohereAPIError, CohereConnectionError, CohereRateLimitError) as e: + logger.error(f"Cohere API error during chat completion: {e}") + raise # Reraise to trigger tenacity retry or final failure + except Exception as e: + logger.exception(f"An unexpected error occurred during Cohere chat completion: {e}") + raise # Reraise unexpected errors if hashing_kv is not None: await hashing_kv.upsert( - { - args_hash: { - "return": response.choices[0].message.content, - "model": deployment_name, - } - } + {args_hash: {"return": completion_text, "model": effective_model}} ) + # Assuming index_done_callback is for batching/finalizing writes await hashing_kv.index_done_callback() - return response.choices[0].message.content - + logger.debug(f"Cached Cohere completion result (hash: {args_hash})") -async def azure_gpt_4o_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Use model from env if available - model = os.environ.get("OPENAI_MODEL", "gpt-4o") - return await azure_openai_complete_if_cache( - model, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) + return completion_text -async def azure_gpt_4o_mini_complete( - prompt, system_prompt=None, history_messages=[], **kwargs +async def cohere_complete( + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict] | None = None, + model: str | None = None, + **kwargs ) -> str: - # Use model from env if available - model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") - return await azure_openai_complete_if_cache( - model, - prompt, + """ + High-level wrapper for Cohere chat completion using default settings. + + Args: + prompt (str): The user's prompt/message. + system_prompt (str | None): The system prompt (preamble for Cohere). + history_messages (list[dict] | None): List of previous messages. + model (str | None): Specific Cohere model to use. Overrides defaults. + **kwargs: Additional arguments for cohere_complete_if_cache (including hashing_kv). + + Returns: + str: The generated text content. + """ + # Model resolution happens inside cohere_complete_if_cache + return await cohere_complete_if_cache( + model=model, + prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) -async def azure_openai_custom_model_complete( - prompt, system_prompt=None, history_messages=[], **kwargs -) -> str: - # Get model name from environment variables or fallback to a default - model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) - return await azure_openai_complete_if_cache( - model_name, - prompt, - system_prompt=system_prompt, - history_messages=history_messages, - **kwargs, - ) - - -@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +# --- Embedding Function --- + +# Example dimensions for common Cohere v3 models +# See: https://docs.cohere.com/reference/embed +COHERE_EMBED_DIMS = { + "embed-english-v3.0": 1024, + "embed-multilingual-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-light-v3.0": 384, + "embed-english-v2.0": 4096, + "embed-english-light-v2.0": 1024, + "embed-multilingual-v2.0": 768, +} +# Recommended max tokens (not a hard limit enforced by API) +COHERE_EMBED_MAX_TOKENS = 512 # Cohere recommends under 512 for optimal quality + +@wrap_embedding_func_with_attrs( # Decorator might need adjustment based on actual model used + embedding_dim=COHERE_EMBED_DIMS.get(os.environ.get("COHERE_EMBEDDING_MODEL", "embed-english-v3.0"), 1024), # Default to common model dim + max_token_size=COHERE_EMBED_MAX_TOKENS # Use recommended token size ) -async def azure_openai_embedding(texts: list[str]) -> np.ndarray: - azure_openai_client = get_azure_openai_async_client_instance() - # Use model from env if available - model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) - response = await azure_openai_client.embeddings.create( - model=model, input=texts, encoding_format="float" - ) - return np.array([dp.embedding for dp in response.data]) - - -@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((CohereRateLimitError, CohereConnectionError, CohereAPIError)), + reraise=True ) -async def deepseek_embedding(texts: list[str]) -> np.ndarray: - session = get_deepseek_session() +async def cohere_embedding( + texts: list[str], + model: str | None = None, + input_type: str = "search_document", + embedding_types: list[str] | None = None, # Allow overriding default ['float'] + hashing_kv: BaseKVStorage | None = None # Added for potential caching +) -> np.ndarray: + """ + Generates embeddings for a list of texts using the Cohere API. + + Args: + texts (list[str]): A list of strings to embed. + model (str | None): The Cohere embedding model ID. Defaults to COHERE_EMBEDDING_MODEL env var or 'embed-english-v3.0'. + input_type (str): Specifies the type of input passed to the model (v3+). + Examples: "search_document", "search_query", "classification", "clustering". + Defaults to "search_document". + embedding_types (list[str] | None): Specifies the desired embedding types (e.g., ['float', 'int8']). + Defaults to ['float']. + hashing_kv (BaseKVStorage | None): Optional KV store for caching results. + + Returns: + np.ndarray: A numpy array where each row is the embedding for the corresponding text. + Returns only the 'float' embeddings if multiple types are requested but caching is not implemented for multiple types. + + Raises: + CohereAPIError, CohereConnectionError, CohereRateLimitError: If API calls fail after retries. + ValueError: If texts list is empty. + """ + if not texts: + logger.warning("Received empty list of texts for embedding. Returning empty array.") + return np.array([]) + # Alternatively: raise ValueError("Texts list cannot be empty for cohere_embedding") + + cohere_async_client = get_cohere_async_client_instance() + + # Determine model, preferring explicit > env var > default + effective_model = model or os.environ.get("COHERE_EMBEDDING_MODEL", "embed-english-v3.0") - # Get embedding model from environment variables - model = os.environ.get("DEEPSEEK_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-embedding")) + # Default embedding types if not specified + effective_embedding_types = embedding_types or ["float"] - # Prepare request payload - payload = { - "model": model, - "input": texts, - "encoding_format": "float" + # Arguments for hashing and API call + api_args = { + "model": effective_model, + "texts": texts, + "input_type": input_type, + "embedding_types": effective_embedding_types, + # Add truncate parameter if needed, e.g., "truncate": "END" } - - # Make API request - deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") - async with session.post(f"{deepseek_base_url}/v1/embeddings", json=payload) as response: - if response.status != 200: - error_text = await response.text() - raise RuntimeError(f"DeepSeek Embedding API error: {response.status} - {error_text}") - - response_json = await response.json() - embeddings = [data["embedding"] for data in response_json["data"]] - - return np.array(embeddings) + # --- Caching Logic (Optional for Embeddings) --- + args_hash = None + if hashing_kv is not None: + args_hash = compute_args_hash(api_args) # Hash includes all relevant params + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + logger.debug(f"Cache hit for Cohere embedding (hash: {args_hash})") + # Assuming cache stores numpy array directly or can reconstruct it + # This might need adjustment based on how caching_kv stores/retrieves complex types + cached_data = if_cache_return.get("return") + if isinstance(cached_data, list): # Simple check if it was stored as list + return np.array(cached_data) + elif isinstance(cached_data, np.ndarray): + return cached_data + else: + logger.warning(f"Cached embedding data format unexpected (hash: {args_hash}). Re-fetching.") + # Fall through to fetch if format is wrong + else: + logger.debug(f"Cache miss for Cohere embedding (hash: {args_hash})") + # --- End Caching Logic --- -@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) -@retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), -) -async def ollama_embedding(texts: list[str]) -> np.ndarray: - session = get_ollama_session() - - # Get embedding model from environment variables - # Ollama might use the same model for completion and embedding - model = os.environ.get("OLLAMA_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) - - # Ollama can only process one text at a time for embeddings - all_embeddings = [] - for text in texts: - # Prepare request payload - payload = { - "model": model, - "prompt": text, - } + try: + logger.debug(f"Calling Cohere embed API with model '{effective_model}', {len(texts)} texts, input_type '{input_type}'") + response = await cohere_async_client.embed(**api_args) - # Make API request - async with session.post("/api/embeddings", json=payload) as response: - if response.status != 200: - error_text = await response.text() - raise RuntimeError(f"Ollama Embedding API error: {response.status} - {error_text}") - - response_json = await response.json() - embeddings = response_json["embedding"] - all_embeddings.append(embeddings) + # Extract embeddings - prioritizing 'float' if available + if hasattr(response, 'embeddings') and 'float' in response.embeddings: + embeddings_list = response.embeddings['float'] + elif hasattr(response, 'embeddings') and effective_embedding_types[0] in response.embeddings: + # Fallback to the first requested type if float isn't there (e.g., if only 'int8' was requested) + embeddings_list = response.embeddings[effective_embedding_types[0]] + logger.warning(f"Returning '{effective_embedding_types[0]}' embeddings as 'float' was not found in response.") + elif isinstance(response.embeddings, list): # Handle older API or potential variations + embeddings_list = response.embeddings + logger.warning("Cohere embed response format unexpected (expected dict with types), using direct list.") + else: + logger.error(f"Could not extract embeddings from Cohere response. Response keys: {list(response.embeddings.keys()) if hasattr(response, 'embeddings') and isinstance(response.embeddings, dict) else 'N/A'}") + raise ValueError("Failed to extract embeddings from Cohere API response.") + + result_array = np.array(embeddings_list) + logger.debug(f"Received Cohere embeddings. Shape: {result_array.shape}") + + except (CohereAPIError, CohereConnectionError, CohereRateLimitError) as e: + logger.error(f"Cohere API error during embedding: {e}") + raise + except Exception as e: + logger.exception(f"An unexpected error occurred during Cohere embedding: {e}") + raise + + # --- Caching Save Logic --- + if hashing_kv is not None and args_hash is not None: + # Store as list for broader compatibility, can be adjusted + await hashing_kv.upsert( + {args_hash: {"return": result_array.tolist(), "model": effective_model}} + ) + await hashing_kv.index_done_callback() + logger.debug(f"Cached Cohere embedding result (hash: {args_hash})") + # --- End Caching Save Logic --- - return np.array(all_embeddings) + # Dynamically update the decorator's attributes based on the actual model used, if needed + # This part is complex as the decorator is applied at definition time. + # A simpler approach is to ensure the decorator uses the default model's info, + # or remove the dimension/token checks if they cause issues with dynamic models. + # For now, we rely on the initial decorator values based on environment or defaults. + # Optionally, log a warning if the used model's known dim differs from decorator's: + # known_dim = COHERE_EMBED_DIMS.get(effective_model) + # if known_dim and known_dim != cohere_embedding.embedding_dim: + # logger.warning(f"Model '{effective_model}' has dimension {known_dim}, but decorator uses {cohere_embedding.embedding_dim}.") + + return result_array + +# Example Usage (can be removed or placed under if __name__ == "__main__":) +async def example_main(): + # Ensure COHERE_API_KEY is set as an environment variable + if not os.environ.get("COHERE_API_KEY"): + print("Please set the COHERE_API_KEY environment variable.") + return + + print("--- Testing Cohere Completion ---") + try: + completion = await cohere_complete( + prompt="What is the capital of France?", + # system_prompt="Respond concisely.", # Optional Preamble + # model="command-r-plus" # Optional: override default + ) + print(f"Completion Result: {completion}") + except Exception as e: + print(f"Completion failed: {e}") + + print("--- Testing Cohere Embedding ---") + try: + texts_to_embed = ["hello world", "large language model"] + embeddings = await cohere_embedding( + texts=texts_to_embed, + input_type="search_document", # Or "search_query", "classification", etc. + # model="embed-english-v3.0" # Optional: override default + ) + print(f"Embedding Result Shape: {embeddings.shape}") + # print(f"First Embedding (first 5 dims): {embeddings[0][:5]}") + except Exception as e: + print(f"Embedding failed: {e}") + +# if __name__ == "__main__": +# import asyncio +# # Note: Running top-level async requires asyncio.run() +# # asyncio.run(example_main()) +# pass # Keep clean for import diff --git a/hirag/_llm_backup.py b/hirag/_llm_backup.py new file mode 100644 index 0000000..e9590d2 --- /dev/null +++ b/hirag/_llm_backup.py @@ -0,0 +1,503 @@ +import numpy as np + +from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError +import aiohttp +import json + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +import os + +from ._utils import compute_args_hash, wrap_embedding_func_with_attrs +from .base import BaseKVStorage + +global_openai_async_client = None +global_azure_openai_async_client = None +global_deepseek_session = None +global_ollama_session = None + + +def get_openai_async_client_instance(): + global global_openai_async_client + if global_openai_async_client is None: + # Check for environment variables for custom OpenAI configuration + base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) + api_key = os.environ.get("OPENAI_API_KEY", 'ollama') # Default to ollama if not set + + # Create the client with the environment variables if they exist + if base_url: + global_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) + else: + global_openai_async_client = AsyncOpenAI() + return global_openai_async_client + + +def get_azure_openai_async_client_instance(): + global global_azure_openai_async_client + if global_azure_openai_async_client is None: + # Check for environment variables for custom Azure configuration + api_version = os.environ.get("AZURE_OPENAI_API_VERSION", "2023-05-15") + azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", None) + azure_key = os.environ.get("AZURE_OPENAI_API_KEY", None) + + # If Azure configuration is available, use it + if azure_endpoint and azure_key: + global_azure_openai_async_client = AsyncAzureOpenAI( + api_version=api_version, + azure_endpoint=azure_endpoint, + api_key=azure_key + ) + else: + # Fall back to the OpenAI client with Azure-compatible settings + base_url = os.environ.get("OPENAI_API_BASE", os.environ.get("OPENAI_BASE_URL", None)) + api_key = os.environ.get("OPENAI_API_KEY", "ollama") + + if base_url: + global_azure_openai_async_client = AsyncOpenAI(base_url=base_url, api_key=api_key) + else: + global_azure_openai_async_client = AsyncOpenAI() + return global_azure_openai_async_client + + +def get_deepseek_session(): + """Get or create a DeepSeek API session""" + global global_deepseek_session + if global_deepseek_session is None: + global_deepseek_session = aiohttp.ClientSession( + headers={ + "Authorization": f"Bearer {os.environ.get('DEEPSEEK_API_KEY', '')}", + "Content-Type": "application/json" + } + ) + return global_deepseek_session + + +def get_ollama_session(): + """Get or create an Ollama API session""" + global global_ollama_session + if global_ollama_session is None: + ollama_base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") + + # Set up basic authentication if provided + headers = {"Content-Type": "application/json"} + auth = None + if os.environ.get("OLLAMA_API_KEY"): + headers["Authorization"] = f"Bearer {os.environ.get('OLLAMA_API_KEY')}" + + global_ollama_session = aiohttp.ClientSession( + base_url=ollama_base_url, + headers=headers, + auth=auth + ) + return global_ollama_session + + +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def openai_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + openai_async_client = get_openai_async_client_instance() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + await hashing_kv.index_done_callback() + return response.choices[0].message.content + + +async def gpt_4o_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + model = os.environ.get("OPENAI_MODEL", "gpt-4o") + return await openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + +async def gpt_35_turbo_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Use a model from env if available, otherwise fallback to GPT-3.5 + model = os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo") + return await openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def gpt_4o_mini_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Use a model from env if available, otherwise fallback to GPT-4o-mini + model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + return await openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + +async def gpt_custom_model_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Get model name from environment variables or fallback to a default + model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) + return await openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def deepseek_complete_if_cache( + model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + session = get_deepseek_session() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + # Get model name from environment variables or use default + if model is None: + model = os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat")) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Check cache if available + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + # Prepare request + payload = { + "model": model, + "messages": messages, + } + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["temperature", "top_p", "max_tokens", "stream"]: + payload[key] = value + + # Make API request + deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") + async with session.post(f"{deepseek_base_url}/v1/chat/completions", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"DeepSeek API error: {response.status} - {error_text}") + + response_json = await response.json() + completion = response_json["choices"][0]["message"]["content"] + + # Cache the response if enabled + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": completion, "model": model}} + ) + await hashing_kv.index_done_callback() + + return completion + + +async def deepseek_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Model can be overridden with OPENAI_MODEL_NAME for consistency with other API calls + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("DEEPSEEK_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-chat"))) + return await deepseek_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def ollama_complete_if_cache( + model=None, prompt=None, system_prompt=None, history_messages=[], **kwargs +) -> str: + session = get_ollama_session() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + + # Get model name from environment variables or use default + if model is None: + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Check cache if available + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + # Prepare request + payload = { + "model": model, + "messages": messages, + "stream": False + } + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["temperature", "top_p", "num_predict"]: + payload[key] = value + + # Make API request + async with session.post("/api/chat", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Ollama API error: {response.status} - {error_text}") + + response_json = await response.json() + completion = response_json["message"]["content"] + + # Cache the response if enabled + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": completion, "model": model}} + ) + await hashing_kv.index_done_callback() + + return completion + + +async def ollama_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + model = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + return await ollama_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def openai_embedding(texts: list[str]) -> np.ndarray: + openai_async_client = get_openai_async_client_instance() + # Use model from env if available + model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) + response = await openai_async_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def azure_openai_complete_if_cache( + deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + azure_openai_client = get_azure_openai_async_client_instance() + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(deployment_name, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + response = await azure_openai_client.chat.completions.create( + model=deployment_name, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + { + args_hash: { + "return": response.choices[0].message.content, + "model": deployment_name, + } + } + ) + await hashing_kv.index_done_callback() + return response.choices[0].message.content + + +async def azure_gpt_4o_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Use model from env if available + model = os.environ.get("OPENAI_MODEL", "gpt-4o") + return await azure_openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def azure_gpt_4o_mini_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Use model from env if available + model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + return await azure_openai_complete_if_cache( + model, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + +async def azure_openai_custom_model_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + # Get model name from environment variables or fallback to a default + model_name = os.environ.get("OPENAI_MODEL_NAME", os.environ.get("OPENAI_MODEL", "llama3")) + return await azure_openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError)), +) +async def azure_openai_embedding(texts: list[str]) -> np.ndarray: + azure_openai_client = get_azure_openai_async_client_instance() + # Use model from env if available + model = os.environ.get("OPENAI_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "text-embedding-3-small")) + response = await azure_openai_client.embeddings.create( + model=model, input=texts, encoding_format="float" + ) + return np.array([dp.embedding for dp in response.data]) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def deepseek_embedding(texts: list[str]) -> np.ndarray: + session = get_deepseek_session() + + # Get embedding model from environment variables + model = os.environ.get("DEEPSEEK_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", "deepseek-embedding")) + + # Prepare request payload + payload = { + "model": model, + "input": texts, + "encoding_format": "float" + } + + # Make API request + deepseek_base_url = os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com") + async with session.post(f"{deepseek_base_url}/v1/embeddings", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"DeepSeek Embedding API error: {response.status} - {error_text}") + + response_json = await response.json() + embeddings = [data["embedding"] for data in response_json["data"]] + + return np.array(embeddings) + + +@wrap_embedding_func_with_attrs(embedding_dim=3584, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), +) +async def ollama_embedding(texts: list[str]) -> np.ndarray: + session = get_ollama_session() + + # Get embedding model from environment variables + # Ollama might use the same model for completion and embedding + model = os.environ.get("OLLAMA_EMBEDDING_MODEL", os.environ.get("OPENAI_MODEL", os.environ.get("GLM_MODEL", "llama3"))) + + # Ollama can only process one text at a time for embeddings + all_embeddings = [] + for text in texts: + # Prepare request payload + payload = { + "model": model, + "prompt": text, + } + + # Make API request + async with session.post("/api/embeddings", json=payload) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Ollama Embedding API error: {response.status} - {error_text}") + + response_json = await response.json() + embeddings = response_json["embedding"] + all_embeddings.append(embeddings) + + return np.array(all_embeddings) diff --git a/hirag/hirag.py b/hirag/hirag.py index 15846f3..b6ee088 100644 --- a/hirag/hirag.py +++ b/hirag/hirag.py @@ -21,7 +21,9 @@ deepseek_embedding, deepseek_complete, ollama_embedding, - ollama_complete + ollama_complete, + cohere_embedding, + cohere_complete ) from ._op import ( chunking_by_token_size,