Skip to content

Commit 8fc9fb0

Browse files
committed
Handle the case of LLM configuration missing from environment
1 parent d9e84f9 commit 8fc9fb0

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

learn2rag/pipeline/generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
def generate(query: str, search_results: list[ScoredPoint], opt_config: dict[str, Any]) -> Any:
18+
assert llm is not None
1819
if hasattr(search_results, "points"):
1920
search_results = search_results.points
2021
context = "\n\n".join([context_template.format(source=result.payload['path'], content=result.payload['content']) for result in search_results]) # type: ignore[index]
@@ -28,6 +29,7 @@ def generate(query: str, search_results: list[ScoredPoint], opt_config: dict[str
2829

2930
def generate_stream(query: str, search_results: list[ScoredPoint], opt_config: dict[str, Any], request_id: str | None=None) -> Generator[str, None, None]:
3031
profilingLogger.info('start', extra={'activity': 'generate', 'request_id': request_id})
32+
assert llm is not None
3133

3234
if hasattr(search_results, "points"):
3335
search_results = search_results.points

learn2rag/pipeline/llm.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,20 @@ def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None
5454
)
5555

5656

57-
default_llm = OpenAIClient
58-
llm_id = os.environ.get('LLM_API_TYPE', default_llm.ID)
59-
logger.debug('Using LLM: %s', llm_id)
60-
61-
llm_kwargs = {
62-
'url': os.environ.get('LLM_API_URL'),
63-
'token': os.environ.get('LLM_API_TOKEN') or None,
64-
'model': os.environ.get('LLM_API_MODEL'),
65-
'proxy': os.environ.get('LLM_API_PROXY') or None,
66-
}
67-
logger.debug('Using LLM args: %s', llm_kwargs)
68-
69-
llm = llms[llm_id](**llm_kwargs).chat_model
57+
def chat_model_from_env() -> BaseChatModel:
58+
default_llm = OpenAIClient
59+
llm_id = os.environ.get('LLM_API_TYPE', default_llm.ID)
60+
logger.debug('Using LLM: %s', llm_id)
61+
llm_kwargs = {
62+
'url': os.environ.get('LLM_API_URL'),
63+
'token': os.environ.get('LLM_API_TOKEN') or None,
64+
'model': os.environ.get('LLM_API_MODEL'),
65+
'proxy': os.environ.get('LLM_API_PROXY') or None,
66+
}
67+
logger.debug('Using LLM args: %s', llm_kwargs)
68+
return llms[llm_id](**llm_kwargs).chat_model
69+
70+
71+
llm = chat_model_from_env() if 'LLM_API_TYPE' in os.environ else None
72+
if not llm:
73+
logger.warning('LLM is not configured')

0 commit comments

Comments
 (0)