Skip to content

Commit 0bb0101

Browse files
Merge pull request #202 from mindsdb/feature/config
Added Configuration Management Operations
2 parents 2713662 + d29184d commit 0bb0101

File tree

4 files changed

+277
-0
lines changed

4 files changed

+277
-0
lines changed

mindsdb_sdk/config.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from mindsdb_sdk.connectors.rest_api import RestAPI
2+
3+
4+
class Config():
5+
"""
6+
**Configuration for MindsDB**
7+
8+
This class provides methods to set and get the various configuration aspects of MindsDB.
9+
10+
Working with configuration:
11+
12+
Set default LLM configuration:
13+
14+
>>> server.config.set_default_llm(
15+
... provider='openai',
16+
... model_name='gpt-4',
17+
... api_key='sk-...'
18+
... )
19+
20+
Get default LLM configuration:
21+
22+
>>> llm_config = server.config.get_default_llm()
23+
>>> print(llm_config)
24+
25+
Set default embedding model:
26+
27+
>>> server.config.set_default_embedding_model(
28+
... provider='openai',
29+
... model_name='text-embedding-ada-002',
30+
... api_key='sk-...'
31+
... )
32+
33+
Get default embedding model:
34+
35+
>>> embedding_config = server.config.get_default_embedding_model()
36+
37+
Set default reranking model:
38+
39+
>>> server.config.set_default_reranking_model(
40+
... provider='openai',
41+
... model_name='gpt-4',
42+
... api_key='sk-...'
43+
... )
44+
45+
Get default reranking model:
46+
47+
>>> reranking_config = server.config.get_default_reranking_model()
48+
"""
49+
def __init__(self, api: RestAPI):
50+
self.api = api
51+
52+
def set_default_llm(
53+
self,
54+
provider: str,
55+
model_name: str,
56+
api_key: str = None,
57+
**kwargs
58+
):
59+
"""
60+
Set the default LLM configuration for MindsDB.
61+
62+
:param provider: The name of the LLM provider (e.g., 'openai', 'google').
63+
:param model_name: The name of the model to use.
64+
:param api_key: Optional API key for the provider.
65+
:param kwargs: Additional parameters for the LLM configuration.
66+
"""
67+
config = {
68+
"default_llm": {
69+
"provider": provider,
70+
"model_name": model_name,
71+
"api_key": api_key,
72+
**kwargs
73+
}
74+
}
75+
self.api.update_config(config)
76+
77+
def get_default_llm(self):
78+
"""
79+
Get the default LLM configuration for MindsDB.
80+
81+
:return: Dictionary containing the default LLM configuration.
82+
"""
83+
return self.api.get_config().get("default_llm", {})
84+
85+
def set_default_embedding_model(
86+
self,
87+
provider: str,
88+
model_name: str,
89+
api_key: str = None,
90+
**kwargs
91+
):
92+
"""
93+
Set the default embedding model configuration for MindsDB.
94+
95+
:param provider: The name of the embedding model provider (e.g., 'openai', 'google').
96+
:param model_name: The name of the embedding model to use.
97+
:param api_key: Optional API key for the provider.
98+
:param kwargs: Additional parameters for the embedding model configuration.
99+
"""
100+
config = {
101+
"default_embedding_model": {
102+
"provider": provider,
103+
"model_name": model_name,
104+
"api_key": api_key,
105+
**kwargs
106+
}
107+
}
108+
self.api.update_config(config)
109+
110+
def get_default_embedding_model(self):
111+
"""
112+
Get the default embedding model configuration for MindsDB.
113+
114+
:return: Dictionary containing the default embedding model configuration.
115+
"""
116+
return self.api.get_config().get("default_embedding_model", {})
117+
118+
def set_default_reranking_model(
119+
self,
120+
provider: str,
121+
model_name: str,
122+
api_key: str = None,
123+
**kwargs
124+
):
125+
"""
126+
Set the default reranking model configuration for MindsDB.
127+
128+
:param provider: The name of the reranking model provider (e.g., 'openai', 'google').
129+
:param model_name: The name of the reranking model to use.
130+
:param api_key: Optional API key for the provider.
131+
:param kwargs: Additional parameters for the reranking model configuration.
132+
"""
133+
config = {
134+
"default_reranking_model": {
135+
"provider": provider,
136+
"model_name": model_name,
137+
"api_key": api_key,
138+
**kwargs
139+
}
140+
}
141+
self.api.update_config(config)
142+
143+
def get_default_reranking_model(self):
144+
"""
145+
Get the default reranking model configuration for MindsDB.
146+
147+
:return: Dictionary containing the default reranking model configuration.
148+
"""
149+
return self.api.get_config().get("default_reranking_model", {})
150+

mindsdb_sdk/connectors/rest_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,24 @@ def knowledge_base_completion(self, project: str, knowledge_base_name, payload):
468468
)
469469
_raise_for_status(r)
470470
return r.json()
471+
472+
def get_config(self):
473+
"""
474+
Get MindsDB configuration.
475+
476+
:return: Dictionary containing MindsDB configuration.
477+
"""
478+
url = self.url + '/api/config'
479+
r = self.session.get(url)
480+
_raise_for_status(r)
481+
return r.json()
482+
483+
def update_config(self, config: dict):
484+
"""
485+
Update MindsDB configuration with the provided settings.
486+
487+
:param config: Dictionary containing configuration settings.
488+
"""
489+
url = self.url + '/api/config'
490+
r = self.session.put(url, json=config)
491+
_raise_for_status(r)

mindsdb_sdk/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .ml_engines import MLEngines
55
from .handlers import Handlers
66
from .skills import Skills
7+
from .config import Config
78

89

910
class Server(Project):
@@ -48,6 +49,8 @@ def __init__(self, api, skills: Skills = None, agents: Agents = None):
4849
self.ml_handlers = Handlers(self.api, 'ml')
4950
self.data_handlers = Handlers(self.api, 'data')
5051

52+
self.config = Config(api)
53+
5154
def status(self) -> dict:
5255
"""
5356
Get server information. It could content version

tests/test_sdk.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,7 @@ def test_add_database(self, mock_post, mock_put, mock_get):
17881788
}
17891789
assert agent_update_json == expected_agent_json
17901790

1791+
17911792
class TestSkills():
17921793
@patch('requests.Session.get')
17931794
def test_list(self, mock_get):
@@ -1896,3 +1897,105 @@ def test_delete(self, mock_delete):
18961897
server.skills.drop('test_skill')
18971898
# Check API call.
18981899
assert mock_delete.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'
1900+
1901+
1902+
class TestConfig():
1903+
@patch('requests.Session.put')
1904+
@patch('requests.Session.get')
1905+
def test_set_and_get_default_llm(self, mock_get, mock_put):
1906+
server = mindsdb_sdk.connect()
1907+
response_mock(mock_put, {})
1908+
response_mock(mock_get, {
1909+
'default_llm': {
1910+
'provider': 'openai',
1911+
'model_name': 'gpt-4',
1912+
'api_key': 'sk-test123'
1913+
}
1914+
})
1915+
1916+
server.config.set_default_llm(
1917+
provider='openai',
1918+
model_name='gpt-4',
1919+
api_key='sk-test123'
1920+
)
1921+
assert mock_put.call_args[1]['json'] == {
1922+
'default_llm': {
1923+
'provider': 'openai',
1924+
'model_name': 'gpt-4',
1925+
'api_key': 'sk-test123'
1926+
}
1927+
}
1928+
1929+
llm_config = server.config.get_default_llm()
1930+
assert llm_config == {
1931+
'provider': 'openai',
1932+
'model_name': 'gpt-4',
1933+
'api_key': 'sk-test123'
1934+
}
1935+
1936+
@patch('requests.Session.put')
1937+
@patch('requests.Session.get')
1938+
def test_set_and_get_default_embedding_model(self, mock_get, mock_put):
1939+
server = mindsdb_sdk.connect()
1940+
response_mock(mock_put, {})
1941+
response_mock(mock_get, {
1942+
'default_embedding_model': {
1943+
'provider': 'openai',
1944+
'model_name': 'text-embedding-ada-002',
1945+
'api_key': 'sk-test456'
1946+
}
1947+
})
1948+
1949+
server.config.set_default_embedding_model(
1950+
provider='openai',
1951+
model_name='text-embedding-ada-002',
1952+
api_key='sk-test456'
1953+
)
1954+
assert mock_put.call_args[1]['json'] == {
1955+
'default_embedding_model': {
1956+
'provider': 'openai',
1957+
'model_name': 'text-embedding-ada-002',
1958+
'api_key': 'sk-test456'
1959+
}
1960+
}
1961+
1962+
embedding_config = server.config.get_default_embedding_model()
1963+
assert embedding_config == {
1964+
'provider': 'openai',
1965+
'model_name': 'text-embedding-ada-002',
1966+
'api_key': 'sk-test456'
1967+
}
1968+
1969+
@patch('requests.Session.put')
1970+
@patch('requests.Session.get')
1971+
def test_set_and_get_default_reranking_model(self, mock_get, mock_put):
1972+
server = mindsdb_sdk.connect()
1973+
response_mock(mock_put, {})
1974+
response_mock(mock_get, {
1975+
'default_reranking_model': {
1976+
'provider': 'cohere',
1977+
'model_name': 'rerank-english-v2.0',
1978+
'api_key': 'cohere-test789'
1979+
}
1980+
})
1981+
1982+
server.config.set_default_reranking_model(
1983+
provider='cohere',
1984+
model_name='rerank-english-v2.0',
1985+
api_key='cohere-test789'
1986+
)
1987+
assert mock_put.call_args[1]['json'] == {
1988+
'default_reranking_model': {
1989+
'provider': 'cohere',
1990+
'model_name': 'rerank-english-v2.0',
1991+
'api_key': 'cohere-test789'
1992+
}
1993+
}
1994+
1995+
reranking_config = server.config.get_default_reranking_model()
1996+
assert reranking_config == {
1997+
'provider': 'cohere',
1998+
'model_name': 'rerank-english-v2.0',
1999+
'api_key': 'cohere-test789'
2000+
}
2001+

0 commit comments

Comments
 (0)