Skip to content

Commit eeb2fc6

Browse files
authored
Merge pull request #6696 from remix-project-org/ai-model-selection
Adding model names
2 parents 9b9cac1 + a6035cf commit eeb2fc6

File tree

13 files changed

+683
-272
lines changed

13 files changed

+683
-272
lines changed

apps/remix-ide/src/app/plugins/remixAIPlugin.tsx

Lines changed: 184 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import * as packageJson from '../../../../../package.json'
22
import { Plugin } from '@remixproject/engine';
33
import { trackMatomoEvent } from '@remix-api'
4-
import { RemoteInferencer, IRemoteModel, IParams, GenerationParams, AssistantParams, CodeExplainAgent, SecurityAgent, CompletionParams, OllamaInferencer, isOllamaAvailable, getBestAvailableModel } from '@remix/remix-ai-core';
4+
import { RemoteInferencer, IRemoteModel, IParams, GenerationParams, AssistantParams, CodeExplainAgent, SecurityAgent, CompletionParams, OllamaInferencer, isOllamaAvailable, getBestAvailableModel, listModels } from '@remix/remix-ai-core';
55
import { CodeCompletionAgent, ContractAgent, workspaceAgent, IContextType, mcpDefaultServersConfig } from '@remix/remix-ai-core';
66
import { MCPInferencer } from '@remix/remix-ai-core';
77
import { IMCPServer, IMCPConnectionStatus } from '@remix/remix-ai-core';
88
import { RemixMCPServer, createRemixMCPServer } from '@remix/remix-ai-core';
9+
import { AIModel, getDefaultModel, getModelById } from '@remix/remix-ai-core';
910
import axios from 'axios';
1011
import { endpointUrls } from "@remix-endpoints-helper"
1112
import { QueryParams } from '@remix-project/remix-lib'
@@ -22,7 +23,8 @@ const profile = {
2223
"code_insertion", "error_explaining", "vulnerability_check", 'generate',
2324
"initialize", 'chatPipe', 'ProcessChatRequestBuffer', 'isChatRequestPending',
2425
'resetChatRequestBuffer', 'setAssistantThrId',
25-
'getAssistantThrId', 'getAssistantProvider', 'setAssistantProvider', 'setModel',
26+
'getAssistantThrId', 'getAssistantProvider', 'setAssistantProvider', 'setModel', 'setOllamaModel',
27+
'getSelectedModel', 'getModelAccess', 'getOllamaModels',
2628
'addMCPServer', 'removeMCPServer', 'getMCPConnectionStatus', 'getMCPResources', 'getMCPTools', 'executeMCPTool',
2729
'enableMCPEnhancement', 'disableMCPEnhancement', 'isMCPEnabled', 'getIMCPServers',
2830
'loadMCPServersFromSettings', 'clearCaches'
@@ -40,14 +42,15 @@ const profile = {
4042
// add Plugin<any, CustomRemixApi>
4143
export class RemixAIPlugin extends Plugin {
4244
aiIsActivated:boolean = false
43-
remoteInferencer:RemoteInferencer = null
45+
remoteInferencer:RemoteInferencer | OllamaInferencer | MCPInferencer = null
4446
isInferencing: boolean = false
4547
chatRequestBuffer: chatRequestBufferT<any> = null
4648
codeExpAgent: CodeExplainAgent
4749
securityAgent: SecurityAgent
4850
contractor: ContractAgent
4951
workspaceAgent: workspaceAgent
50-
assistantProvider: string = 'mistralai' // default provider
52+
selectedModel: AIModel = getDefaultModel() // default model
53+
selectedModelId: string = getDefaultModel().id
5154
assistantThreadId: string = ''
5255
useRemoteInferencer:boolean = true
5356
completionAgent: CodeCompletionAgent
@@ -83,7 +86,29 @@ export class RemixAIPlugin extends Plugin {
8386
this.isInferencing = false
8487
})
8588

86-
this.setAssistantProvider(this.assistantProvider) // propagate the provider to the remote inferencer
89+
// Load saved model preference
90+
const savedModelId = await this.call('settings', 'get', 'ai/selectedModel')
91+
if (savedModelId) {
92+
await this.setModel(savedModelId)
93+
} else {
94+
// Migration: Convert old provider preference to model
95+
const oldProvider = await this.call('settings', 'get', 'ai/assistantProvider')
96+
if (oldProvider) {
97+
const migrationMap = {
98+
'openai': 'gpt-4-turbo',
99+
'mistralai': 'mistral-medium-latest',
100+
'anthropic': 'claude-sonnet-4-5',
101+
'ollama': 'ollama'
102+
}
103+
const modelId = migrationMap[oldProvider] || getDefaultModel().id
104+
await this.call('settings', 'set', 'ai/selectedModel', modelId)
105+
await this.setModel(modelId)
106+
} else {
107+
// Set default model
108+
await this.setModel(this.selectedModelId)
109+
}
110+
}
111+
87112
this.aiIsActivated = true
88113

89114
this.on('blockchain', 'transactionExecuted', async () => {
@@ -233,7 +258,7 @@ export class RemixAIPlugin extends Plugin {
233258
async generateWorkspace (userPrompt: string, params: IParams=AssistantParams, newThreadID:string="", useRag:boolean=false, statusCallback?: (status: string) => Promise<void>): Promise<any> {
234259
params.stream_result = false // enforce no stream result
235260
params.threadId = newThreadID
236-
params.provider = this.assistantProvider
261+
params.provider = this.selectedModel.provider
237262
useRag = false
238263
trackMatomoEvent(this, { category: 'ai', action: 'GenerateNewAIWorkspace', name: 'WorkspaceAgentEdit', isClick: false })
239264

@@ -333,72 +358,109 @@ export class RemixAIPlugin extends Plugin {
333358
}
334359

335360
async getAssistantProvider(){
336-
return this.assistantProvider
361+
// Legacy method for backwards compatibility
362+
return this.selectedModel.provider
363+
}
364+
365+
async getSelectedModel(){
366+
return this.selectedModelId
337367
}
338368

339369
async setAssistantProvider(provider: string) {
340-
if (provider === 'openai' || provider === 'mistralai' || provider === 'anthropic') {
341-
GenerationParams.provider = provider
342-
CompletionParams.provider = provider
343-
AssistantParams.provider = provider
344-
345-
if (this.assistantProvider !== provider){
346-
// clear the threadDds
347-
this.assistantThreadId = ''
348-
GenerationParams.threadId = ''
349-
CompletionParams.threadId = ''
350-
AssistantParams.threadId = ''
351-
}
352-
this.assistantProvider = provider
353-
354-
// Switch back to remote inferencer for cloud providers -- important
355-
if (this.remoteInferencer && this.remoteInferencer instanceof OllamaInferencer) {
356-
this.remoteInferencer = new RemoteInferencer()
357-
this.remoteInferencer.event.on('onInference', () => {
358-
this.isInferencing = true
359-
})
360-
this.remoteInferencer.event.on('onInferenceDone', () => {
361-
this.isInferencing = false
362-
})
363-
}
364-
} else if (provider === 'ollama') {
365-
const isAvailable = await isOllamaAvailable();
366-
if (!isAvailable) {
367-
return
368-
}
370+
// Legacy method - map provider to a default model for backwards compatibility
371+
const providerToModelMap = {
372+
'openai': 'gpt-4-turbo',
373+
'mistralai': 'mistral-medium-latest',
374+
'anthropic': 'claude-sonnet-4-5',
375+
'ollama': 'ollama'
376+
}
377+
const modelId = providerToModelMap[provider] || getDefaultModel().id
378+
await this.setModel(modelId)
379+
}
369380

370-
const bestModel = await getBestAvailableModel();
371-
if (!bestModel) {
372-
return
373-
}
381+
async setModel(modelId: string) {
382+
let model = getModelById(modelId)
383+
if (!model) {
384+
model = getDefaultModel()
385+
modelId = model.id
386+
}
374387

375-
// Switch to Ollama inferencer
376-
this.remoteInferencer = new OllamaInferencer(bestModel);
377-
this.remoteInferencer.event.on('onInference', () => {
378-
this.isInferencing = true
379-
})
380-
this.remoteInferencer.event.on('onInferenceDone', () => {
381-
this.isInferencing = false
382-
})
388+
// Store previous model for comparison
389+
const previousModelId = this.selectedModelId
390+
391+
this.selectedModelId = modelId
392+
this.selectedModel = model
393+
394+
// Update inference parameters
395+
GenerationParams.provider = model.provider
396+
GenerationParams.model = modelId
397+
CompletionParams.provider = model.provider
398+
CompletionParams.model = modelId
399+
AssistantParams.provider = model.provider
400+
AssistantParams.model = modelId
401+
402+
// Clear thread IDs when switching models
403+
if (previousModelId !== modelId) {
404+
this.assistantThreadId = ''
405+
GenerationParams.threadId = ''
406+
CompletionParams.threadId = ''
407+
AssistantParams.threadId = ''
408+
}
383409

384-
if (this.assistantProvider !== provider){
385-
// clear the threadIds
386-
this.assistantThreadId = ''
387-
GenerationParams.threadId = ''
388-
CompletionParams.threadId = ''
389-
AssistantParams.threadId = ''
410+
// Switch inferencer based on provider
411+
if (model.provider === 'ollama') {
412+
// Ollama requires sub-model selection, use best available for now
413+
const isAvailable = await isOllamaAvailable();
414+
if (!isAvailable) {
415+
console.error('Ollama is not available. Please ensure Ollama is running. Falling back to default model.')
416+
const defaultModel = getDefaultModel()
417+
model = defaultModel
418+
modelId = defaultModel.id
419+
this.selectedModelId = modelId
420+
this.selectedModel = model
421+
GenerationParams.provider = model.provider
422+
GenerationParams.model = modelId
423+
CompletionParams.provider = model.provider
424+
CompletionParams.model = modelId
425+
AssistantParams.provider = model.provider
426+
AssistantParams.model = modelId
427+
} else {
428+
const bestModel = await getBestAvailableModel();
429+
if (!bestModel) {
430+
console.error('No Ollama models available. Falling back to default model.')
431+
// Fall back to default model
432+
const defaultModel = getDefaultModel()
433+
model = defaultModel
434+
modelId = defaultModel.id
435+
this.selectedModelId = modelId
436+
this.selectedModel = model
437+
GenerationParams.provider = model.provider
438+
GenerationParams.model = modelId
439+
CompletionParams.provider = model.provider
440+
CompletionParams.model = modelId
441+
AssistantParams.provider = model.provider
442+
AssistantParams.model = modelId
443+
} else {
444+
// Switch to Ollama inferencer
445+
this.remoteInferencer = new OllamaInferencer(bestModel);
446+
this.remoteInferencer.event.on('onInference', () => {
447+
this.isInferencing = true
448+
})
449+
this.remoteInferencer.event.on('onInferenceDone', () => {
450+
this.isInferencing = false
451+
})
452+
}
390453
}
391-
this.assistantProvider = provider
392-
} else {
393-
console.error(`Unknown assistant provider: ${provider}`)
394454
}
395455

396-
// If MCP is enabled, update it to use the new Ollama inferencer
456+
// Update MCP inferencer if enabled
397457
if (this.mcpEnabled) {
398458
this.mcpInferencer = new MCPInferencer(this.mcpServers, undefined, undefined, this.remixMCPServer, this.remoteInferencer);
399459
this.mcpInferencer.event.on('mcpServerConnected', (serverName: string) => {
460+
// Handle server connected
400461
})
401462
this.mcpInferencer.event.on('mcpServerError', (serverName: string, error: Error) => {
463+
// Handle server error
402464
})
403465
this.mcpInferencer.event.on('onInference', () => {
404466
this.isInferencing = true
@@ -408,31 +470,76 @@ export class RemixAIPlugin extends Plugin {
408470
})
409471
await this.mcpInferencer.connectAllServers();
410472
}
473+
474+
// Save preference
475+
await this.call('settings', 'set', 'ai/selectedModel', modelId)
476+
477+
// Emit event for UI updates
478+
this.emit('modelChanged', modelId)
411479
}
412480

413-
async setModel(modelName: string) {
414-
if (this.assistantProvider === 'ollama' && this.remoteInferencer instanceof OllamaInferencer) {
415-
try {
416-
const isAvailable = await isOllamaAvailable();
417-
if (!isAvailable) {
418-
console.error('Ollama is not available. Please ensure Ollama is running.')
419-
return
420-
}
481+
async setOllamaModel(ollamaModelName: string) {
482+
// Special method for selecting specific Ollama model after "Ollama" is selected
483+
if (this.selectedModel.provider !== 'ollama') {
484+
console.warn('setOllamaModel should only be called when Ollama provider is selected')
485+
return
486+
}
421487

422-
this.remoteInferencer = new OllamaInferencer(modelName);
423-
this.remoteInferencer.event.on('onInference', () => {
424-
this.isInferencing = true
425-
})
426-
this.remoteInferencer.event.on('onInferenceDone', () => {
427-
this.isInferencing = false
428-
})
488+
const isAvailable = await isOllamaAvailable();
489+
if (!isAvailable) {
490+
console.error('Ollama is not available. Please ensure Ollama is running.')
491+
return
492+
}
429493

430-
} catch (error) {
431-
console.error('Failed to set Ollama model:', error)
494+
this.remoteInferencer = new OllamaInferencer(ollamaModelName);
495+
this.remoteInferencer.event.on('onInference', () => {
496+
this.isInferencing = true
497+
})
498+
this.remoteInferencer.event.on('onInferenceDone', () => {
499+
this.isInferencing = false
500+
})
501+
502+
// Update MCP if enabled
503+
if (this.mcpEnabled && this.mcpInferencer) {
504+
this.mcpInferencer = new MCPInferencer(this.mcpServers, undefined, undefined, this.remixMCPServer, this.remoteInferencer);
505+
await this.mcpInferencer.connectAllServers();
506+
}
507+
}
508+
509+
async getModelAccess(): Promise<string[]> {
510+
try {
511+
const token = localStorage.getItem('remix_access_token')
512+
const headers = token ? { 'Authorization': `Bearer ${token}` } : {}
513+
514+
const response = await fetch(`${endpointUrls.sso}/accounts`, {
515+
credentials: 'include',
516+
headers
517+
})
518+
519+
if (response.ok) {
520+
const data = await response.json()
521+
return data.allowed_models || []
432522
}
433-
} else {
434-
console.warn(`setModel is only supported for Ollama provider. Current provider: ${this.assistantProvider}`)
523+
} catch (err) {
524+
console.error('Failed to fetch model access:', err)
435525
}
526+
527+
// Fallback: default model + ollama
528+
return [getDefaultModel().id, 'ollama']
529+
}
530+
531+
async getOllamaModels(): Promise<string[]> {
532+
if (this.selectedModel.provider !== 'ollama') {
533+
throw new Error('Ollama is not the selected provider')
534+
}
535+
536+
const available = await isOllamaAvailable()
537+
if (!available) {
538+
throw new Error('Ollama is not running')
539+
}
540+
541+
const models = await listModels()
542+
return models
436543
}
437544

438545
isChatRequestPending(){

libs/endpoints-helper/src/index.ts

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,22 @@ const localhostUrls: EndpointUrls = {
8282
github: 'http://localhost:3005/github',
8383
ghfolderpull: 'http://localhost:3005/ghfolderpull',
8484
gitHubLoginProxy: 'http://localhost:3005/github-login-proxy',
85-
85+
8686
// UTILITIES service (port 3007)
8787
solidityScan: 'http://localhost:3007/solidityscan',
8888
solidityScanWebSocket: 'ws://localhost:3007/solidityscan',
89-
89+
9090
// PLUGINS service (port 3006)
9191
ipfsGateway: 'http://localhost:3006/jqgt',
9292
embedly: 'http://localhost:3006/embedly',
9393
vyper2: 'http://localhost:3006/vyper2',
94-
94+
9595
// AI service (port 3003)
9696
solcoder: 'http://localhost:4000/solcoder',
9797
completion: 'http://localhost:3003/completion',
9898
gptChat: 'http://localhost:3003/gpt-chat',
9999
rag: 'http://localhost:3003/rag',
100-
100+
101101
// AUTH service (port 3001)
102102
sso: 'https://auth.api.remix.live:8443/sso',
103103

@@ -114,13 +114,13 @@ const localhostUrls: EndpointUrls = {
114114

115115
const resolvedUrls: EndpointUrls = prefix
116116
? (prefix.includes('localhost')
117-
? localhostUrls // Use direct service ports for localhost
118-
: Object.fromEntries( // Use prefix paths for production/ngrok
119-
Object.entries(defaultUrls).map(([key, _]) => [
120-
key,
121-
`${prefix}/${endpointPathMap[key as keyof EndpointUrls]}`,
122-
])
123-
) as EndpointUrls)
117+
? localhostUrls // Use direct service ports for localhost
118+
: Object.fromEntries( // Use prefix paths for production/ngrok
119+
Object.entries(defaultUrls).map(([key, _]) => [
120+
key,
121+
`${prefix}/${endpointPathMap[key as keyof EndpointUrls]}`,
122+
])
123+
) as EndpointUrls)
124124
: defaultUrls;
125125

126126
if (resolvedUrls.solidityScan.startsWith('https://')) {

libs/remix-ai-core/src/agents/contractAgent.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ export class ContractAgent {
1919

2020
private constructor(props) {
2121
this.plugin = props;
22-
AssistantParams.provider = this.plugin.assistantProvider
2322
}
2423

2524
public static getInstance(props) {

0 commit comments

Comments
 (0)