Skip to content

Commit

Permalink
Fix CI error
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya committed Feb 18, 2025
1 parent 7e950a1 commit 23483b5
Showing 1 changed file with 39 additions and 36 deletions.
75 changes: 39 additions & 36 deletions ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,6 @@ class HttpChatBotMetaData(HttpMetaData):
def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def prepare_data(self, params: ChatBotParameters):
query = params.query
conversation_id = params.conversation_id
provider = params.provider
model_id = params.model_id
system_prompt = params.system_prompt

data = {
"query": query,
"model": model_id,
"provider": provider,
}
if conversation_id:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
return data

def self_test(self) -> Optional[HealthCheckSummary]:
summary: HealthCheckSummary = HealthCheckSummary(
{
Expand Down Expand Up @@ -184,10 +166,26 @@ def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
query = params.query
conversation_id = params.conversation_id
provider = params.provider
model_id = params.model_id
system_prompt = params.system_prompt

data = {
"query": query,
"model": model_id,
"provider": provider,
}
if conversation_id:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)

response = requests.post(
self.config.inference_url + "/v1/query",
headers=self.headers,
json=self.prepare_data(params),
json=data,
timeout=self.timeout(1),
verify=self.config.verify_ssl,
)
Expand All @@ -213,24 +211,9 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
raise ChatbotInternalServerException(detail=detail)


class HttpStreamingChatBotMetaData(HttpChatBotMetaData):

def __init__(self, config: HttpConfiguration):
super().__init__(config=config)

def prepare_data(self, params: StreamingChatBotParameters):
data = super().prepare_data(params)

media_type = params.media_type
if media_type:
data["media_type"] = str(media_type)

return data


@Register(api_type="http")
class HttpStreamingChatBotPipeline(
HttpStreamingChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
HttpChatBotMetaData, ModelPipelineStreamingChatBot[HttpConfiguration]
):

def __init__(self, config: HttpConfiguration):
Expand All @@ -245,9 +228,29 @@ async def async_invoke(self, params: StreamingChatBotParameters) -> StreamingHtt
"Content-Type": "application/json",
"Accept": "application/json,text/event-stream",
}

query = params.query
conversation_id = params.conversation_id
provider = params.provider
model_id = params.model_id
system_prompt = params.system_prompt
media_type = params.media_type

data = {
"query": query,
"model": model_id,
"provider": provider,
}
if conversation_id:
data["conversation_id"] = str(conversation_id)
if system_prompt:
data["system_prompt"] = str(system_prompt)
if media_type:
data["media_type"] = str(media_type)

async with session.post(
self.config.inference_url + "/v1/streaming_query",
json=self.prepare_data(params),
json=data,
headers=headers,
) as r:
async for chunk in r.content:
Expand Down

0 comments on commit 23483b5

Please sign in to comment.