Skip to content

Commit d4dcef9

Browse files
srdaspre-commit-ci[bot]Darshan808
authored
Allow embedding model fields, fix coupled model fields, add custom OpenAI provider (#1264)
* Simplifying the OpenAI provider to use multiple model providers * Update openrouter.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * openai general interface added * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * embedding * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated settings to take OpenAI generic embedding models * added openai generic embeddings screenshot * Fixed Issue 1261 * bump version floor on jupyter server * linter * adding embedding model fields * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_config_manager * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * pyproject.toml fixes * pyproject.toml updates * Update pyproject.toml * Update pyproject.toml * Make Native Chat Handlers Overridable via Entry Points (#1249) * make native chat handlers customizable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove-ci-error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add-disabled-check-and-sort-entrypoints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor Chat Handlers to Simplify Initialization (#1257) * simplify-entrypoints-loading * fix-lint * fix-tests * add-retriever-typing * remove-retriever-from-base * fix-circular-import(ydoc-import) * fix-tests * fix-type-check-failure * refactor-retriever-init * Allow chat handlers to be initialized in any order (#1268) * lazy-initialize-retriever * add-retriever-property * rebase-into-main * update-docs * update-documentation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * pyproject toml files * pyproject toml updates * update snapshot * writing config file correctly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tsx lint * Update use-server-info.ts * Update pyproject.toml * adds embedding_models attribute * Fixed display of Base url for embeddings and completions * removed embedding_models * Added help fields * Update chat-settings.tsx * minor reversions moved to new issue --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Darshan Poudel <[email protected]>
1 parent 110311b commit d4dcef9

File tree

16 files changed

+291
-75
lines changed

16 files changed

+291
-75
lines changed
157 KB
Loading
150 KB
Loading
188 KB
Loading
349 KB
Loading

docs/source/users/index.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,13 @@ Jupyter AI enables use of language models hosted on [Amazon Bedrock](https://aws
346346
For details on enabling model access in your AWS account, using cross-region inference, or invoking custom/provisioned models, please see our dedicated documentation page on [using Amazon Bedrock in Jupyter AI](bedrock.md).
347347

348348

349-
### OpenRouter Usage
349+
### OpenRouter and OpenAI Interface Usage
350350

351351
Jupyter AI enables use of language models accessible through [OpenRouter](https://openrouter.ai)'s unified interface. Examples of models that may be accessed via OpenRouter are: [Deepseek](https://openrouter.ai/deepseek/deepseek-chat), [Qwen](https://openrouter.ai/qwen/), [mistral](https://openrouter.ai/mistralai/), etc. OpenRouter enables usage of any model conforming to the OpenAI API.
352352

353-
For details on enabling model access via the AI Settings and using models via OpenRouter, please see the dedicated documentation page on using [OpenRouter in Jupyter AI](openrouter.md).
353+
Likewise, for many models, you may directly choose the OpenAI provider in Jupyter AI instead of OpenRouter in the same way.
354+
355+
For details on enabling model access via the AI Settings and using models via OpenRouter or OpenAI, please see the dedicated documentation page on using [OpenRouter and OpenAI providers in Jupyter AI](openrouter.md).
354356

355357

356358
### SageMaker endpoints usage

docs/source/users/openrouter.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Using OpenRouter in Jupyter AI
1+
# Using OpenRouter or OpenAI Interfaces in Jupyter AI
22

33
[(Return to the Chat Interface page)](index.md#openrouter-usage)
44

@@ -33,4 +33,36 @@ You should now be able to use Deepseek! An example of usage is shown next:
3333
alt='Screenshot of chat using Deepseek via the OpenRouter provider.'
3434
class="screenshot" />
3535

36+
In a similar manner, models may also be invoked directly using the OpenAI provider interface in Jupyter AI. First, you can choose the OpenAI provider and then enter in the model ID, as shown on the OpenAI [models page](https://platform.openai.com/docs/models). An example is shown below:
37+
38+
<img src="../_static/openai-chat-openai.png"
39+
width="75%"
40+
alt='Screenshot of chat using gpt-4o via the OpenAI provider.'
41+
class="screenshot" />
42+
43+
DeepSeek models may be used via the same interface, if the base API url is provided:
44+
45+
<img src="../_static/openai-chat-deepseek.png"
46+
width="75%"
47+
alt='Screenshot of chat using deepseek via the OpenAI provider.'
48+
class="screenshot" />
49+
50+
For DeepSeek models, enter the DeepSeek API for the OpenAI API key.
51+
52+
Models deployed using vLLM may be used in a similar manner:
53+
54+
<img src="../_static/openai-chat-vllm.png"
55+
width="75%"
56+
alt='Screenshot of chat using vllm via the OpenAI provider.'
57+
class="screenshot" />
58+
59+
Usage of models using vLLM and their deployment is discussed [here](vllm.md).
60+
61+
For embedding models from OpenAI, you can generically choose them using the AI Settings interface as well:
62+
63+
<img src="../_static/openai-embeddings.png"
64+
width="75%"
65+
alt='Screenshot of embedding use via the OpenAI provider.'
66+
class="screenshot" />
67+
3668
[(Return to the Chat Interface page)](index.md#openrouter-usage)

packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/ollama.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
2424
name = "Ollama"
2525
# source: https://ollama.com/library
2626
model_id_key = "model"
27-
models = [
28-
"nomic-embed-text",
29-
"mxbai-embed-large",
30-
"all-minilm",
31-
"snowflake-arctic-embed",
32-
]
27+
models = ["*"]
3328
registry = True
3429
fields = [
3530
TextField(key="base_url", label="Base API URL (optional)", format="text"),

packages/jupyter-ai-magics/jupyter_ai_magics/partner_providers/openai.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ def is_api_key_exc(cls, e: Exception):
7676
return False
7777

7878

79+
class ChatOpenAICustomProvider(BaseProvider, ChatOpenAI):
80+
id = "openai-chat-custom"
81+
name = "OpenAI (general interface)"
82+
models = ["*"]
83+
model_id_key = "model_name"
84+
model_id_label = "Model ID"
85+
pypi_package_deps = ["langchain_openai"]
86+
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
87+
fields = [
88+
TextField(
89+
key="openai_api_base", label="Base API URL (optional)", format="text"
90+
),
91+
TextField(
92+
key="openai_organization", label="Organization (optional)", format="text"
93+
),
94+
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
95+
]
96+
help = "Supports non-OpenAI model that use the OpenAI API interface. Replace the OpenAI API key with the API key for the chosen provider."
97+
registry = True
98+
99+
79100
class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
80101
id = "azure-chat-openai"
81102
name = "Azure OpenAI"
@@ -107,6 +128,15 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
107128
model_id_key = "model"
108129
pypi_package_deps = ["langchain_openai"]
109130
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
131+
132+
133+
class OpenAIEmbeddingsCustomProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
134+
id = "openai-custom"
135+
name = "OpenAI (general interface)"
136+
models = ["*"]
137+
model_id_key = "model"
138+
pypi_package_deps = ["langchain_openai"]
139+
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
110140
registry = True
111141
fields = [
112142
TextField(
@@ -128,7 +158,6 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding
128158
auth_strategy = EnvAuthStrategy(
129159
name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
130160
)
131-
registry = True
132161
fields = [
133162
TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
134163
]

packages/jupyter-ai-magics/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ huggingface_hub = "jupyter_ai_magics:HfHubProvider"
6666
ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaProvider"
6767
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIProvider"
6868
openai-chat = "jupyter_ai_magics.partner_providers.openai:ChatOpenAIProvider"
69+
openai-chat-custom = "jupyter_ai_magics.partner_providers.openai:ChatOpenAICustomProvider"
6970
azure-chat-openai = "jupyter_ai_magics.partner_providers.openai:AzureChatOpenAIProvider"
7071
sagemaker-endpoint = "jupyter_ai_magics.partner_providers.aws:SmEndpointProvider"
7172
amazon-bedrock = "jupyter_ai_magics.partner_providers.aws:BedrockProvider"
@@ -87,6 +88,7 @@ gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
8788
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
8889
ollama = "jupyter_ai_magics.partner_providers.ollama:OllamaEmbeddingsProvider"
8990
openai = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsProvider"
91+
openai-custom = "jupyter_ai_magics.partner_providers.openai:OpenAIEmbeddingsCustomProvider"
9092
qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"
9193

9294
[tool.hatch.version]

packages/jupyter-ai/jupyter_ai/config/config_schema.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@
4444
},
4545
"additionalProperties": false
4646
},
47+
"embeddings_fields": {
48+
"$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for embeddings.",
49+
"type": "object",
50+
"default": {},
51+
"patternProperties": {
52+
"^.*$": {
53+
"anyOf": [{ "type": "object" }]
54+
}
55+
},
56+
"additionalProperties": false
57+
},
4758
"completions_fields": {
4859
"$comment": "Dictionary of model-specific fields, mapping LM GIDs to sub-dictionaries of field key-value pairs for completions.",
4960
"type": "object",

packages/jupyter-ai/jupyter_ai/config_manager.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _process_existing_config(self, default_config):
184184
def _validate_model_ids(self, config):
185185
lm_provider_keys = ["model_provider_id", "completions_model_provider_id"]
186186
em_provider_keys = ["embeddings_provider_id"]
187+
clm_provider_keys = ["completions_model_provider_id"]
187188

188189
# if the currently selected language or embedding model are
189190
# forbidden, set them to `None` and log a warning.
@@ -201,6 +202,13 @@ def _validate_model_ids(self, config):
201202
f"Embedding model {em_id} is forbidden by current allow/blocklists. Setting to None."
202203
)
203204
setattr(config, em_key, None)
205+
for clm_key in clm_provider_keys:
206+
clm_id = getattr(config, clm_key)
207+
if clm_id is not None and not self._validate_model(clm_id, raise_exc=False):
208+
self.log.warning(
209+
f"Completion model {clm_id} is forbidden by current allow/blocklists. Setting to None."
210+
)
211+
setattr(config, clm_key, None)
204212

205213
# if the currently selected language or embedding model ids are
206214
# not associated with models, set them to `None` and log a warning.
@@ -218,6 +226,16 @@ def _validate_model_ids(self, config):
218226
f"No embedding model is associated with '{em_id}'. Setting to None."
219227
)
220228
setattr(config, em_key, None)
229+
for clm_key in clm_provider_keys:
230+
clm_id = getattr(config, clm_key)
231+
if (
232+
clm_id is not None
233+
and not get_lm_provider(clm_id, self._lm_providers)[1]
234+
):
235+
self.log.warning(
236+
f"No completion model is associated with '{clm_id}'. Setting to None."
237+
)
238+
setattr(config, clm_key, None)
221239

222240
return config
223241

@@ -228,7 +246,8 @@ def _init_defaults(self):
228246
config_keys = GlobalConfig.model_fields.keys()
229247
schema_properties = self.validator.schema.get("properties", {})
230248
default_config = {
231-
field: schema_properties.get(field).get("default") for field in config_keys
249+
field: schema_properties.get(field, {}).get("default")
250+
for field in config_keys
232251
}
233252
if self._defaults is None:
234253
return default_config
@@ -283,6 +302,36 @@ def _validate_config(self, config: GlobalConfig):
283302
# verify model is authenticated
284303
_validate_provider_authn(config, lm_provider)
285304

305+
# verify fields exist for this model if needed
306+
if lm_provider.fields and config.model_provider_id not in config.fields:
307+
config.fields[config.model_provider_id] = {}
308+
309+
# validate completions model config
310+
if config.completions_model_provider_id:
311+
_, completions_provider = get_lm_provider(
312+
config.completions_model_provider_id, self._lm_providers
313+
)
314+
315+
# verify model is declared by some provider
316+
if not completions_provider:
317+
raise ValueError(
318+
f"No language model is associated with '{config.completions_model_provider_id}'."
319+
)
320+
321+
# verify model is not blocked
322+
self._validate_model(config.completions_model_provider_id)
323+
324+
# verify model is authenticated
325+
_validate_provider_authn(config, completions_provider)
326+
327+
# verify completions fields exist for this model if needed
328+
if (
329+
completions_provider.fields
330+
and config.completions_model_provider_id
331+
not in config.completions_fields
332+
):
333+
config.completions_fields[config.completions_model_provider_id] = {}
334+
286335
# validate embedding model config
287336
if config.embeddings_provider_id:
288337
_, em_provider = get_em_provider(
@@ -301,6 +350,13 @@ def _validate_config(self, config: GlobalConfig):
301350
# verify model is authenticated
302351
_validate_provider_authn(config, em_provider)
303352

353+
# verify embedding fields exist for this model if needed
354+
if (
355+
em_provider.fields
356+
and config.embeddings_provider_id not in config.embeddings_fields
357+
):
358+
config.embeddings_fields[config.embeddings_provider_id] = {}
359+
304360
def _validate_model(self, model_id: str, raise_exc=True):
305361
"""
306362
Validates a model against the set of allow/blocklists specified by the
@@ -349,6 +405,9 @@ def _write_config(self, new_config: GlobalConfig):
349405
new_config.completions_fields = {
350406
k: v for k, v in new_config.completions_fields.items() if v
351407
}
408+
new_config.embeddings_fields = {
409+
k: v for k, v in new_config.embeddings_fields.items() if v
410+
}
352411

353412
self._validate_config(new_config)
354413
with open(self.config_path, "w") as f:
@@ -462,18 +521,25 @@ def _provider_params(self, key, listing, completions: bool = False):
462521
# get config fields (e.g. base API URL, etc.)
463522
if completions:
464523
fields = config.completions_fields.get(model_uid, {})
524+
elif key == "embeddings_provider_id":
525+
fields = config.embeddings_fields.get(model_uid, {})
465526
else:
466527
fields = config.fields.get(model_uid, {})
467528

468529
# exclude empty fields
469530
# TODO: modify the config manager to never save empty fields in the
470531
# first place.
471-
for field_key in fields:
472-
if isinstance(fields[field_key], str) and not len(fields[field_key]):
473-
fields[field_key] = None
532+
fields = {
533+
k: None if isinstance(v, str) and not len(v) else v
534+
for k, v in fields.items()
535+
}
474536

475537
# get authn fields
476-
_, Provider = get_em_provider(model_uid, listing)
538+
_, Provider = (
539+
get_em_provider(model_uid, listing)
540+
if key == "embeddings_provider_id"
541+
else get_lm_provider(model_uid, listing)
542+
)
477543
authn_fields = {}
478544
if Provider.auth_strategy and Provider.auth_strategy.type == "env":
479545
keyword_param = (

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,21 @@ class DescribeConfigResponse(BaseModel):
5353
last_read: int
5454
completions_model_provider_id: Optional[str] = None
5555
completions_fields: Dict[str, Dict[str, Any]]
56+
embeddings_fields: Dict[str, Dict[str, Any]]
5657

5758

5859
class UpdateConfigRequest(BaseModel):
5960
model_provider_id: Optional[str] = None
6061
embeddings_provider_id: Optional[str] = None
62+
completions_model_provider_id: Optional[str] = None
6163
send_with_shift_enter: Optional[bool] = None
6264
api_keys: Optional[Dict[str, str]] = None
63-
fields: Optional[Dict[str, Dict[str, Any]]] = None
6465
# if passed, this will raise an Error if the config was written to after the
6566
# time specified by `last_read` to prevent write-write conflicts.
6667
last_read: Optional[int] = None
67-
completions_model_provider_id: Optional[str] = None
68+
fields: Optional[Dict[str, Dict[str, Any]]] = None
6869
completions_fields: Optional[Dict[str, Dict[str, Any]]] = None
70+
embeddings_fields: Optional[Dict[str, Dict[str, Any]]] = None
6971

7072
@field_validator("send_with_shift_enter", "api_keys", "fields", mode="before")
7173
@classmethod
@@ -88,6 +90,7 @@ class GlobalConfig(BaseModel):
8890
api_keys: Dict[str, str]
8991
completions_model_provider_id: Optional[str] = None
9092
completions_fields: Dict[str, Dict[str, Any]]
93+
embeddings_fields: Dict[str, Dict[str, Any]]
9194

9295

9396
class ListSlashCommandsEntry(BaseModel):

packages/jupyter-ai/jupyter_ai/tests/__snapshots__/test_config_manager.ambr

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
'completions_fields': dict({
77
}),
88
'completions_model_provider_id': None,
9+
'embeddings_fields': dict({
10+
}),
911
'embeddings_provider_id': None,
1012
'fields': dict({
1113
}),

0 commit comments

Comments
 (0)