Skip to content

Commit d9e84f9

Browse files
committed
Fix language model configuration
1 parent 14fb79d commit d9e84f9

File tree

4 files changed

+76
-38
lines changed

4 files changed

+76
-38
lines changed

learn2rag/pipeline/llm.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,69 @@
11
import logging
22
import os
33
from pydantic import SecretStr
4+
from langchain_core.language_models.chat_models import BaseChatModel
45
from langchain_ollama import ChatOllama
56
from langchain_openai import ChatOpenAI
6-
from typing import Callable, Any
77

88

9-
def ollama_client(*, url: str, token: str | None, model: str, proxy: str | None) -> ChatOllama:
10-
return ChatOllama(
11-
model=model,
12-
temperature=0,
13-
base_url=url,
14-
client_kwargs={
15-
'headers': {'Authorization': f'Bearer {token}'} if token else {},
16-
'proxy': proxy,
17-
},
18-
)
9+
logger = logging.getLogger(__name__)
1910

2011

21-
def openai_client(*, url: str, token: SecretStr, model: str, proxy: str | None) -> ChatOpenAI:
22-
return ChatOpenAI(
23-
model=model,
24-
temperature=0,
25-
base_url=url,
26-
api_key=token,
27-
)
12+
class LLMClient():
13+
# ID is used as a key to store in user data, should not be changed
14+
ID: str
15+
# LABEL is a display label for user interface
16+
LABEL: str
17+
chat_model: BaseChatModel
2818

2919

30-
# TODO: set up the right llm for user_config
20+
llms = {}
21+
def llm_client(cls: type[LLMClient]) -> type[LLMClient]:
22+
llms[cls.ID] = cls; return cls
23+
24+
25+
# First @llm_client would be the default in UI when adding an external model
26+
@llm_client
27+
class OpenAIClient(LLMClient):
28+
ID = 'ChatOpenAI'
29+
LABEL = 'OpenAI'
30+
31+
def __init__(self, *, url: str, token: SecretStr, model: str, proxy: str | None) -> None:
32+
self.chat_model = ChatOpenAI(
33+
model=model,
34+
temperature=0,
35+
base_url=url,
36+
api_key=token,
37+
)
38+
39+
40+
@llm_client
41+
class OllamaClient(LLMClient):
42+
ID = 'ChatOllama'
43+
LABEL = 'Ollama'
44+
45+
def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None) -> None:
46+
self.chat_model = ChatOllama(
47+
model=model,
48+
temperature=0,
49+
base_url=url,
50+
client_kwargs={
51+
'headers': {'Authorization': f'Bearer {token}'} if token else {},
52+
'proxy': proxy,
53+
},
54+
)
55+
56+
57+
default_llm = OpenAIClient
58+
llm_id = os.environ.get('LLM_API_TYPE', default_llm.ID)
59+
logger.debug('Using LLM: %s', llm_id)
3160

3261
llm_kwargs = {
3362
'url': os.environ.get('LLM_API_URL'),
3463
'token': os.environ.get('LLM_API_TOKEN') or None,
3564
'model': os.environ.get('LLM_API_MODEL'),
3665
'proxy': os.environ.get('LLM_API_PROXY') or None,
3766
}
38-
logging.info('LLM args: %s', llm_kwargs)
39-
40-
# the keys are written by the configurator UI
41-
llms: dict[str, Callable[..., Any]] = {
42-
'ChatOllama': ollama_client,
43-
'ChatOpenAI': openai_client,
44-
}
67+
logger.debug('Using LLM args: %s', llm_kwargs)
4568

46-
llm = llms[os.environ.get('LLM_API_TYPE', 'ChatOllama')](**llm_kwargs)
69+
llm = llms[llm_id](**llm_kwargs).chat_model

learn2rag/ui/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from learn2rag.compose import Project
2828
import learn2rag.data
29+
import learn2rag.pipeline.llm
2930

3031
from datetime import datetime # <-- ADD THIS
3132

@@ -179,6 +180,7 @@ def inject_info() -> dict[str, Any]:
179180
'firststeps_storage_path': app.instance_path + '/storage/example',
180181
'debug_logging': config.get('logging', {}).get('debug', False),
181182
'current_timestamp': math.floor(time.time()),
183+
'llm': learn2rag.pipeline.llm,
182184
}
183185

184186
@app.context_processor
@@ -243,7 +245,7 @@ def model_create() -> 'str | werkzeug.wrappers.response.Response':
243245
ok = True
244246
model = request.form['model']
245247
api = request.form['api']
246-
if api == 'ollama_clientent':
248+
if api == learn2rag.pipeline.llm.OllamaClient.ID:
247249
url = request.form.get('url') or 'http://127.0.0.1:' + str(app.config['OLLAMA']['port']) + '/'
248250
# TODO setup tokens for locally running ollama
249251
token = request.form.get('token') or ''
@@ -252,7 +254,7 @@ def model_create() -> 'str | werkzeug.wrappers.response.Response':
252254
model += ':latest'
253255
start_project('ollama_download', components_template_path / 'ollama-download.yml', Path(), {'model': model})
254256
return flask_redirect(url_for('model_pulling', model=model))
255-
elif api == 'openai_clientent':
257+
elif api == learn2rag.pipeline.llm.OpenAIClient.ID:
256258
url = request.form['url']
257259
token = request.form['token']
258260
else:
@@ -284,7 +286,7 @@ def model_pulling() -> 'str | werkzeug.wrappers.response.Response':
284286
'url': 'http://127.0.0.1:' + str(app.config['OLLAMA']['port']) + '/',
285287
'token': '',
286288
'model': model,
287-
'api': 'ollama_clientent',
289+
'api': learn2rag.pipeline.llm.OllamaClient.ID,
288290
})
289291
flash(pgettext('flash', 'Downloaded a language model: %(model)s', model=model))
290292
res = make_response(render_template('model_pulling_success.html'))

learn2rag/ui/config.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ SUGGESTED_MODELS:
66
link: https://ollama.com/library/gemma3
77
ollama: pull
88
config:
9-
api: ollama_client
9+
api: ChatOllama
1010
model: gemma3:27b
1111
llama3.3_70b:
1212
label: Meta Llama 3.3
@@ -15,13 +15,13 @@ SUGGESTED_MODELS:
1515
link: https://ollama.com/library/llama3.3
1616
ollama: pull
1717
config:
18-
api: ollama_client
18+
api: ChatOllama
1919
model: llama3.3:70b
2020
tinyllama:
2121
label: TinyLlama
2222
image: models/tinyllama.png
2323
link: https://github.com/jzhang38/TinyLlama
2424
ollama: pull
2525
config:
26-
api: ollama_client
26+
api: ChatOllama
2727
model: tinyllama:latest

learn2rag/ui/templates/models_list.html

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,14 @@ <h2 class="accordion-header">
3030
<div class="mb-3">
3131
<label for="api" class="form-label">{{pgettext('form_label', 'API type')}}</label>
3232
<select required="required" class="form-select" name="api">
33-
<option selected="selected">ChatOpenAI</option>
34-
<option>ChatOllama</option>
33+
{% for llm_client in llm.llms.values() %}
34+
<option
35+
value="{{ llm_client.ID }}"
36+
{% if loop.first %}
37+
selected="selected"
38+
{% endif %}
39+
>{{ llm_client.LABEL }}</option>
40+
{% endfor %}
3541
</select>
3642
</div>
3743
<div class="mb-3">
@@ -103,7 +109,7 @@ <h5 class="card-title">{{gettext('Another model')}}</h5>
103109
<label for="model" class="form-label">{{pgettext('form_label', 'Language model')}}</label>
104110
<input class="form-control" name="model" required="required" title="{{pgettext('tooltip', 'For example: %(model)s', model=firststeps_model.get('config', {}).get('model'))}}">
105111
</div>
106-
<input type="hidden" name="api" value="ollama_clientent"/>
112+
<input type="hidden" name="api" value="{{ llm.OllamaClient.ID }}"/>
107113
<input type="hidden" name="ollama" value="pull"/>
108114
<button type="submit" class="btn btn-primary">Save</button>
109115
</form>
@@ -140,7 +146,7 @@ <h2 class="accordion-header">
140146
{% endfor %}
141147
</select>
142148
</div>
143-
<input type="hidden" name="api" value="ollama_clientent"/>
149+
<input type="hidden" name="api" value="{{ llm.OllamaClient.ID }}"/>
144150
<input type="hidden" name="ollama" value="use"/>
145151
<button type="submit" class="btn btn-primary">{{pgettext('button', 'Save')}}</button>
146152
</form>
@@ -168,7 +174,14 @@ <h2 class="accordion-header">
168174
{% for name, model in models.items() %}
169175
<tr>
170176
<td title="{{ name }}">{{ model.label }}</td>
171-
<td>{{ model.api }}</td>
177+
<td>
178+
{% if model.api in llm.llms %}
179+
{{ llm.llms[model.api].LABEL }}
180+
{% else %}
181+
<code>{{ model.api }}</code>
182+
<span title="{{pgettext('tooltip', 'Unknown value')}}" style="cursor: default;">⚠️</span>
183+
{% endif %}
184+
</td>
172185
<td>{{ model.url }}</td>
173186
<td>{{ model.token }}</td>
174187
<td>{{ model.model }}</td>

0 commit comments

Comments
 (0)