Skip to content

Commit 0903d76

Browse files
✨ feat(ai): 新增对 LM Studio 作为 AI 提供商的支持
- 【新功能】引入 `LMStudioProvider`,允许用户连接到本地运行的、兼容 OpenAI API 的 LM Studio 服务。 - 【重构】重构 `BaseOpenAIProvider`,将 `apiKey` 设为可选,以支持不需要 API 密钥的自托管模型。 - 【配置】在配置中添加 `lmstudio.baseUrl` 选项,并为差异分析添加 `autoDetectStaged` 和 `fallbackToAll` 选项。 - 【依赖】更新 `openai` 依赖至 `^5.21.0`。
1 parent 79cd783 commit 0903d76

File tree

8 files changed

+100
-17
lines changed

8 files changed

+100
-17
lines changed

package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@
219219
"Azure OpenAI",
220220
"Cloudflare",
221221
"GoogleAI",
222-
"VertexAI"
222+
"VertexAI",
223+
"LMStudio"
223224
]
224225
},
225226
"dish-ai-commit.base.model": {
@@ -383,20 +384,36 @@
383384
"default": "",
384385
"description": "Groq API Key / Groq API 密钥"
385386
},
387+
"dish-ai-commit.providers.lmstudio.baseUrl": {
388+
"type": "string",
389+
"default": "http://localhost:1234/v1",
390+
"description": "LMStudio API Base URL / LMStudio API 基础地址"
391+
},
386392
"dish-ai-commit.features.suppressNonCriticalWarnings": {
387393
"type": "boolean",
388394
"default": true,
389395
"description": "Suppress non-critical warning popups, such as context length warnings. / 禁用非关键性警告弹窗,例如上下文长度警告。"
390396
},
391397
"dish-ai-commit.features.codeAnalysis.diffTarget": {
392398
"type": "string",
393-
"default": "all",
394-
"description": "Specify the target for git diff: 'staged' for staged changes, 'all' for all changes. / 指定 git diff 的目标:'staged' 表示暂存区的更改,'all' 表示所有更改。",
399+
"default": "auto",
400+
"description": "Specify the target for git diff: 'staged' for staged changes, 'all' for all changes, 'auto' for automatic detection. / 指定 git diff 的目标:'staged' 表示暂存区的更改,'all' 表示所有更改,'auto' 表示自动检测",
395401
"enum": [
396402
"staged",
397-
"all"
403+
"all",
404+
"auto"
398405
]
399406
},
407+
"dish-ai-commit.features.codeAnalysis.autoDetectStaged": {
408+
"type": "boolean",
409+
"default": true,
410+
"description": "Automatically detect staged content and prioritize it over all changes when diffTarget is 'auto'. / 当 diffTarget 为 'auto' 时,自动检测暂存区内容并优先使用。"
411+
},
412+
"dish-ai-commit.features.codeAnalysis.fallbackToAll": {
413+
"type": "boolean",
414+
"default": true,
415+
"description": "When staged area is empty, fallback to analyze all working directory changes. / 当暂存区为空时,回退到分析所有工作目录更改。"
416+
},
400417
"dish-ai-commit.features.codeAnalysis.simplifyDiff": {
401418
"type": "boolean",
402419
"default": false,
@@ -691,7 +708,7 @@
691708
"inversify": "^7.5.2",
692709
"node-notifier": "^10.0.1",
693710
"ollama": "^0.5.16",
694-
"openai": "^5.3.0",
711+
"openai": "^5.21.0",
695712
"tiktoken": "^1.0.21",
696713
"tree-sitter-wasms": "^0.1.11",
697714
"vite": "^7.0.6",
@@ -738,5 +755,5 @@
738755
]
739756
}
740757
},
741-
"packageManager": "pnpm@10.12.1"
758+
"packageManager": "pnpm@10.16.1"
742759
}

src/ai/ai-provider-factory.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import { CloudflareWorkersAIProvider } from "./providers/cloudflare-workersai-pr
2222
import { VertexAIProvider } from "./providers/vertexai-provider";
2323
import { GroqAIProvider } from "./providers/groq-provider";
2424
import { BaiduQianfanProvider } from "./providers/baidu-qianfan-provider";
25+
import { LMStudioProvider } from "./providers/lmstudio-provider";
2526

2627
/**
2728
* AI提供者工厂类,负责创建和管理不同AI服务提供者的实例
@@ -148,6 +149,8 @@ export class AIProviderFactory {
148149
return { ...baseConfig, ...config.providers?.mistral };
149150
case AIProvider.BAIDU_QIANFAN:
150151
return { ...baseConfig, ...config.providers?.baiduQianfan };
152+
case AIProvider.LMSTUDIO:
153+
return { ...baseConfig, ...config.providers?.lmstudio };
151154
default:
152155
return baseConfig;
153156
}
@@ -260,6 +263,9 @@ export class AIProviderFactory {
260263
case AIProvider.BAIDU_QIANFAN:
261264
provider = new BaiduQianfanProvider();
262265
break;
266+
case AIProvider.LMSTUDIO:
267+
provider = new LMStudioProvider();
268+
break;
263269
default:
264270
throw new Error(formatMessage("provider.type.unknown", [type]));
265271
}
@@ -301,6 +307,7 @@ export class AIProviderFactory {
301307
new MistralAIProvider(),
302308
new GroqAIProvider(),
303309
new BaiduQianfanProvider(),
310+
new LMStudioProvider(),
304311
];
305312
}
306313

src/ai/providers/base-openai-provider.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import { generateWithRetry, getSystemPrompt } from "../utils/generate-helper"; /
1818
*/
1919
export interface OpenAIProviderConfig {
2020
/** OpenAI API密钥 */
21-
apiKey: string;
21+
apiKey?: string;
2222
/** API基础URL,对于非官方OpenAI端点可自定义 */
2323
baseURL?: string;
2424
/** API版本号 */
@@ -72,16 +72,16 @@ export abstract class BaseOpenAIProvider extends AbstractAIProvider {
7272
* @protected
7373
*/
7474
protected createClient(): OpenAI {
75+
const apiKey = this.config.apiKey ?? "local-dummy-key";
7576
const config: any = {
76-
apiKey: this.config.apiKey,
77+
apiKey: apiKey,
7778
};
7879

7980
if (this.config.baseURL) {
8081
config.baseURL = this.config.baseURL;
81-
if (this.config.apiKey) {
82-
// config.defaultQuery = { "api-version": this.config.apiVersion };
83-
config.defaultHeaders = { "api-key": this.config.apiKey };
84-
}
82+
config.defaultHeaders = {
83+
"api-key": apiKey,
84+
};
8585
}
8686

8787
return new OpenAI(config);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import { ConfigurationManager } from "../../config/configuration-manager";
2+
import { AIModel } from "../types";
3+
import { BaseOpenAIProvider } from "./base-openai-provider";
4+
5+
const provider = { id: "lmstudio", name: "LMStudio" } as const;
6+
7+
const models: AIModel[] = [
8+
{
9+
id: "lmstudio-model",
10+
name: "LMStudio Default Model",
11+
maxTokens: { input: 4096, output: 2048 },
12+
provider: provider,
13+
default: true,
14+
},
15+
];
16+
17+
export class LMStudioProvider extends BaseOpenAIProvider {
18+
constructor() {
19+
const configManager = ConfigurationManager.getInstance();
20+
super({
21+
baseURL: configManager.getConfig("PROVIDERS_LMSTUDIO_BASEURL"),
22+
providerId: "lmstudio",
23+
providerName: "LMStudio",
24+
models: models,
25+
defaultModel: "lmstudio-model",
26+
});
27+
}
28+
29+
async isAvailable(): Promise<boolean> {
30+
try {
31+
await this.withTimeout(
32+
this.withRetry(async () => {
33+
await this.openai.models.list();
34+
})
35+
);
36+
return true;
37+
} catch {
38+
return false;
39+
}
40+
}
41+
}

src/ai/types.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,8 @@ export type VertexAIModels =
437437

438438
export type XAIModels = "grok-1.5-flash" | "grok-1.5";
439439

440+
export type LMStudioModels = string;
441+
440442
// 所有支持的模型名称类型
441443
export type ModelNames =
442444
| OpenAIModels
@@ -456,6 +458,7 @@ export type ModelNames =
456458
| VertexAIModels
457459
| MistralAIModels
458460
| XAIModels
461+
| LMStudioModels
459462
| "mixtral-8x7b-32768";
460463

461464
export type PremAIModels = string;
@@ -483,7 +486,8 @@ export type AIProviders =
483486
| "cloudflare"
484487
| "vertexai"
485488
| "groq"
486-
| "siliconflow";
489+
| "siliconflow"
490+
| "lmstudio";
487491
export type AnthropicAIModels =
488492
| "claude-3-opus-20240229"
489493
| "claude-3-sonnet-20240229"
@@ -532,6 +536,8 @@ export type AIModels<Provider extends AIProviders = AIProviders> =
532536
? VertexAIModels
533537
: Provider extends "groq"
534538
? "mixtral-8x7b-32768"
539+
: Provider extends "lmstudio"
540+
? LMStudioModels
535541
: OpenAIModels;
536542

537543
export type SupportedAIModels =

src/config/config-schema.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ export const CONFIG_SCHEMA = {
7676
"Cloudflare",
7777
"GoogleAI",
7878
"VertexAI",
79+
"LMStudio",
7980
],
8081
description: "AI provider / AI 提供商",
8182
},
@@ -285,6 +286,13 @@ export const CONFIG_SCHEMA = {
285286
description: "Groq API Key / Groq API 密钥",
286287
},
287288
},
289+
lmstudio: {
290+
baseUrl: {
291+
type: "string",
292+
default: "http://localhost:1234/v1",
293+
description: "LMStudio API Base URL / LMStudio API 基础地址",
294+
},
295+
},
288296
},
289297
features: {
290298
suppressNonCriticalWarnings: {

src/config/generated/config-keys.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,14 @@ export const CONFIG_KEYS = {
5656
"PROVIDERS_VERTEXAI_GOOGLEAUTHOPTIONS": "dish-ai-commit.providers.vertexai.googleAuthOptions",
5757
"PROVIDERS_GROQ": "dish-ai-commit.providers.groq",
5858
"PROVIDERS_GROQ_APIKEY": "dish-ai-commit.providers.groq.apiKey",
59+
"PROVIDERS_LMSTUDIO": "dish-ai-commit.providers.lmstudio",
60+
"PROVIDERS_LMSTUDIO_BASEURL": "dish-ai-commit.providers.lmstudio.baseUrl",
5961
"FEATURES": "dish-ai-commit.features",
6062
"FEATURES_SUPPRESSNONCRITICALWARNINGS": "dish-ai-commit.features.suppressNonCriticalWarnings",
6163
"FEATURES_CODEANALYSIS": "dish-ai-commit.features.codeAnalysis",
6264
"FEATURES_CODEANALYSIS_DIFFTARGET": "dish-ai-commit.features.codeAnalysis.diffTarget",
65+
"FEATURES_CODEANALYSIS_AUTODETECTSTAGED": "dish-ai-commit.features.codeAnalysis.autoDetectStaged",
66+
"FEATURES_CODEANALYSIS_FALLBACKTOALL": "dish-ai-commit.features.codeAnalysis.fallbackToAll",
6367
"FEATURES_CODEANALYSIS_SIMPLIFYDIFF": "dish-ai-commit.features.codeAnalysis.simplifyDiff",
6468
"FEATURES_COMMITFORMAT": "dish-ai-commit.features.commitFormat",
6569
"FEATURES_COMMITFORMAT_ENABLEEMOJI": "dish-ai-commit.features.commitFormat.enableEmoji",

0 commit comments

Comments
 (0)