Skip to content

Commit a99e77b

Browse files
authored
feat: prioritize running local inference jobs on skynet nodes with vLLM (#208)
1 parent b746722 commit a99e77b

File tree

4 files changed

+52
-7
lines changed

4 files changed

+52
-7
lines changed

skynet/auth/user_info.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313

1414

1515
class CredentialsType(Enum):
16-
OPENAI = 'OPENAI'
1716
AZURE_OPENAI = 'AZURE_OPENAI'
17+
LOCAL = 'LOCAL'
1818
OCI = 'OCI'
19+
OPENAI = 'OPENAI'
1920

2021

2122
async def open_yaml(file_path):

skynet/modules/ttt/llm_selector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def get_job_processor(customer_id: str, job_id: Optional[str] = None) -> Process
4747
elif api_type == CredentialsType.AZURE_OPENAI.value:
4848
return Processors.AZURE
4949

50+
if api_type == CredentialsType.LOCAL.value:
51+
return Processors.LOCAL
52+
5053
if oci_available:
5154
return Processors.OCI
5255

53-
log.warning(f'OCI is not available, falling back to local processing for customer {customer_id}')
54-
5556
return Processors.LOCAL
5657

5758
@staticmethod
@@ -109,8 +110,7 @@ def select(
109110
service_endpoint=oci_service_endpoint,
110111
)
111112
else:
112-
if customer_id:
113-
log.info(f'Customer {customer_id} has no API key configured, falling back to local processing')
113+
log.info(f'Forwarding inference to local LLM for customer {customer_id}')
114114

115115
return ChatOpenAI(
116116
api_key='placeholder', # use a placeholder value to bypass validation

skynet/modules/ttt/processor_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,33 @@ async def test_process_with_azure_open_ai(self, process_fixture):
137137

138138
LLMSelector.select.assert_called_once()
139139

140+
@pytest.mark.asyncio
141+
async def test_process_with_local(self, process_fixture):
142+
'''Test that a job is sent for local inference if there is a customer id configured for it.'''
143+
144+
from skynet.modules.ttt.llm_selector import LLMSelector
145+
from skynet.modules.ttt.processor import process
146+
147+
process_fixture.patch(
148+
'skynet.modules.ttt.llm_selector.get_credentials',
149+
return_value={'type': 'LOCAL'},
150+
)
151+
process_fixture.patch('skynet.modules.ttt.llm_selector.oci_available', True)
152+
153+
job = Job(
154+
payload=DocumentPayload(
155+
text="Andrew: Hello. Beatrix: Honey? It’s me . . . Andrew: Where are you? Beatrix: At the station. I missed my train."
156+
),
157+
metadata=DocumentMetadata(customer_id='test'),
158+
type=JobType.SUMMARY,
159+
)
160+
161+
assert LLMSelector.get_job_processor(job.metadata.customer_id, job.id) == Processors.LOCAL
162+
163+
await process(job)
164+
165+
LLMSelector.select.assert_called_once()
166+
140167
@pytest.mark.asyncio
141168
async def test_process_with_oci(self, process_fixture):
142169
'''Test that a job is sent for inference to oci if there is a customer id configured for it.'''

skynet/modules/ttt/summaries/jobs.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44

55
from skynet.constants import ERROR_JOBS_KEY, PENDING_JOBS_KEY, RUNNING_JOBS_KEY
6-
from skynet.env import enable_batching, job_timeout, max_concurrency, modules, redis_exp_seconds
6+
from skynet.env import enable_batching, job_timeout, max_concurrency, modules, redis_exp_seconds, use_vllm
77
from skynet.logs import get_logger
88
from skynet.modules.monitoring import (
99
OPENAI_API_RESTART_COUNTER,
@@ -82,6 +82,10 @@ async def create_job(job_type: JobType, payload: DocumentPayload, metadata: Docu
8282
"""Create a job and add it to the db queue if it can't be started immediately."""
8383

8484
job = Job(payload=payload, type=job_type, metadata=metadata)
85+
processor = LLMSelector.get_job_processor(metadata.customer_id)
86+
87+
# encode the processor in the job id to avoid having to retrieve the whole job object
88+
job.id += f':{processor.value}'
8589
job_id = job.id
8690

8791
await db.set(job_id, Job.model_dump_json(job))
@@ -198,7 +202,20 @@ async def maybe_run_next_job() -> None:
198202
if not can_run_next_job():
199203
return
200204

201-
next_job_id = await db.lpop(PENDING_JOBS_KEY)
205+
next_job_id = None
206+
207+
if use_vllm:
208+
pending_jobs_keys = await db.lrange(PENDING_JOBS_KEY, 0, -1)
209+
210+
for job_id in pending_jobs_keys:
211+
if job_id.endswith(Processors.LOCAL.value):
212+
next_job_id = job_id
213+
await db.lrem(PENDING_JOBS_KEY, 0, job_id)
214+
215+
break
216+
217+
if not next_job_id:
218+
next_job_id = await db.lpop(PENDING_JOBS_KEY)
202219

203220
await update_summary_queue_metric()
204221

0 commit comments

Comments
 (0)