Skip to content

Commit a2086b5

Browse files
committed
hack
1 parent 8abca70 commit a2086b5

File tree

5 files changed

+185
-175
lines changed

5 files changed

+185
-175
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

+39-72
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
import json
1616
import logging
17-
from json import JSONDecodeError
17+
18+
# from json import JSONDecodeError
1819
from typing import AsyncGenerator
1920

20-
import aiohttp
21+
# import aiohttp
2122
import requests
2223
from django.http import StreamingHttpResponse
2324
from health_check.exceptions import ServiceUnavailable
@@ -238,74 +239,40 @@ def get_streaming_http_response(
238239
)
239240

240241
async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator:
241-
async with aiohttp.ClientSession(raise_for_status=True) as session:
242-
headers = {
243-
"Content-Type": "application/json",
244-
"Accept": "application/json,text/event-stream",
245-
}
246242

247-
query = params.query
248-
conversation_id = params.conversation_id
249-
provider = params.provider
250-
model_id = params.model_id
251-
system_prompt = params.system_prompt
252-
media_type = params.media_type
253-
254-
data = {
255-
"query": query,
256-
"model": model_id,
257-
"provider": provider,
258-
}
259-
if conversation_id:
260-
data["conversation_id"] = str(conversation_id)
261-
if system_prompt:
262-
data["system_prompt"] = str(system_prompt)
263-
if media_type:
264-
data["media_type"] = str(media_type)
265-
266-
async with session.post(
267-
self.config.inference_url + "/v1/streaming_query",
268-
json=data,
269-
headers=headers,
270-
raise_for_status=False,
271-
) as response:
272-
if response.status == 200:
273-
async for chunk in response.content:
274-
try:
275-
if chunk:
276-
s = chunk.decode("utf-8").strip()
277-
if s and s.startswith("data: "):
278-
o = json.loads(s[len("data: ") :])
279-
if o["event"] == "error":
280-
default_data = {
281-
"response": "(not provided)",
282-
"cause": "(not provided)",
283-
}
284-
data = o.get("data", default_data)
285-
logger.error(
286-
"An error received in chat streaming content:"
287-
+ " response="
288-
+ data.get("response")
289-
+ ", cause="
290-
+ data.get("cause")
291-
)
292-
except JSONDecodeError:
293-
pass
294-
logger.debug(chunk)
295-
yield chunk
296-
else:
297-
logging.error(
298-
"Streaming query API returned status code="
299-
+ str(response.status)
300-
+ ", reason="
301-
+ str(response.reason)
302-
)
303-
error = {
304-
"event": "error",
305-
"data": {
306-
"response": f"Non-200 status code ({response.status}) was received.",
307-
"cause": response.reason,
308-
},
309-
}
310-
yield json.dumps(error).encode("utf-8")
311-
return
243+
query = params.query
244+
# conversation_id = params.conversation_id
245+
# provider = params.provider
246+
# model_id = params.model_id
247+
# system_prompt = params.system_prompt
248+
# media_type = params.media_type
249+
chatbackend_config = params.chatbackend_config
250+
# client = chatbackend_config.get("client")
251+
agent = chatbackend_config.get("agent")
252+
253+
# agent = AsyncAgent(client, chatbackend_config.get('agent_config'))
254+
session_id = await agent.create_session("lightspeed-session")
255+
256+
response = await agent.create_turn(
257+
# agent_id=chatbackend_config.get('agent_id'),
258+
messages=[
259+
{
260+
"role": "user",
261+
"content": query,
262+
}
263+
],
264+
stream=True,
265+
session_id=session_id,
266+
)
267+
async for chunk in response:
268+
if hasattr(chunk, "event"):
269+
print(chunk.event)
270+
if (
271+
hasattr(chunk, "event")
272+
and hasattr(chunk.event, "payload")
273+
and hasattr(chunk.event.payload, "event_type")
274+
and chunk.event.payload.event_type == "turn_complete"
275+
):
276+
print(" *** ")
277+
print(chunk.event.payload.turn.output_message)
278+
yield chunk

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

+3
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def init(
249249
@define
250250
class StreamingChatBotParameters(ChatBotParameters):
251251
media_type: str
252+
chatbackend_config: dict[str, Any]
252253

253254
@classmethod
254255
def init(
@@ -259,6 +260,7 @@ def init(
259260
conversation_id: Optional[str] = None,
260261
system_prompt: Optional[str] = None,
261262
media_type: Optional[str] = None,
263+
chatbackend_config: dict[str, Any] = {},
262264
):
263265
return cls(
264266
query=query,
@@ -267,6 +269,7 @@ def init(
267269
conversation_id=conversation_id,
268270
system_prompt=system_prompt,
269271
media_type=media_type,
272+
chatbackend_config=chatbackend_config,
270273
)
271274

272275

ansible_ai_connect/ai/api/views.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from django.http import StreamingHttpResponse
2525
from django_prometheus.conf import NAMESPACE
2626
from drf_spectacular.utils import OpenApiResponse, extend_schema
27+
from llama_stack_client import AsyncLlamaStackClient
28+
from llama_stack_client.lib.agents.agent import AsyncAgent
29+
from llama_stack_client.types.agent_create_params import AgentConfig
2730
from oauth2_provider.contrib.rest_framework import IsAuthenticatedOrTokenHasScope
2831
from prometheus_client import Histogram
2932
from rest_framework import permissions, serializers
@@ -103,7 +106,7 @@
103106
)
104107
from ansible_ai_connect.users.models import User
105108

106-
from ...main.permissions import IsAAPUser, IsRHInternalUser, IsTestUser
109+
# from ...main.permissions import IsAAPUser, IsRHInternalUser, IsTestUser
107110
from ...users.throttling import EndpointRateThrottle
108111
from ..feature_flags import FeatureFlags
109112
from .data.data_model import ContentMatchPayloadData, ContentMatchResponseDto
@@ -246,12 +249,13 @@ def finalize_response(self, request, response, *args, **kwargs):
246249
response = super().finalize_response(request, response, *args, **kwargs)
247250

248251
try:
249-
model_meta_data: MetaData = apps.get_app_config("ai").get_model_pipeline(MetaData)
250-
user = request.user
251-
org_id = hasattr(user, "organization") and user.organization and user.organization.id
252-
self.event.modelName = self.event.modelName or model_meta_data.get_model_id(
253-
request.user, org_id, self.req_model_id
254-
)
252+
pass
253+
# model_meta_data: MetaData = apps.get_app_config("ai").get_model_pipeline(MetaData)
254+
# user = request.user
255+
# org_id = hasattr(user, "organization") and user.organization and user.organization.id
256+
# self.event.modelName = self.event.modelName or model_meta_data.get_model_id(
257+
# request.user, org_id, self.req_model_id
258+
# )
255259
except (WcaNoDefaultModelId, WcaModelIdNotFound, WcaSecretManagerError):
256260
pass
257261
self.event.set_response(response)
@@ -1044,9 +1048,9 @@ class ChatEndpointThrottle(EndpointRateThrottle):
10441048
scope = "chat"
10451049

10461050
permission_classes = [
1047-
permissions.IsAuthenticated,
1048-
IsAuthenticatedOrTokenHasScope,
1049-
IsRHInternalUser | IsTestUser | IsAAPUser,
1051+
# permissions.IsAuthenticated,
1052+
# IsAuthenticatedOrTokenHasScope,
1053+
# IsRHInternalUser | IsTestUser | IsAAPUser,
10501054
]
10511055
required_scopes = ["read", "write"]
10521056
schema1_event = schema1.ChatBotOperationalEvent
@@ -1142,9 +1146,9 @@ class StreamingChatEndpointThrottle(EndpointRateThrottle):
11421146
scope = "chat"
11431147

11441148
permission_classes = [
1145-
permissions.IsAuthenticated,
1146-
IsAuthenticatedOrTokenHasScope,
1147-
IsRHInternalUser | IsTestUser | IsAAPUser,
1149+
# permissions.IsAuthenticated,
1150+
# IsAuthenticatedOrTokenHasScope,
1151+
# IsRHInternalUser | IsTestUser | IsAAPUser,
11481152
]
11491153
required_scopes = ["read", "write"]
11501154
schema1_event = schema1.StreamingChatBotOperationalEvent
@@ -1167,6 +1171,33 @@ def __init__(self):
11671171
else:
11681172
logger.debug("Chatbot is not enabled.")
11691173

1174+
agent_config = AgentConfig(
1175+
# model="anthropic/claude-3-5-haiku-latest",
1176+
model="llama3.2:3b-instruct-fp16",
1177+
instructions="You are a helpful Ansible Automation Platform assistant.",
1178+
sampling_params={
1179+
"strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9},
1180+
},
1181+
toolgroups=(
1182+
[
1183+
# "mcp::weather",
1184+
# "mcp::github",
1185+
# "mcp::fs",
1186+
# "mcp::aap_api",
1187+
# "mcp::gateway_api",
1188+
# "builtin::websearch",
1189+
]
1190+
),
1191+
tool_choice="auto",
1192+
input_shields=[], # available_shields if available_shields else [],
1193+
output_shields=[], # available_shields if available_shields else [],
1194+
enable_session_persistence=False,
1195+
)
1196+
self.client = AsyncLlamaStackClient(
1197+
base_url="http://localhost:8321",
1198+
)
1199+
self.agent = AsyncAgent(self.client, agent_config)
1200+
11701201
@extend_schema(
11711202
request=StreamingChatRequestSerializer,
11721203
responses={
@@ -1205,5 +1236,12 @@ def post(self, request) -> StreamingHttpResponse:
12051236
provider=req_provider,
12061237
conversation_id=conversation_id,
12071238
media_type=media_type,
1239+
chatbackend_config={
1240+
# "agent_config": self.agent_config,
1241+
# "agent_id": self.agent.agent_id,
1242+
# "session_id": self.session_id,
1243+
"client": self.client,
1244+
"agent": self.agent,
1245+
},
12081246
)
12091247
)

ansible_ai_connect/main/settings/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,12 @@
7777
"import_export",
7878
"ansible_base.resource_registry",
7979
"ansible_base.jwt_consumer",
80+
"corsheaders",
8081
]
8182

8283
MIDDLEWARE = [
84+
"corsheaders.middleware.CorsMiddleware",
85+
"django.middleware.common.CommonMiddleware",
8386
"allow_cidr.middleware.AllowCIDRMiddleware",
8487
"django_prometheus.middleware.PrometheusBeforeMiddleware",
8588
"django.middleware.security.SecurityMiddleware",
@@ -95,6 +98,7 @@
9598
"django_prometheus.middleware.PrometheusAfterMiddleware",
9699
"csp.middleware.CSPMiddleware",
97100
]
101+
CORS_ALLOW_ALL_ORIGINS = True
98102

99103
if os.environ.get("CSRF_TRUSTED_ORIGINS"):
100104
CSRF_TRUSTED_ORIGINS = os.environ.get("CSRF_TRUSTED_ORIGINS").split(",")

0 commit comments

Comments
 (0)