Skip to content

Commit c6732b9

Browse files
authored
[onechat] factorize chat agent execution to runChatAgent function (elastic#222921)
## Summary Introduce a static `runChatAgent` function to allow re-using the agent's workflow from anywhere ### Example ```ts const completedRound = await runChatAgent( { nextInput, conversation, agentGraphName: defaultAgentGraphName, runId, onEvent: (event) => { events.emit(event); }, tools: toolProvider, }, { logger, runner, request, modelProvider, } ); ```
1 parent 60fc93c commit c6732b9

11 files changed

Lines changed: 200 additions & 67 deletions

File tree

x-pack/platform/packages/shared/onechat/onechat-server/agents/provider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ export interface AgentHandlerContext {
5454
*/
5555
modelProvider: ModelProvider;
5656
/**
57-
* Tool provider that should be used to list of execute tools.
57+
* Tool provider that can be used to list or execute tools.
5858
*/
5959
toolProvider: ToolProvider;
6060
/**

x-pack/platform/packages/shared/onechat/onechat-server/src/tools.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ export interface ToolHandlerContext {
100100
* Can be used to access the inference APIs or chatModel.
101101
*/
102102
modelProvider: ModelProvider;
103+
/**
104+
* Tool provider that can be used to list or execute tools.
105+
*/
106+
toolProvider: ToolProvider;
103107
/**
104108
* Onechat runner scoped to the current execution.
105109
* Can be used to run other workchat primitive as part of the tool execution.

x-pack/platform/plugins/shared/onechat/server/services/agents/conversational/graph.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@ import { ToolNode } from '@langchain/langgraph/prebuilt';
1212
import type { StructuredTool } from '@langchain/core/tools';
1313
import type { Logger } from '@kbn/core/server';
1414
import { InferenceChatModel } from '@kbn/inference-langchain';
15-
import { withSystemPrompt } from './system_prompt';
15+
import { withSystemPrompt, defaultSystemPrompt } from './system_prompt';
1616

1717
export const createAgentGraph = async ({
1818
chatModel,
1919
tools,
20+
systemPrompt = defaultSystemPrompt,
2021
}: {
2122
chatModel: InferenceChatModel;
2223
tools: StructuredTool[];
24+
systemPrompt?: string;
2325
logger: Logger;
2426
}) => {
2527
const StateAnnotation = Annotation.Root({
@@ -44,6 +46,7 @@ export const createAgentGraph = async ({
4446
const callModel = async (state: typeof StateAnnotation.State) => {
4547
const response = await model.invoke(
4648
await withSystemPrompt({
49+
systemPrompt,
4750
messages: [...state.initialMessages, ...state.addedMessages],
4851
})
4952
);
@@ -68,7 +71,7 @@ export const createAgentGraph = async ({
6871
};
6972
};
7073

71-
// note: the node names are used in the event convertion logic, they should not be changed
74+
// note: the node names are used in the event convertion logic, they should *not* be changed
7275
const graph = new StateGraph(StateAnnotation)
7376
.addNode('agent', callModel)
7477
.addNode('tools', toolHandler)

x-pack/platform/plugins/shared/onechat/server/services/agents/conversational/handler.ts

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,9 @@
55
* 2.0.
66
*/
77

8-
import { from, filter, shareReplay, firstValueFrom, map } from 'rxjs';
98
import type { Logger } from '@kbn/logging';
10-
import { StreamEvent } from '@langchain/core/tracers/log_stream';
11-
import { isRoundCompleteEvent } from '@kbn/onechat-common';
129
import type { ConversationalAgentHandlerFn } from '@kbn/onechat-server';
13-
import { providerToLangchainTools, conversationLangchainMessages } from './utils';
14-
import { createAgentGraph } from './graph';
15-
import { convertGraphEvents, addRoundCompleteEvent } from './convert_graph_events';
10+
import { runChatAgent } from './run_chat_agent';
1611

1712
export interface CreateConversationalAgentHandlerParams {
1813
logger: Logger;
@@ -30,54 +25,29 @@ export const createHandler = ({
3025
{ agentParams: { nextInput, conversation = [] }, runId },
3126
{ request, modelProvider, toolProvider, events, runner }
3227
) => {
33-
const model = await modelProvider.getDefaultModel();
34-
const tools = await providerToLangchainTools({ request, toolProvider, runner, logger });
35-
const initialMessages = conversationLangchainMessages({
36-
nextInput,
37-
previousRounds: conversation,
38-
});
39-
const agentGraph = await createAgentGraph({ logger, chatModel: model.chatModel, tools });
40-
41-
const eventStream = agentGraph.streamEvents(
42-
{ initialMessages },
28+
const completedRound = await runChatAgent(
4329
{
44-
version: 'v2',
45-
runName: defaultAgentGraphName,
46-
metadata: {
47-
graphName: defaultAgentGraphName,
48-
runId,
30+
nextInput,
31+
conversation,
32+
agentGraphName: defaultAgentGraphName,
33+
runId,
34+
onEvent: (event) => {
35+
events.emit(event);
4936
},
50-
recursionLimit: 10,
51-
callbacks: [],
37+
tools: toolProvider,
38+
},
39+
{
40+
logger,
41+
runner,
42+
request,
43+
modelProvider,
5244
}
5345
);
5446

55-
const events$ = from(eventStream).pipe(
56-
filter(isStreamEvent),
57-
convertGraphEvents({ graphName: defaultAgentGraphName, runName: defaultAgentGraphName }),
58-
addRoundCompleteEvent({ userInput: nextInput }),
59-
shareReplay()
60-
);
61-
62-
events$.subscribe((event) => {
63-
events.emit(event);
64-
});
65-
66-
const round = await firstValueFrom(
67-
events$.pipe(
68-
filter(isRoundCompleteEvent),
69-
map((event) => event.data.round)
70-
)
71-
);
72-
7347
return {
7448
result: {
75-
round,
49+
round: completedRound,
7650
},
7751
};
7852
};
7953
};
80-
81-
const isStreamEvent = (input: any): input is StreamEvent => {
82-
return 'event' in input;
83-
};

x-pack/platform/plugins/shared/onechat/server/services/agents/conversational/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
*/
77

88
export { createDefaultAgentProvider } from './provider';
9+
export { runChatAgent } from './run_chat_agent';
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
import { v4 as uuidv4 } from 'uuid';
9+
import { Observable, from, filter, shareReplay, firstValueFrom, map } from 'rxjs';
10+
import type { Logger } from '@kbn/logging';
11+
import { StreamEvent } from '@langchain/core/tracers/log_stream';
12+
import type { KibanaRequest } from '@kbn/core-http-server';
13+
import {
14+
RoundInput,
15+
ConversationRound,
16+
ChatAgentEvent,
17+
isRoundCompleteEvent,
18+
} from '@kbn/onechat-common';
19+
import type {
20+
ModelProvider,
21+
ScopedRunner,
22+
ExecutableTool,
23+
ToolProvider,
24+
} from '@kbn/onechat-server';
25+
import { providerToLangchainTools, toLangchainTool, conversationLangchainMessages } from './utils';
26+
import { createAgentGraph } from './graph';
27+
import { convertGraphEvents, addRoundCompleteEvent } from './convert_graph_events';
28+
29+
export interface RunChatAgentContext {
30+
logger: Logger;
31+
request: KibanaRequest;
32+
modelProvider: ModelProvider;
33+
runner: ScopedRunner;
34+
}
35+
36+
export interface RunChatAgentParams {
37+
/**
38+
* The next message in this conversation that the agent should respond to.
39+
*/
40+
nextInput: RoundInput;
41+
/**
42+
* Previous rounds of conversation.
43+
*/
44+
conversation?: ConversationRound[];
45+
/**
46+
* Optional system prompt to override the default one.
47+
*/
48+
systemPrompt?: string;
49+
/**
50+
* List of tools that will be exposed to the agent.
51+
* Either a list of tools or a tool provider.
52+
*/
53+
tools?: ToolProvider | ExecutableTool[];
54+
/**
55+
* In case of nested calls (e.g calling from a tool), allows to define the runId.
56+
*/
57+
runId?: string;
58+
/**
59+
* Handler to react to the agent's events.
60+
*/
61+
onEvent?: (event: ChatAgentEvent) => void;
62+
/**
63+
* Can be used to override the graph's name. Used for tracing.
64+
*/
65+
agentGraphName?: string;
66+
}
67+
68+
export type RunChatAgentFn = (
69+
params: RunChatAgentParams,
70+
context: RunChatAgentContext
71+
) => Promise<ConversationRound>;
72+
73+
const defaultAgentGraphName = 'default-onechat-agent';
74+
75+
const noopOnEvent = () => {};
76+
77+
/**
78+
* Create the handler function for the default onechat agent.
79+
*/
80+
export const runChatAgent: RunChatAgentFn = async (
81+
{
82+
nextInput,
83+
conversation = [],
84+
tools = [],
85+
onEvent = noopOnEvent,
86+
runId = uuidv4(),
87+
systemPrompt,
88+
agentGraphName = defaultAgentGraphName,
89+
},
90+
{ logger, request, modelProvider }
91+
) => {
92+
const model = await modelProvider.getDefaultModel();
93+
const langchainTools = Array.isArray(tools)
94+
? tools.map((tool) => toLangchainTool({ tool, logger }))
95+
: await providerToLangchainTools({ request, toolProvider: tools, logger });
96+
const initialMessages = conversationLangchainMessages({
97+
nextInput,
98+
previousRounds: conversation,
99+
});
100+
const agentGraph = await createAgentGraph({
101+
logger,
102+
chatModel: model.chatModel,
103+
tools: langchainTools,
104+
systemPrompt,
105+
});
106+
107+
const eventStream = agentGraph.streamEvents(
108+
{ initialMessages },
109+
{
110+
version: 'v2',
111+
runName: agentGraphName,
112+
metadata: {
113+
graphName: agentGraphName,
114+
runId,
115+
},
116+
recursionLimit: 10,
117+
callbacks: [],
118+
}
119+
);
120+
121+
const events$ = from(eventStream).pipe(
122+
filter(isStreamEvent),
123+
convertGraphEvents({ graphName: agentGraphName, runName: agentGraphName }),
124+
addRoundCompleteEvent({ userInput: nextInput }),
125+
shareReplay()
126+
);
127+
128+
events$.subscribe(onEvent);
129+
130+
return await extractRound(events$);
131+
};
132+
133+
export const extractRound = async (events$: Observable<ChatAgentEvent>) => {
134+
return await firstValueFrom(
135+
events$.pipe(
136+
filter(isRoundCompleteEvent),
137+
map((event) => event.data.round)
138+
)
139+
);
140+
};
141+
142+
const isStreamEvent = (input: any): input is StreamEvent => {
143+
return 'event' in input;
144+
};

x-pack/platform/plugins/shared/onechat/server/services/agents/conversational/system_prompt.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,25 @@
77

88
import { BaseMessage, BaseMessageLike } from '@langchain/core/messages';
99

10-
const getSystemPrompt = () => {
11-
return `You are a helpful chat assistant from the Elasticsearch company.
10+
export const defaultSystemPrompt =
11+
'You are a helpful chat assistant from the Elasticsearch company.';
1212

13-
You have tools at your disposal that you can use to answer the user's question.
13+
const getFullSystemPrompt = (systemPrompt: string) => {
14+
return `${systemPrompt}
1415
1516
### Additional info
17+
- You have tools at your disposal that you can use
1618
- The current date is: ${new Date().toISOString()}
1719
- You can use markdown format to structure your response
1820
`;
1921
};
2022

21-
export const withSystemPrompt = ({ messages }: { messages: BaseMessage[] }): BaseMessageLike[] => {
22-
return [['system', getSystemPrompt()], ...messages];
23+
export const withSystemPrompt = ({
24+
systemPrompt,
25+
messages,
26+
}: {
27+
systemPrompt: string;
28+
messages: BaseMessage[];
29+
}): BaseMessageLike[] => {
30+
return [['system', getFullSystemPrompt(systemPrompt)], ...messages];
2331
};

x-pack/platform/plugins/shared/onechat/server/services/agents/conversational/utils/tool_provider_to_langchain_tools.ts

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,37 @@ import { StructuredTool, tool as toTool } from '@langchain/core/tools';
99
import { Logger } from '@kbn/logging';
1010
import type { KibanaRequest } from '@kbn/core-http-server';
1111
import { toolDescriptorToIdentifier, toSerializedToolIdentifier } from '@kbn/onechat-common';
12-
import type { ToolProvider, ExecutableTool, ScopedRunner } from '@kbn/onechat-server';
12+
import type { ToolProvider, ExecutableTool } from '@kbn/onechat-server';
1313

1414
export const providerToLangchainTools = async ({
1515
request,
1616
toolProvider,
1717
logger,
18-
runner,
1918
}: {
2019
request: KibanaRequest;
2120
toolProvider: ToolProvider;
2221
logger: Logger;
23-
runner: ScopedRunner;
2422
}): Promise<StructuredTool[]> => {
2523
const allTools = await toolProvider.list({ request });
2624
return Promise.all(
2725
allTools.map((tool) => {
28-
return toLangchainTool({ tool, logger, runner });
26+
return toLangchainTool({ tool, logger });
2927
})
3028
);
3129
};
3230

3331
export const toLangchainTool = ({
3432
tool,
3533
logger,
36-
runner,
3734
}: {
3835
tool: ExecutableTool;
39-
runner: ScopedRunner;
4036
logger: Logger;
4137
}): StructuredTool => {
4238
const toolId = toolDescriptorToIdentifier(tool);
4339
return toTool(
4440
async (input) => {
4541
try {
46-
const toolReturn = await runner.runTool({
47-
toolId,
48-
toolParams: input,
49-
});
42+
const toolReturn = await tool.execute({ toolParams: input });
5043
return JSON.stringify(toolReturn.result);
5144
} catch (e) {
5245
logger.warn(`error calling tool ${tool.name}: ${e.message}`);

x-pack/platform/plugins/shared/onechat/server/services/runner/run_agent.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type {
1111
ConversationalAgentParams,
1212
RunAgentReturn,
1313
} from '@kbn/onechat-server';
14+
import { internalProviderToPublic } from '../tools/utils';
1415
import { createAgentEventEmitter, forkContextForAgentRun } from './utils';
1516
import type { RunnerManager } from './runner';
1617

@@ -29,7 +30,10 @@ export const createAgentHandlerContext = <TParams = Record<string, unknown>>({
2930
esClient: elasticsearch.client.asScoped(request),
3031
modelProvider: modelProviderFactory({ request, defaultConnectorId }),
3132
runner: manager.getRunner(),
32-
toolProvider: toolsService.registry.asPublicRegistry(),
33+
toolProvider: internalProviderToPublic({
34+
provider: toolsService.registry,
35+
getRunner: manager.getRunner,
36+
}),
3337
events: createAgentEventEmitter({ eventHandler: onEvent, context: manager.context }),
3438
};
3539
};

0 commit comments

Comments
 (0)