33
44from langchain_core .language_models .llms import Generation , LLMResult
55from langchain_core .prompt_values import PromptValue
6- from llama_stack .apis .inference import EmbeddingTaskType
6+ from llama_stack .apis .inference import SamplingParams , TopPSamplingStrategy
77from ragas .embeddings .base import BaseRagasEmbeddings
88from ragas .llms .base import BaseRagasLLM
99from ragas .run_config import RunConfig
@@ -39,25 +39,23 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
3939 async def aembed_documents (self , texts : list [str ]) -> list [list [float ]]:
4040 """Embed documents using Llama Stack inference API."""
4141 try :
42- response = await self .inference_api .embeddings (
43- model_id = self .embedding_model_id ,
44- contents = texts ,
45- task_type = EmbeddingTaskType .document ,
42+ response = await self .inference_api .openai_embeddings (
43+ model = self .embedding_model_id ,
44+ input = texts ,
4645 )
47- return response . embeddings # type: ignore
46+ return [ data . embedding for data in response . data ]
4847 except Exception as e :
4948 logger .error (f"Document embedding failed: { str (e )} " )
5049 raise
5150
5251 async def aembed_query (self , text : str ) -> list [float ]:
5352 """Embed query using Llama Stack inference API."""
5453 try :
55- response = await self .inference_api .embeddings (
56- model_id = self .embedding_model_id ,
57- contents = [text ],
58- task_type = EmbeddingTaskType .query ,
54+ response = await self .inference_api .openai_embeddings (
55+ model = self .embedding_model_id ,
56+ input = text ,
5957 )
60- return response .embeddings [0 ] # type: ignore
58+ return response .data [0 ]. embedding # type: ignore
6159 except Exception as e :
6260 logger .error (f"Query embedding failed: { str (e )} " )
6361 raise
@@ -70,39 +68,14 @@ def __init__(
7068 self ,
7169 inference_api ,
7270 model_id : str ,
73- sampling_params ,
71+ sampling_params : SamplingParams | None = None ,
7472 run_config : RunConfig = RunConfig (),
7573 multiple_completion_supported : bool = True ,
7674 ):
7775 super ().__init__ (run_config , multiple_completion_supported )
7876 self .inference_api = inference_api
7977 self .model_id = model_id
8078 self .sampling_params = sampling_params
81- self .enable_prompt_logging = True
82- self .prompt_counter = 0
83-
84- def _estimate_tokens (self , text : str ) -> int :
85- """Estimate token count for a given text.
86-
87- This is a rough estimation - for accurate counts, you'd need the actual tokenizer.
88- """
89- # Rough estimation: ~4 characters per token for English text
90- return len (text ) // 4
91-
92- def _log_prompt (self , prompt_text : str , prompt_type : str = "evaluation" ) -> None :
93- """Log prompt details if enabled."""
94- if not self .enable_prompt_logging :
95- return
96-
97- self .prompt_counter += 1
98- estimated_tokens = self ._estimate_tokens (prompt_text )
99-
100- logger .info (f"=== RAGAS PROMPT #{ self .prompt_counter } ({ prompt_type } ) ===" )
101- logger .info (f"Estimated tokens: { estimated_tokens } " )
102- logger .info (f"Character count: { len (prompt_text )} " )
103- logger .info (f"Prompt preview: { prompt_text [:200 ]} ..." )
104- logger .info (f"Full prompt:\n { prompt_text } " )
105- logger .info ("=" * 50 )
10679
10780 def generate_text (
10881 self ,
@@ -126,64 +99,56 @@ async def agenerate_text(
12699 ) -> LLMResult :
127100 """Asynchronous text generation using Llama Stack inference API."""
128101 try :
129- # Convert PromptValue to string
130- prompt_text = prompt .to_string ()
131-
132- # Log the prompt if enabled
133- self ._log_prompt (prompt_text )
134-
135- # Create sampling params for this generation
136- gen_sampling_params = self .sampling_params
137- if temperature is not None :
138- # Update temperature if provided
139- gen_sampling_params = (
140- gen_sampling_params .copy ()
141- if hasattr (gen_sampling_params , "copy" )
142- else gen_sampling_params
143- )
144- if hasattr (gen_sampling_params , "temperature" ):
145- gen_sampling_params .temperature = temperature
146-
147- # Generate responses (handle multiple completions if n > 1)
148102 generations = []
149103 llm_output = {
150104 "llama_stack_responses" : [],
151105 "model_id" : self .model_id ,
152106 "provider" : "llama_stack" ,
153107 }
154108
109+ # sampling params for this generation should be set via the benchmark config
110+ # we will ignore the temperature and stop params passed in here
155111 for _ in range (n ):
156- response = await self .inference_api .completion (
157- model_id = self .model_id ,
158- content = prompt_text ,
159- sampling_params = gen_sampling_params ,
112+ response = await self .inference_api .openai_completion (
113+ model = self .model_id ,
114+ prompt = prompt .to_string (),
115+ max_tokens = self .sampling_params .max_tokens
116+ if self .sampling_params
117+ else None ,
118+ temperature = self .sampling_params .strategy .temperature
119+ if self .sampling_params
120+ and isinstance (self .sampling_params .strategy , TopPSamplingStrategy )
121+ else None ,
122+ top_p = self .sampling_params .strategy .top_p
123+ if self .sampling_params
124+ and isinstance (self .sampling_params .strategy , TopPSamplingStrategy )
125+ else None ,
126+ stop = self .sampling_params .stop if self .sampling_params else None ,
160127 )
161128
129+ if not response .choices :
130+ logger .warning ("Completion response returned no choices" )
131+
132+ # Extract text from OpenAI completion response
133+ choice = response .choices [0 ] if response .choices else None
134+ text = choice .text if choice else ""
135+
162136 # Store Llama Stack response info in llm_output
163137 llama_stack_info = {
164- "stop_reason" : (
165- response .stop_reason .value if response .stop_reason else None
166- ),
167- "content_length" : len (response .content ),
168- "has_logprobs" : response .logprobs is not None ,
169- "logprobs_count" : (
170- len (response .logprobs ) if response .logprobs else 0
171- ),
138+ "stop_reason" : (choice .finish_reason if choice else None ),
139+ "content_length" : len (text ),
140+ "has_logprobs" : choice .logprobs is not None if choice else False ,
172141 }
173142 llm_output ["llama_stack_responses" ].append (llama_stack_info ) # type: ignore
174143
175- generations .append (Generation (text = response . content ))
144+ generations .append (Generation (text = text ))
176145
177146 return LLMResult (generations = [generations ], llm_output = llm_output )
178147
179148 except Exception as e :
180149 logger .error (f"LLM generation failed: { str (e )} " )
181150 raise
182151
183- def get_temperature (self , n : int ) -> float :
184- """Get temperature based on number of completions."""
185- return 0.3 if n > 1 else 1e-8
186-
187152 # TODO: revisit this
188153 # def is_finished(self, response: LLMResult) -> bool:
189154 # """
0 commit comments