Skip to content

Commit 8e25f33

Browse files
committed
SN1-331: Adding initial draft for endpoints
1 parent 62ae30c commit 8e25f33

File tree

8 files changed

+157
-87
lines changed

8 files changed

+157
-87
lines changed

neurons/miners/epistula_miner/miner.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
settings = settings.settings
66

77
import time
8-
import asyncio
9-
import json
108
import httpx
119
import netaddr
1210
import uvicorn
1311
import requests
1412
import traceback
1513
import bittensor as bt
16-
from starlette.responses import JSONResponse
1714
from loguru import logger
18-
from fastapi import APIRouter, Depends, FastAPI, Request, HTTPException
15+
from fastapi import APIRouter, FastAPI, Request, HTTPException
1916
from starlette.background import BackgroundTask
2017
from starlette.responses import StreamingResponse
2118
from bittensor.subtensor import serve_extrinsic
@@ -34,44 +31,41 @@
3431
SYSTEM_PROMPT = """You are a helpful agent that does it's best to answer all questions!"""
3532

3633

37-
class OpenAIMiner():
38-
34+
class OpenAIMiner:
3935
def __init__(self):
4036
self.should_exit = False
4137
self.client = httpx.AsyncClient(
42-
base_url="https://api.openai.com/v1",
43-
headers={
44-
"Authorization": f"Bearer {settings.OPENAI_API_KEY}",
45-
"Content-Type": "application/json",
46-
},
47-
)
38+
base_url="https://api.openai.com/v1",
39+
headers={
40+
"Authorization": f"Bearer {settings.OPENAI_API_KEY}",
41+
"Content-Type": "application/json",
42+
},
43+
)
4844
print("OpenAI Key: ", settings.OPENAI_API_KEY)
4945

5046
async def format_openai_query(self, request: Request):
5147
# Read the JSON data once
5248
data = await request.json()
53-
49+
5450
# Extract the required fields
5551
openai_request = {}
5652
for key in ["messages", "model", "stream"]:
5753
if key in data:
5854
openai_request[key] = data[key]
5955
openai_request["model"] = MODEL_ID
60-
56+
6157
return openai_request
62-
58+
6359
async def create_chat_completion(self, request: Request):
6460
bt.logging.info(
6561
"\u2713",
6662
f"Getting Chat Completion request from {request.headers.get('Epistula-Signed-By', '')[:8]}!",
6763
)
68-
req = self.client.build_request(
69-
"POST", "chat/completions", json = await self.format_openai_query(request)
70-
)
64+
logger.debug("Starting chat completion request...")
65+
req = self.client.build_request("POST", "chat/completions", json=await self.format_openai_query(request))
7166
r = await self.client.send(req, stream=True)
72-
return StreamingResponse(
73-
r.aiter_raw(), background=BackgroundTask(r.aclose), headers=r.headers
74-
)
67+
logger.debug("Chat completion request returning...")
68+
return StreamingResponse(r.aiter_raw(), background=BackgroundTask(r.aclose), headers=r.headers)
7569

7670
# async def create_chat_completion(self, request: Request):
7771
# bt.logging.info(
@@ -104,7 +98,7 @@ async def create_chat_completion(self, request: Request):
10498
# "\u2713",
10599
# f"Getting Chat Completion request from {request.headers.get('Epistula-Signed-By', '')[:8]}!",
106100
# )
107-
101+
108102
# async def word_stream():
109103
# words = "This is a test stream".split()
110104
# for word in words:
@@ -133,30 +127,27 @@ async def create_chat_completion(self, request: Request):
133127
# }
134128
# yield f"data: {json.dumps(data)}\n\n"
135129
# yield "data: [DONE]\n\n"
136-
130+
137131
# return StreamingResponse(word_stream(), media_type='text/event-stream')
138132

139133
async def check_availability(self, request: Request):
140134
print("Checking availability")
141135
# Parse the incoming JSON request
142136
data = await request.json()
143-
task_availabilities = data.get('task_availabilities', {})
144-
llm_model_availabilities = data.get('llm_model_availabilities', {})
145-
137+
task_availabilities = data.get("task_availabilities", {})
138+
llm_model_availabilities = data.get("llm_model_availabilities", {})
139+
146140
# Set all task availabilities to True
147141
task_response = {key: True for key in task_availabilities}
148-
142+
149143
# Set all model availabilities to False
150144
model_response = {key: False for key in llm_model_availabilities}
151-
145+
152146
# Construct the response dictionary
153-
response = {
154-
'task_availabilities': task_response,
155-
'llm_model_availabilities': model_response
156-
}
157-
147+
response = {"task_availabilities": task_response, "llm_model_availabilities": model_response}
148+
158149
return response
159-
150+
160151
async def verify_request(
161152
self,
162153
request: Request,
@@ -170,18 +161,14 @@ async def verify_request(
170161
signed_by = request.headers.get("Epistula-Signed-By")
171162
signed_for = request.headers.get("Epistula-Signed-For")
172163
if signed_for != self.wallet.hotkey.ss58_address:
173-
raise HTTPException(
174-
status_code=400, detail="Bad Request, message is not intended for self"
175-
)
164+
raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self")
176165
if signed_by not in self.metagraph.hotkeys:
177166
raise HTTPException(status_code=401, detail="Signer not in metagraph")
178167

179168
uid = self.metagraph.hotkeys.index(signed_by)
180169
stake = self.metagraph.S[uid].item()
181170
if not self.config.no_force_validator_permit and stake < 10000:
182-
bt.logging.warning(
183-
f"Blacklisting request from {signed_by} [uid={uid}], not enough stake -- {stake}"
184-
)
171+
bt.logging.warning(f"Blacklisting request from {signed_by} [uid={uid}], not enough stake -- {stake}")
185172
raise HTTPException(status_code=401, detail="Stake below minimum: {stake}")
186173

187174
# If anything is returned here, we can throw
@@ -200,8 +187,7 @@ async def verify_request(
200187
raise HTTPException(status_code=400, detail=err)
201188

202189
def run(self):
203-
204-
external_ip = None #settings.EXTERNAL_IP
190+
external_ip = None # settings.EXTERNAL_IP
205191
if not external_ip or external_ip == "[::]":
206192
try:
207193
external_ip = requests.get("https://checkip.amazonaws.com").text.strip()
@@ -232,7 +218,7 @@ def run(self):
232218
router.add_api_route(
233219
"/v1/chat/completions",
234220
self.create_chat_completion,
235-
#dependencies=[Depends(self.verify_request)],
221+
# dependencies=[Depends(self.verify_request)],
236222
methods=["POST"],
237223
)
238224
router.add_api_route(
@@ -244,7 +230,8 @@ def run(self):
244230
fast_config = uvicorn.Config(
245231
app,
246232
host="0.0.0.0",
247-
port=settings.AXON_PORT,
233+
# port=settings.AXON_PORT,
234+
port=8008,
248235
log_level="info",
249236
loop="asyncio",
250237
)

neurons/validator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from prompting.weight_setting.weight_setter import weight_setter
2626
from prompting.llms.utils import GPUInfo
2727
from prompting.base.epistula import query_miners
28+
from prompting.api.api import start_api
2829

2930
NEURON_SAMPLE_SIZE = 100
3031

@@ -139,14 +140,18 @@ async def collect_responses(self, task: BaseTextTask) -> DendriteResponseEvent |
139140
logger.warning("No available miners. This should already have been caught earlier.")
140141
return
141142

142-
143-
body = {"seed": task.seed, "model": task.llm_model_id, "messages": [{'role': 'user', 'content': task.query},]}
143+
body = {
144+
"seed": task.seed,
145+
"model": task.llm_model_id,
146+
"messages": [
147+
{"role": "user", "content": task.query},
148+
],
149+
}
144150
body_bytes = json.dumps(body).encode("utf-8")
145151
stream_results = await query_miners(task.__class__.__name__, uids, body_bytes)
146152

147153
log_stream_results(stream_results)
148154

149-
150155
response_event = DendriteResponseEvent(
151156
stream_results=stream_results, uids=uids, timeout=settings.NEURON_TIMEOUT
152157
)
@@ -202,6 +207,9 @@ def __exit__(self, exc_type, exc_value, traceback):
202207

203208

204209
async def main():
210+
# start api
211+
asyncio.create_task(start_api())
212+
205213
GPUInfo.log_gpu_info()
206214
# start profiling
207215
asyncio.create_task(profiler.print_stats())

prompting/api/api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from fastapi import FastAPI
2+
import uvicorn
3+
from prompting.api.gpt_endpoints.api import router as gpt_router
4+
from prompting.api.miner_availabilities.api import router as miner_availabilities_router
5+
from loguru import logger
6+
7+
app = FastAPI()
8+
9+
app.include_router(gpt_router)
10+
app.include_router(miner_availabilities_router)
11+
12+
13+
async def start_api():
14+
logger.info("Starting API")
15+
uvicorn.run(app, host="0.0.0.0", port=8000)

prompting/api/gpt_endpoints/api.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from fastapi import APIRouter, Request
2+
import openai
3+
from prompting.settings import settings
4+
from httpx import Timeout
5+
from prompting.base.epistula import create_header_hook
6+
from fastapi.responses import StreamingResponse
7+
import json
8+
9+
router = APIRouter()
10+
11+
12+
async def process_stream(stream):
13+
async for chunk in stream:
14+
if hasattr(chunk, "choices") and chunk.choices:
15+
# Extract the delta content from the chunk
16+
delta = chunk.choices[0].delta
17+
if hasattr(delta, "content") and delta.content is not None:
18+
# Format as SSE data
19+
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
20+
yield "data: [DONE]\n\n"
21+
22+
23+
@router.post("/v1/chat/completions")
24+
async def proxy_chat_completions(request: Request):
25+
# Get the request body
26+
body = await request.json()
27+
28+
# Ensure streaming is enabled
29+
body["stream"] = True
30+
31+
# TODO: Forward to actual miners
32+
miner = openai.AsyncOpenAI(
33+
base_url="http://localhost:8008/v1",
34+
max_retries=0,
35+
timeout=Timeout(settings.NEURON_TIMEOUT, connect=5, read=5),
36+
http_client=openai.DefaultAsyncHttpxClient(
37+
event_hooks={"request": [create_header_hook(settings.WALLET.hotkey, None)]}
38+
),
39+
)
40+
41+
# Create streaming request to OpenAI
42+
response = await miner.chat.completions.create(**body)
43+
44+
# Return a streaming response with properly formatted chunks
45+
return StreamingResponse(process_stream(response), media_type="text/event-stream")

prompting/api/gpt_endpoints/serialisers.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from fastapi import APIRouter
2+
from prompting.miner_availability.miner_availability import miner_availabilities
3+
from loguru import logger
4+
5+
router = APIRouter()
6+
7+
8+
@router.post("/miner_availabilities")
9+
async def get_miner_availabilities(uids: list[int] | None = None):
10+
if uids:
11+
return {uid: miner_availabilities.miners.get(uid) for uid in uids}
12+
logger.info(f"Returning all miner availabilities for {len(miner_availabilities.miners)} miners")
13+
return miner_availabilities.miners
14+
15+
16+
@router.get("/get_available_miners")
17+
async def get_available_miners(task: str | None = None, model: str | None = None, k: int = 10):
18+
logger.info(f"Getting {k} available miners for task {task} and model {model}")
19+
return miner_availabilities.get_available_miners(task=task, model=model, k=k)

0 commit comments

Comments
 (0)