Skip to content

Commit d5eaa31

Browse files
authored
Merge pull request #28 from dmaniloff/openai-api-downgrade-pandas
Update Ragas wrappers to use OpenAI APIs (inference & embeddings)
2 parents 8b2788f + 34cbd69 commit d5eaa31

File tree

9 files changed

+154
-165
lines changed

9 files changed

+154
-165
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "llama-stack-provider-ragas"
7-
version = "0.3.6"
7+
version = "0.4.0"
88
description = "Ragas evaluation as an out-of-tree Llama Stack provider"
99
readme = "README.md"
1010
requires-python = ">=3.12"
@@ -25,7 +25,7 @@ authors = [
2525
keywords = ["llama-stack", "ragas", "evaluation"]
2626
dependencies = [
2727
"setuptools-scm",
28-
"llama-stack==0.2.23",
28+
"llama-stack>=0.2.23",
2929
"greenlet==3.2.4", # inline/files/localfs errors saying greenlet not found
3030
"ragas==0.3.0",
3131
"pandas<2.3.0",

src/llama_stack_provider_ragas/inline/wrappers_inline.py

Lines changed: 38 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from langchain_core.language_models.llms import Generation, LLMResult
55
from langchain_core.prompt_values import PromptValue
6-
from llama_stack.apis.inference import EmbeddingTaskType
6+
from llama_stack.apis.inference import SamplingParams, TopPSamplingStrategy
77
from ragas.embeddings.base import BaseRagasEmbeddings
88
from ragas.llms.base import BaseRagasLLM
99
from ragas.run_config import RunConfig
@@ -39,25 +39,23 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
3939
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
4040
"""Embed documents using Llama Stack inference API."""
4141
try:
42-
response = await self.inference_api.embeddings(
43-
model_id=self.embedding_model_id,
44-
contents=texts,
45-
task_type=EmbeddingTaskType.document,
42+
response = await self.inference_api.openai_embeddings(
43+
model=self.embedding_model_id,
44+
input=texts,
4645
)
47-
return response.embeddings # type: ignore
46+
return [data.embedding for data in response.data]
4847
except Exception as e:
4948
logger.error(f"Document embedding failed: {str(e)}")
5049
raise
5150

5251
async def aembed_query(self, text: str) -> list[float]:
5352
"""Embed query using Llama Stack inference API."""
5453
try:
55-
response = await self.inference_api.embeddings(
56-
model_id=self.embedding_model_id,
57-
contents=[text],
58-
task_type=EmbeddingTaskType.query,
54+
response = await self.inference_api.openai_embeddings(
55+
model=self.embedding_model_id,
56+
input=text,
5957
)
60-
return response.embeddings[0] # type: ignore
58+
return response.data[0].embedding # type: ignore
6159
except Exception as e:
6260
logger.error(f"Query embedding failed: {str(e)}")
6361
raise
@@ -70,39 +68,14 @@ def __init__(
7068
self,
7169
inference_api,
7270
model_id: str,
73-
sampling_params,
71+
sampling_params: SamplingParams | None = None,
7472
run_config: RunConfig = RunConfig(),
7573
multiple_completion_supported: bool = True,
7674
):
7775
super().__init__(run_config, multiple_completion_supported)
7876
self.inference_api = inference_api
7977
self.model_id = model_id
8078
self.sampling_params = sampling_params
81-
self.enable_prompt_logging = True
82-
self.prompt_counter = 0
83-
84-
def _estimate_tokens(self, text: str) -> int:
85-
"""Estimate token count for a given text.
86-
87-
This is a rough estimation - for accurate counts, you'd need the actual tokenizer.
88-
"""
89-
# Rough estimation: ~4 characters per token for English text
90-
return len(text) // 4
91-
92-
def _log_prompt(self, prompt_text: str, prompt_type: str = "evaluation") -> None:
93-
"""Log prompt details if enabled."""
94-
if not self.enable_prompt_logging:
95-
return
96-
97-
self.prompt_counter += 1
98-
estimated_tokens = self._estimate_tokens(prompt_text)
99-
100-
logger.info(f"=== RAGAS PROMPT #{self.prompt_counter} ({prompt_type}) ===")
101-
logger.info(f"Estimated tokens: {estimated_tokens}")
102-
logger.info(f"Character count: {len(prompt_text)}")
103-
logger.info(f"Prompt preview: {prompt_text[:200]}...")
104-
logger.info(f"Full prompt:\n{prompt_text}")
105-
logger.info("=" * 50)
10679

10780
def generate_text(
10881
self,
@@ -126,64 +99,56 @@ async def agenerate_text(
12699
) -> LLMResult:
127100
"""Asynchronous text generation using Llama Stack inference API."""
128101
try:
129-
# Convert PromptValue to string
130-
prompt_text = prompt.to_string()
131-
132-
# Log the prompt if enabled
133-
self._log_prompt(prompt_text)
134-
135-
# Create sampling params for this generation
136-
gen_sampling_params = self.sampling_params
137-
if temperature is not None:
138-
# Update temperature if provided
139-
gen_sampling_params = (
140-
gen_sampling_params.copy()
141-
if hasattr(gen_sampling_params, "copy")
142-
else gen_sampling_params
143-
)
144-
if hasattr(gen_sampling_params, "temperature"):
145-
gen_sampling_params.temperature = temperature
146-
147-
# Generate responses (handle multiple completions if n > 1)
148102
generations = []
149103
llm_output = {
150104
"llama_stack_responses": [],
151105
"model_id": self.model_id,
152106
"provider": "llama_stack",
153107
}
154108

109+
# sampling params for this generation should be set via the benchmark config
110+
# we will ignore the temperature and stop params passed in here
155111
for _ in range(n):
156-
response = await self.inference_api.completion(
157-
model_id=self.model_id,
158-
content=prompt_text,
159-
sampling_params=gen_sampling_params,
112+
response = await self.inference_api.openai_completion(
113+
model=self.model_id,
114+
prompt=prompt.to_string(),
115+
max_tokens=self.sampling_params.max_tokens
116+
if self.sampling_params
117+
else None,
118+
temperature=self.sampling_params.strategy.temperature
119+
if self.sampling_params
120+
and isinstance(self.sampling_params.strategy, TopPSamplingStrategy)
121+
else None,
122+
top_p=self.sampling_params.strategy.top_p
123+
if self.sampling_params
124+
and isinstance(self.sampling_params.strategy, TopPSamplingStrategy)
125+
else None,
126+
stop=self.sampling_params.stop if self.sampling_params else None,
160127
)
161128

129+
if not response.choices:
130+
logger.warning("Completion response returned no choices")
131+
132+
# Extract text from OpenAI completion response
133+
choice = response.choices[0] if response.choices else None
134+
text = choice.text if choice else ""
135+
162136
# Store Llama Stack response info in llm_output
163137
llama_stack_info = {
164-
"stop_reason": (
165-
response.stop_reason.value if response.stop_reason else None
166-
),
167-
"content_length": len(response.content),
168-
"has_logprobs": response.logprobs is not None,
169-
"logprobs_count": (
170-
len(response.logprobs) if response.logprobs else 0
171-
),
138+
"stop_reason": (choice.finish_reason if choice else None),
139+
"content_length": len(text),
140+
"has_logprobs": choice.logprobs is not None if choice else False,
172141
}
173142
llm_output["llama_stack_responses"].append(llama_stack_info) # type: ignore
174143

175-
generations.append(Generation(text=response.content))
144+
generations.append(Generation(text=text))
176145

177146
return LLMResult(generations=[generations], llm_output=llm_output)
178147

179148
except Exception as e:
180149
logger.error(f"LLM generation failed: {str(e)}")
181150
raise
182151

183-
def get_temperature(self, n: int) -> float:
184-
"""Get temperature based on number of completions."""
185-
return 0.3 if n > 1 else 1e-8
186-
187152
# TODO: revisit this
188153
# def is_finished(self, response: LLMResult) -> bool:
189154
# """

src/llama_stack_provider_ragas/remote/kubeflow/components.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def run_ragas_evaluation(
8585
import logging
8686

8787
import pandas as pd
88+
from llama_stack.apis.inference import SamplingParams
8889
from ragas import EvaluationDataset, evaluate
8990
from ragas.dataset_schema import EvaluationResult
9091
from ragas.run_config import RunConfig
@@ -99,10 +100,14 @@ def run_ragas_evaluation(
99100
logger = logging.getLogger(__name__)
100101
logger.setLevel(logging.INFO)
101102

103+
# sampling_params is passed in from the benchmark config as model_dump()
104+
# we need to convert it back to a SamplingParams object
105+
sampling_params_obj = SamplingParams.model_validate(sampling_params)
106+
102107
llm = LlamaStackRemoteLLM(
103108
base_url=llama_stack_base_url,
104109
model_id=model,
105-
sampling_params=sampling_params,
110+
sampling_params=sampling_params_obj,
106111
)
107112
embeddings = LlamaStackRemoteEmbeddings(
108113
base_url=llama_stack_base_url,

src/llama_stack_provider_ragas/remote/ragas_remote_eval.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -181,18 +181,6 @@ async def run_eval(
181181
async def _submit_to_kubeflow(self, job: RagasEvaluationJob) -> str:
182182
from .kubeflow.pipeline import ragas_evaluation_pipeline
183183

184-
# temperature = (
185-
# job.runtime_config.benchmark_config.sampling_params.temperature
186-
# if job.runtime_config.benchmark_config.sampling_params.strategy.type
187-
# == "top_p"
188-
# else None
189-
# )
190-
191-
# sampling_params = {
192-
# "temperature": temperature,
193-
# "max_tokens": job.runtime_config.benchmark_config.sampling_params.max_tokens,
194-
# }
195-
196184
pipeline_args = {
197185
"dataset_id": job.runtime_config.benchmark.dataset_id,
198186
"llama_stack_base_url": job.runtime_config.kubeflow_config.llama_stack_url,
@@ -202,7 +190,9 @@ async def _submit_to_kubeflow(self, job: RagasEvaluationJob) -> str:
202190
else -1
203191
),
204192
"model": job.runtime_config.benchmark_config.eval_candidate.model,
205-
"sampling_params": job.runtime_config.benchmark_config.eval_candidate.sampling_params.model_dump(),
193+
"sampling_params": job.runtime_config.benchmark_config.eval_candidate.sampling_params.model_dump(
194+
exclude_none=True
195+
),
206196
"embedding_model": self.config.embedding_model,
207197
"metrics": job.runtime_config.benchmark.scoring_functions,
208198
"result_s3_location": job.result_s3_location,

0 commit comments

Comments
 (0)