Skip to content

Commit f2e00ea

Browse files
committed
feat: 支持流式输出
1 parent 19f6288 commit f2e00ea

File tree

8 files changed

+82
-18
lines changed

8 files changed

+82
-18
lines changed

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "aigc-detector",
3-
"version": "1.0.3",
3+
"version": "1.0.4",
44
"description": "Detect if content is generated by AI",
55
"keywords": [
66
"aigc",

src/cli/commands/chat.ts

+18-9
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ class ChatCommand extends BaseCommand {
4242

4343
static flags = {};
4444

45-
private lastMessage = 'How can I help you today?';
46-
4745
private messages = new ChatMessageHistory();
4846

4947
async run(): Promise<void> {
@@ -54,16 +52,25 @@ class ChatCommand extends BaseCommand {
5452
apiKey: config.apiKey,
5553
platform: config.platform as unknown as Platform
5654
});
57-
const userDisplay = this.getDisplayContent(PromptRole.USER);
55+
const aiDisplay = this.getDisplayContent(PromptRole.AI);
56+
let lastMessage = 'How can I help you today?';
57+
58+
process.stdout.write(aiDisplay + lastMessage + '\n');
5859

5960
// eslint-disable-next-line no-constant-condition
6061
while (true) {
61-
const aiMessage = await this.addMessage(PromptRole.AI, this.lastMessage);
62-
const userMessage = await this.getUserMessage(aiMessage + `\n${userDisplay}`);
63-
const answer = await detector.chat(userMessage, await this.messages.getMessages());
62+
const userMessage = await this.getUserMessage();
63+
const stream = detector.chat(userMessage, await this.messages.getMessages());
64+
65+
process.stdout.write(aiDisplay);
66+
stream.pipe(process.stdout);
67+
68+
lastMessage = await stream.getData();
69+
70+
process.stdout.write('\n');
6471

6572
await this.addMessage(PromptRole.USER, userMessage);
66-
this.lastMessage = answer;
73+
await this.addMessage(PromptRole.AI, lastMessage);
6774
}
6875
} else {
6976
this.showHelp();
@@ -84,9 +91,11 @@ class ChatCommand extends BaseCommand {
8491
return chalk[roleDisplay.color](`[${roleDisplay.name}] `);
8592
}
8693

87-
private getUserMessage(aiMessage: string): Promise<string> {
94+
private getUserMessage(): Promise<string> {
95+
const userDisplay = this.getDisplayContent(PromptRole.USER);
96+
8897
return new Promise<string>((resolve) => {
89-
reader.question(aiMessage, resolve);
98+
reader.question(userDisplay, resolve);
9099
});
91100
}
92101
}

src/core/index.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import type { BaseMessage } from '@langchain/core/messages';
22

3+
import type Stream from './stream';
4+
35
import { PROMPT } from '../const';
46
import { getPlatform, type Platform } from '../platform';
57
import { getEnvConfig } from './env';
@@ -22,15 +24,14 @@ export class AIGC {
2224
this.platform = (env.platform as unknown as Platform) || options.platform;
2325
}
2426

25-
public async chat(content: string, messages: BaseMessage[]) {
27+
public chat(content: string, messages: BaseMessage[]): Stream {
2628
const platform = getPlatform(this.platform);
27-
const result = await platform.invoke(
29+
30+
return platform.stream(
2831
'You are a helpful assistant. Answer all questions to the best of your ability.',
2932
{ content, messages },
3033
this.apiKey
3134
);
32-
33-
return result;
3435
}
3536

3637
public async detect(content: string): Promise<ReturnType<typeof getDetectResult>> {

src/core/stream.ts

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { Transform, type TransformCallback } from 'node:stream';
2+
3+
class Stream extends Transform {
4+
private data = '';
5+
6+
getData(): Promise<string> {
7+
return new Promise((resolve) => {
8+
this.on('close', () => {
9+
resolve(this.data);
10+
});
11+
});
12+
}
13+
14+
_transform(chunk: Buffer, encoding: BufferEncoding, callback: TransformCallback): void {
15+
const data = chunk.toString();
16+
17+
this.data += data;
18+
19+
callback(null, data);
20+
}
21+
}
22+
23+
export default Stream;

src/platform/base.ts

+29-1
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@ import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemp
44
import { ChatOpenAI } from '@langchain/openai';
55
import { LLMChain } from 'langchain/chains';
66

7+
import Stream from '../core/stream';
8+
79
type InvokeParameter = Parameters<InstanceType<typeof LLMChain>['invoke']>[0];
810

911
abstract class Platform {
1012
protected temperature = 0.7;
1113

12-
protected getChatModel(apiKey?: string): BaseLanguageModel {
14+
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
1315
return new ChatOpenAI({
1416
apiKey,
1517
configuration: {
1618
baseURL: `https://${this.server}/v1`
1719
},
1820
frequencyPenalty: 1,
1921
model: this.model,
22+
streaming,
2023
temperature: this.temperature
2124
});
2225
}
@@ -39,6 +42,31 @@ abstract class Platform {
3942
return result.text;
4043
}
4144

45+
public stream(prompt: string, params: InvokeParameter, apiKey?: string): Stream {
46+
const promptTemplate = this.getPrompt(prompt);
47+
const chain = new LLMChain({
48+
llm: this.getChatModel(apiKey, true),
49+
prompt: promptTemplate
50+
});
51+
const stream = new Stream();
52+
53+
chain
54+
.invoke(params, {
55+
callbacks: [
56+
{
57+
handleLLMNewToken(token: string) {
58+
stream.write(token);
59+
}
60+
}
61+
]
62+
})
63+
.then(() => {
64+
stream.destroy();
65+
});
66+
67+
return stream;
68+
}
69+
4270
protected abstract model: string;
4371

4472
public abstract name: string;

src/platform/minimax.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ class MiniMax extends Platform {
1111

1212
protected server = 'api.minimax.chat';
1313

14-
protected getChatModel(apiKey?: string): BaseLanguageModel {
14+
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
1515
return new ChatMinimax({
1616
minimaxApiKey: apiKey,
1717
minimaxGroupId: '1782658868262748274',
1818
model: this.model,
19+
streaming,
1920
temperature: this.temperature
2021
});
2122
}

src/platform/tongyi.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ class TongYi extends Platform {
1111

1212
protected server = 'open.bigmodel.cn';
1313

14-
protected getChatModel(apiKey?: string): BaseLanguageModel {
14+
protected getChatModel(apiKey?: string, streaming = false): BaseLanguageModel {
1515
return new ChatAlibabaTongyi({
1616
alibabaApiKey: apiKey,
1717
model: this.model,
18+
streaming,
1819
temperature: this.temperature
1920
});
2021
}

src/platform/zhipu.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ class ZhiPu extends Platform {
1111

1212
protected server = 'open.bigmodel.cn';
1313

14-
protected getChatModel(apiKey?: string): BaseLanguageModel {
14+
protected getChatModel(apiKey?: string, streaming?: boolean): BaseLanguageModel {
1515
return new ChatZhipuAI({
1616
model: this.model,
17+
streaming,
1718
temperature: this.temperature,
1819
zhipuAIApiKey: apiKey
1920
});

0 commit comments

Comments
 (0)