Skip to content

Commit 02802a6

Browse files
authored
DX-1599: Add way to use every fetch-compatible model in agents (#77)
* feat: add way to use every fetch-compatible model in agents * fix: logs types * fix: split the agents.model to two * fix: use agents.AISDKModel instead of agents.model * fix: add @ai-sdk/anthropic to dev deps
1 parent ccee173 commit 02802a6

File tree

7 files changed

+200
-59
lines changed

7 files changed

+200
-59
lines changed

bun.lockb

1.38 KB
Binary file not shown.

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
},
8080
"homepage": "https://github.com/upstash/workflow-ts#readme",
8181
"devDependencies": {
82+
"@ai-sdk/anthropic": "^1.1.15",
8283
"@commitlint/cli": "^19.5.0",
8384
"@commitlint/config-conventional": "^19.5.0",
8485
"@eslint/js": "^9.11.1",

src/agents/adapters.ts

+67-48
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,60 @@ import { createOpenAI } from "@ai-sdk/openai";
77
import { HTTPMethods } from "@upstash/qstash";
88
import { WorkflowContext } from "../context";
99
import { tool } from "ai";
10-
import { AISDKTool, LangchainTool } from "./types";
10+
import { AISDKTool, LangchainTool, ProviderFunction } from "./types";
1111
import { AGENT_NAME_HEADER } from "./constants";
1212
import { z, ZodType } from "zod";
1313

14+
export const fetchWithContextCall = async (
15+
context: WorkflowContext,
16+
...params: Parameters<typeof fetch>
17+
) => {
18+
const [input, init] = params;
19+
try {
20+
// Prepare headers from init.headers
21+
const headers = init?.headers ? Object.fromEntries(new Headers(init.headers).entries()) : {};
22+
23+
// Prepare body from init.body
24+
const body = init?.body ? JSON.parse(init.body as string) : undefined;
25+
26+
// create step name
27+
const agentName = headers[AGENT_NAME_HEADER] as string | undefined;
28+
const stepName = agentName ? `Call Agent ${agentName}` : "Call Agent";
29+
30+
// Make network call
31+
const responseInfo = await context.call(stepName, {
32+
url: input.toString(),
33+
method: init?.method as HTTPMethods,
34+
headers,
35+
body,
36+
});
37+
38+
// Construct headers for the response
39+
const responseHeaders = new Headers(
40+
Object.entries(responseInfo.header).reduce(
41+
(acc, [key, values]) => {
42+
acc[key] = values.join(", ");
43+
return acc;
44+
},
45+
{} as Record<string, string>
46+
)
47+
);
48+
49+
// Return the constructed response
50+
return new Response(JSON.stringify(responseInfo.body), {
51+
status: responseInfo.status,
52+
headers: responseHeaders,
53+
});
54+
} catch (error) {
55+
if (error instanceof Error && error.name === "WorkflowAbort") {
56+
throw error;
57+
} else {
58+
console.error("Error in fetch implementation:", error);
59+
throw error; // Rethrow error for further handling
60+
}
61+
}
62+
};
63+
1464
/**
1565
* creates an AI SDK openai client with a custom
1666
* fetch implementation which uses context.call.
@@ -27,53 +77,22 @@ export const createWorkflowOpenAI = (
2777
baseURL,
2878
apiKey,
2979
compatibility: "strict",
30-
fetch: async (input, init) => {
31-
try {
32-
// Prepare headers from init.headers
33-
const headers = init?.headers
34-
? Object.fromEntries(new Headers(init.headers).entries())
35-
: {};
36-
37-
// Prepare body from init.body
38-
const body = init?.body ? JSON.parse(init.body as string) : undefined;
39-
40-
// create step name
41-
const agentName = headers[AGENT_NAME_HEADER] as string | undefined;
42-
const stepName = agentName ? `Call Agent ${agentName}` : "Call Agent";
43-
44-
// Make network call
45-
const responseInfo = await context.call(stepName, {
46-
url: input.toString(),
47-
method: init?.method as HTTPMethods,
48-
headers,
49-
body,
50-
});
51-
52-
// Construct headers for the response
53-
const responseHeaders = new Headers(
54-
Object.entries(responseInfo.header).reduce(
55-
(acc, [key, values]) => {
56-
acc[key] = values.join(", ");
57-
return acc;
58-
},
59-
{} as Record<string, string>
60-
)
61-
);
62-
63-
// Return the constructed response
64-
return new Response(JSON.stringify(responseInfo.body), {
65-
status: responseInfo.status,
66-
headers: responseHeaders,
67-
});
68-
} catch (error) {
69-
if (error instanceof Error && error.name === "WorkflowAbort") {
70-
throw error;
71-
} else {
72-
console.error("Error in fetch implementation:", error);
73-
throw error; // Rethrow error for further handling
74-
}
75-
}
76-
},
80+
fetch: async (...params) => fetchWithContextCall(context, ...params),
81+
});
82+
};
83+
84+
export const createWorkflowModel = <TProvider extends ProviderFunction>({
85+
context,
86+
provider,
87+
providerParams,
88+
}: {
89+
context: WorkflowContext;
90+
provider: TProvider;
91+
providerParams?: Omit<Required<Parameters<TProvider>>[0], "fetch">;
92+
}): ReturnType<TProvider> => {
93+
return provider({
94+
fetch: (...params) => fetchWithContextCall(context, ...params),
95+
...providerParams,
7796
});
7897
};
7998

src/agents/index.ts

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import { createOpenAI } from "@ai-sdk/openai";
12
import { WorkflowContext } from "../context";
2-
import { createWorkflowOpenAI, wrapTools } from "./adapters";
3+
import { createWorkflowModel, wrapTools } from "./adapters";
34
import { Agent } from "./agent";
45
import { Task } from "./task";
56
import {
@@ -95,7 +96,15 @@ export class WorkflowAgents {
9596
public openai(...params: CustomModelParams) {
9697
const [model, settings] = params;
9798
const { baseURL, apiKey, ...otherSettings } = settings ?? {};
98-
const openai = createWorkflowOpenAI(this.context, { baseURL, apiKey });
99-
return openai(model, otherSettings);
99+
100+
const openaiModel = this.AISDKModel({
101+
context: this.context,
102+
provider: createOpenAI,
103+
providerParams: { baseURL, apiKey, compatibility: "strict" },
104+
});
105+
106+
return openaiModel(model, otherSettings);
100107
}
108+
109+
public AISDKModel = createWorkflowModel;
101110
}

src/agents/task.test.ts

+109-7
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,19 @@ import { WorkflowAgents } from ".";
77
import { tool } from "ai";
88
import { z } from "zod";
99
import { DisabledWorkflowContext } from "../serve/authorization";
10+
import { createAnthropic } from "@ai-sdk/anthropic";
11+
import { createOpenAI } from "@ai-sdk/openai";
1012

11-
export const getAgentsApi = ({ disabledContext }: { disabledContext: boolean }) => {
13+
export const getAgentsApi = ({
14+
disabledContext,
15+
getModel,
16+
}: {
17+
disabledContext: boolean;
18+
getModel?: (
19+
agentsApi: WorkflowAgents,
20+
context: WorkflowContext
21+
) => ReturnType<typeof agentsApi.openai>;
22+
}) => {
1223
const workflowRunId = getWorkflowRunId();
1324
const token = nanoid();
1425

@@ -39,7 +50,7 @@ export const getAgentsApi = ({ disabledContext }: { disabledContext: boolean })
3950
const maxSteps = 2;
4051
const name = "my agent";
4152
const temparature = 0.4;
42-
const model = agentsApi.openai("gpt-3.5-turbo");
53+
const model = getModel ? getModel(agentsApi, context) : agentsApi.openai("gpt-3.5-turbo");
4354

4455
const agent = agentsApi.agent({
4556
tools: {
@@ -62,6 +73,7 @@ export const getAgentsApi = ({ disabledContext }: { disabledContext: boolean })
6273
workflowRunId,
6374
context,
6475
agentsApi,
76+
model,
6577
};
6678
};
6779

@@ -131,13 +143,28 @@ describe("tasks", () => {
131143
});
132144

133145
test("multi agent with baseURL", async () => {
134-
const { agentsApi, agent, token, workflowRunId } = getAgentsApi({ disabledContext: false });
135-
136-
const customURL = "https://api.deepseek.com/v1";
137146
const customApiKey = nanoid();
147+
const { agentsApi, agent, token, workflowRunId, model } = getAgentsApi({
148+
disabledContext: false,
149+
getModel(agentsApi, context) {
150+
const model = agentsApi.AISDKModel({
151+
context,
152+
provider: createOpenAI,
153+
providerParams: {
154+
baseURL: "https://api.deepseek.com/v1",
155+
apiKey: customApiKey,
156+
},
157+
});
158+
159+
return model("gpt-4o", {
160+
reasoningEffort: "low",
161+
});
162+
},
163+
});
164+
138165
const task = agentsApi.task({
139166
agents: [agent],
140-
model: agentsApi.openai("gpt-3.5-turbo", { baseURL: customURL, apiKey: customApiKey }),
167+
model: model,
141168
maxSteps: 2,
142169
prompt: "hello world!",
143170
});
@@ -159,7 +186,7 @@ describe("tasks", () => {
159186
token,
160187
body: [
161188
{
162-
body: '{"model":"gpt-3.5-turbo","temperature":0.1,"messages":[{"role":"system","content":"You are an agent orchestrating other AI Agents.\\n\\nThese other agents have tools available to them.\\n\\nGiven a prompt, utilize these agents to address requests.\\n\\nDon\'t always call all the agents provided to you at the same time. You can call one and use it\'s response to call another.\\n\\nAvoid calling the same agent twice in one turn. Instead, prefer to call it once but provide everything\\nyou need from that agent.\\n"},{"role":"user","content":"hello world!"}],"tools":[{"type":"function","function":{"name":"my agent","description":"An AI Agent with the following background: an agentHas access to the following tools: ai sdk tool","parameters":{"type":"object","properties":{"prompt":{"type":"string"}},"required":["prompt"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}}}],"tool_choice":"auto"}',
189+
body: '{"model":"gpt-4o","temperature":0.1,"reasoning_effort":"low","messages":[{"role":"system","content":"You are an agent orchestrating other AI Agents.\\n\\nThese other agents have tools available to them.\\n\\nGiven a prompt, utilize these agents to address requests.\\n\\nDon\'t always call all the agents provided to you at the same time. You can call one and use it\'s response to call another.\\n\\nAvoid calling the same agent twice in one turn. Instead, prefer to call it once but provide everything\\nyou need from that agent.\\n"},{"role":"user","content":"hello world!"}],"tools":[{"type":"function","function":{"name":"my agent","description":"An AI Agent with the following background: an agentHas access to the following tools: ai sdk tool","parameters":{"type":"object","properties":{"prompt":{"type":"string"}},"required":["prompt"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}}}],"tool_choice":"auto"}',
163190
destination: "https://api.deepseek.com/v1/chat/completions",
164191
headers: {
165192
"upstash-workflow-sdk-version": "1",
@@ -195,4 +222,79 @@ describe("tasks", () => {
195222
},
196223
});
197224
});
225+
226+
test("anthropic model", async () => {
227+
const { agentsApi, agent, token, workflowRunId } = getAgentsApi({
228+
disabledContext: false,
229+
getModel: (agentsApi, context) => {
230+
const model = agentsApi.AISDKModel({
231+
context,
232+
provider: createAnthropic,
233+
providerParams: {
234+
apiKey: "antrhopic-key",
235+
},
236+
});
237+
238+
return model("claude-3-sonnet-20240229");
239+
},
240+
});
241+
242+
const task = agentsApi.task({
243+
agent,
244+
prompt: "hello world!",
245+
});
246+
247+
await mockQStashServer({
248+
execute: () => {
249+
const throws = () => task.run();
250+
expect(throws).toThrowError(`Aborting workflow after executing step 'Call Agent my agent'`);
251+
},
252+
responseFields: {
253+
status: 200,
254+
body: "msgId",
255+
},
256+
receivesRequest: {
257+
method: "POST",
258+
url: `${MOCK_QSTASH_SERVER_URL}/v2/batch`,
259+
token,
260+
body: [
261+
{
262+
body: '{"model":"claude-3-sonnet-20240229","max_tokens":4096,"temperature":0.4,"system":[{"type":"text","text":"an agent"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello world!"}]}],"tools":[{"name":"tool","description":"ai sdk tool","input_schema":{"type":"object","properties":{"expression":{"type":"string"}},"required":["expression"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}}],"tool_choice":{"type":"auto"}}',
263+
destination: "https://api.anthropic.com/v1/messages",
264+
headers: {
265+
"upstash-workflow-sdk-version": "1",
266+
"content-type": "application/json",
267+
"upstash-callback": "https://requestcatcher.com/api",
268+
"upstash-callback-feature-set": "LazyFetch,InitialBody",
269+
"upstash-callback-forward-upstash-workflow-callback": "true",
270+
"upstash-callback-forward-upstash-workflow-concurrent": "1",
271+
"upstash-callback-forward-upstash-workflow-contenttype": "application/json",
272+
"upstash-callback-forward-upstash-workflow-stepid": "1",
273+
"upstash-callback-forward-upstash-workflow-steptype": "Call",
274+
"upstash-callback-forward-upstash-workflow-invoke-count": "0",
275+
"upstash-callback-retries": "3",
276+
"upstash-callback-workflow-calltype": "fromCallback",
277+
"upstash-callback-workflow-init": "false",
278+
"upstash-callback-workflow-runid": workflowRunId,
279+
"upstash-callback-workflow-url": "https://requestcatcher.com/api",
280+
"upstash-failure-callback-retries": "3",
281+
"upstash-feature-set": "WF_NoDelete,InitialBody",
282+
"upstash-forward-content-type": "application/json",
283+
"upstash-forward-upstash-agent-name": "my agent",
284+
"upstash-method": "POST",
285+
"upstash-retries": "0",
286+
"upstash-workflow-calltype": "toCallback",
287+
"upstash-workflow-init": "false",
288+
"upstash-workflow-runid": workflowRunId,
289+
"upstash-workflow-url": "https://requestcatcher.com/api",
290+
"upstash-callback-forward-upstash-workflow-stepname": "Call Agent my agent",
291+
// anthropic specific headers:
292+
"upstash-forward-x-api-key": "antrhopic-key",
293+
"upstash-forward-anthropic-version": "2023-06-01",
294+
},
295+
},
296+
],
297+
},
298+
});
299+
});
198300
});

src/agents/types.ts

+5
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,8 @@ export type ManagerAgentParameters = {
9393
type ModelParams = Parameters<ReturnType<typeof createWorkflowOpenAI>>;
9494
type CustomModelSettings = ModelParams["1"] & { baseURL?: string; apiKey?: string };
9595
export type CustomModelParams = [ModelParams[0], CustomModelSettings?];
96+
97+
export type ProviderFunction = (params: {
98+
fetch: typeof fetch;
99+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
100+
}) => any;

src/client/types.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ export type StepLog = BaseStepLog &
133133
AsOptional<CallUrlGroup> &
134134
AsOptional<CallResponseStatusGroup> &
135135
AsOptional<InvokedWorkflowGroup> &
136+
AsOptional<{ sleepFor: number }> &
137+
AsOptional<{ sleepUntil: number }> &
136138
AsOptional<WaitEventGroup>;
137139

138140
type StepLogGroup =
@@ -160,7 +162,10 @@ type StepLogGroup =
160162
/**
161163
* Log which belongs to the next step
162164
*/
163-
steps: { messageId: string; state: "STEP_PROGRESS" | "STEP_RETRY" | "STEP_FAILED" }[];
165+
steps: {
166+
messageId: string;
167+
state: "STEP_PROGRESS" | "STEP_RETRY" | "STEP_FAILED" | "STEP_CANCELED";
168+
}[];
164169
/**
165170
* Log which belongs to the next step
166171
*/

0 commit comments

Comments
 (0)