Skip to content

Commit 9185e1a

Browse files
committed
Added tool calling support
1 parent 68f26e3 commit 9185e1a

File tree

8 files changed

+264
-149
lines changed

8 files changed

+264
-149
lines changed

packages/alphawave/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"name": "alphawave",
33
"author": "Steven Ickman",
44
"description": "A very opinionated client for interfacing with Large Language Models.",
5-
"version": "0.21.3",
5+
"version": "0.22.0",
66
"license": "MIT",
77
"keywords": [
88
"ai",
@@ -31,7 +31,7 @@
3131
"gpt-3-encoder": "1.1.4",
3232
"json-colorizer": "^2.2.2",
3333
"jsonschema": "1.4.1",
34-
"promptrix": "^0.5.0",
34+
"promptrix": "^0.6.0",
3535
"strict-event-emitter-types": "^2.0.0",
3636
"yaml": "2.3.1"
3737
},

packages/alphawave/src/AlphaWave.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { DefaultResponseValidator } from "./DefaultResponseValidator";
66
import { MemoryFork } from "./MemoryFork";
77
import { Colorize } from "./internals";
88
import { OpenAIModel } from "./OpenAIModel";
9-
import { FunctionResponseValidator } from "./FunctionResponseValidator";
9+
import { ToolResponseValidator } from "./ToolResponseValidator";
1010

1111
/**
1212
* Options for an AlphaWave instance.
@@ -326,8 +326,8 @@ export class AlphaWave extends (EventEmitter as { new(): AlphaWaveEmitter }) {
326326
// Create validator to use
327327
if (!this.options.validator) {
328328
// Check for an OpenAI model using functions
329-
if (this.options.model instanceof OpenAIModel && this.options.model.options.functions) {
330-
this.options.validator = new FunctionResponseValidator(this.options.model.options.functions);
329+
if (this.options.model instanceof OpenAIModel && this.options.model.options.tools) {
330+
this.options.validator = new ToolResponseValidator(this.options.model.options.tools);
331331
} else {
332332
this.options.validator = new DefaultResponseValidator();
333333
}
@@ -398,10 +398,10 @@ export class AlphaWave extends (EventEmitter as { new(): AlphaWaveEmitter }) {
398398
const { prompt, memory, functions, tokenizer, validator, max_repair_attempts, history_variable, input_variable } = this.options;
399399
let { model } = this.options;
400400

401-
// Check for OpenAI model being used with a function validator
402-
if (model instanceof OpenAIModel && validator instanceof FunctionResponseValidator && !model.options.functions) {
401+
// Check for OpenAI model being used with a tool validator
402+
if (model instanceof OpenAIModel && validator instanceof ToolResponseValidator && !model.options.tools) {
403403
// Create a clone of the model that's configured to use the validators functions
404-
model = model.clone({ functions: validator.functions })
404+
model = model.clone({ tools: validator.tools })
405405
}
406406

407407
// Update/get user input

packages/alphawave/src/FunctionResponseValidator.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@ import { Schema } from "jsonschema";
44
import { JSONResponseValidator } from "./JSONResponseValidator";
55

66
/**
7+
* @deprecated
78
* Validates function calls returned by the model.
8-
*
99
* @remarks
10-
*
1110
*/
1211
export class FunctionResponseValidator implements PromptResponseValidator {
1312
private readonly _functions: Map<string, ChatCompletionFunction> = new Map();

packages/alphawave/src/OpenAIModel.ts

Lines changed: 92 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import axios, { AxiosInstance, AxiosResponse, AxiosRequestConfig } from 'axios';
22
import { PromptFunctions, PromptMemory, PromptSection, Tokenizer } from "promptrix";
3-
import { PromptCompletionModel, PromptResponse, ChatCompletionFunction, PromptResponseDetails, JsonSchema } from "./types";
3+
import { PromptCompletionModel, PromptResponse, ChatCompletionFunction, PromptResponseDetails, JsonSchema, ChatCompletionTool } from "./types";
44
import { ChatCompletionRequestMessage, CreateChatCompletionRequest, CreateChatCompletionResponse, CreateCompletionRequest, CreateCompletionResponse, OpenAICreateChatCompletionRequest, OpenAICreateCompletionRequest } from "./internals";
55
import { Colorize } from "./internals";
66

@@ -111,11 +111,13 @@ export interface BaseOpenAIModelOptions {
111111
requestConfig?: AxiosRequestConfig;
112112

113113
/**
114+
* @deprecated
114115
* Optional. A list of functions the model may generate JSON inputs for.
115116
*/
116117
functions?: ChatCompletionFunction[];
117118

118119
/**
120+
* @deprecated
119121
* Optional. Controls how the model responds to function calls.
120122
* @remarks
121123
* `"none"` means the model does not call a function, and responds to the end-user.
@@ -139,6 +141,25 @@ export interface BaseOpenAIModelOptions {
139141
* Only available on select models but can be used to improve the models determinism in its responses.
140142
*/
141143
seed?: number;
144+
145+
/**
146+
* Optional. A list of tools the model may generate JSON inputs for.
147+
*/
148+
tools?: ChatCompletionTool[];
149+
150+
/**
151+
* Optional. Controls how the model responds to tool calls.
152+
* @remarks
153+
* Defaults to `auto`.
154+
*/
155+
tool_choice?: 'auto' | 'none' | 'required' | ChatCompletionTool;
156+
157+
/**
158+
* Optional. Whether to support calling tools in parallel.
159+
* @remarks
160+
* Defaults to `true`.
161+
*/
162+
parallel_tool_calls?: boolean;
142163
}
143164

144165
/**
@@ -302,148 +323,84 @@ export class OpenAIModel implements PromptCompletionModel {
302323
const startTime = Date.now();
303324
const max_input_tokens = this.options.max_input_tokens ?? 1024;
304325
if (this.options.completion_type == 'text') {
305-
// Render prompt
306-
const result = await prompt.renderAsText(memory, functions, tokenizer, max_input_tokens);
307-
if (result.tooLong) {
308-
return {
309-
status: 'too_long',
310-
prompt: result.output,
311-
error: `The generated text completion prompt had a length of ${result.length} tokens which exceeded the max_input_tokens of ${max_input_tokens}.`,
312-
};
313-
}
314-
if (this.options.logRequests) {
315-
console.log(Colorize.title('PROMPT:'));
316-
console.log(Colorize.output(result.output));
326+
throw new Error('Text completions are no longer supported by OpenAI.');
327+
}
328+
329+
// Render prompt
330+
const result = await prompt.renderAsMessages(memory, functions, tokenizer, max_input_tokens);
331+
if (result.tooLong) {
332+
return {
333+
status: 'too_long',
334+
prompt: result.output,
335+
error: `The generated chat completion prompt had a length of ${result.length} tokens which exceeded the max_input_tokens of ${max_input_tokens}.`
336+
};
337+
}
338+
if (this.options.logRequests) {
339+
console.log(Colorize.title('CHAT PROMPT:'));
340+
console.log(Colorize.output(result.output));
341+
if (Array.isArray(this.options.tools) && this.options.tools.length > 0) {
342+
console.log(Colorize.title('TOOLS:'));
343+
console.log(Colorize.output(this.options.tools));
317344
}
345+
}
318346

319-
// Call text completion API
320-
const request: CreateCompletionRequest = this.copyOptionsToRequest<CreateCompletionRequest>({
321-
prompt: result.output,
322-
}, this.options, ['max_tokens', 'temperature', 'top_p', 'n', 'stream', 'logprobs', 'echo', 'stop', 'presence_penalty', 'frequency_penalty', 'best_of', 'logit_bias', 'user']);
323-
const response = await this.createCompletion(request);
324-
const request_duration = Date.now() - startTime;;
325-
if (this.options.logRequests) {
326-
console.log(Colorize.title('RESPONSE:'));
327-
console.log(Colorize.value('status', response.status));
328-
console.log(Colorize.value('duration', request_duration, 'ms'));
329-
console.log(Colorize.output(response.data));
330-
}
347+
// Call chat completion API
348+
const request: CreateChatCompletionRequest = this.patchBreakingChanges(this.copyOptionsToRequest<CreateChatCompletionRequest>({
349+
messages: result.output as ChatCompletionRequestMessage[],
350+
}, this.options, [
351+
'max_tokens', 'temperature', 'top_p', 'n', 'stream', 'logprobs', 'echo', 'stop', 'presence_penalty',
352+
'frequency_penalty', 'best_of', 'logit_bias', 'user', 'functions', 'function_call', 'response_format',
353+
'seed', 'tools', 'tool_choice', 'parallel_tool_calls'
354+
]));
355+
const response = await this.createChatCompletion(request);
356+
const request_duration = Date.now() - startTime;
357+
if (this.options.logRequests) {
358+
console.log(Colorize.title('CHAT RESPONSE:'));
359+
console.log(Colorize.value('status', response.status));
360+
console.log(Colorize.value('duration', request_duration, 'ms'));
361+
console.log(Colorize.output(response.data));
362+
}
331363

332-
// Process response
333-
if (response.status < 300) {
334-
const completion = response.data.choices[0];
335-
const usage = response.data.usage;
336-
const details: PromptResponseDetails = {
337-
finish_reason: completion.finish_reason as any,
338-
completion_tokens: usage?.completion_tokens ?? -1,
339-
prompt_tokens: usage?.prompt_tokens ?? -1,
340-
total_tokens: usage?.total_tokens ?? -1,
341-
request_duration,
342-
};
343-
344-
// Ensure content is text
345-
// - We sometimes get an object back from the API
346-
let content = completion.text ?? '';
347-
if (typeof content == 'object') {
348-
content = JSON.stringify(content);
349-
}
350-
351-
return {
352-
status: 'success',
353-
prompt: result.output,
354-
message: { role: 'assistant', content },
355-
details
356-
};
357-
} else if (response.status == 429 && !response.statusText.includes('quota')) {
358-
if (this.options.logRequests) {
359-
console.log(Colorize.title('HEADERS:'));
360-
console.log(Colorize.output(response.headers));
361-
}
362-
return {
363-
status: 'rate_limited',
364-
prompt: result.output,
365-
error: `The text completion API returned a rate limit error.`
366-
}
367-
} else {
368-
return {
369-
status: 'error',
370-
prompt: result.output,
371-
error: `The text completion API returned an error status of ${response.status}: ${response.statusText}`
372-
};
373-
}
374-
} else {
375-
// Render prompt
376-
const result = await prompt.renderAsMessages(memory, functions, tokenizer, max_input_tokens);
377-
if (result.tooLong) {
378-
return {
379-
status: 'too_long',
380-
prompt: result.output,
381-
error: `The generated chat completion prompt had a length of ${result.length} tokens which exceeded the max_input_tokens of ${max_input_tokens}.`
382-
};
383-
}
384-
if (this.options.logRequests) {
385-
console.log(Colorize.title('CHAT PROMPT:'));
386-
console.log(Colorize.output(result.output));
387-
if (Array.isArray(this.options.functions) && this.options.functions.length > 0) {
388-
console.log(Colorize.title('FUNCTIONS:'));
389-
console.log(Colorize.output(this.options.functions));
390-
}
364+
// Process response
365+
if (response.status < 300) {
366+
const completion = response.data.choices[0];
367+
const usage = response.data.usage;
368+
const details: PromptResponseDetails = {
369+
finish_reason: completion.finish_reason as any,
370+
completion_tokens: usage?.completion_tokens ?? -1,
371+
prompt_tokens: usage?.prompt_tokens ?? -1,
372+
total_tokens: usage?.total_tokens ?? -1,
373+
request_duration,
374+
};
375+
376+
// Ensure message content is text
377+
const message = completion.message ?? { role: 'assistant', content: '' };
378+
if (typeof message.content == 'object') {
379+
message.content = JSON.stringify(message.content);
391380
}
392381

393-
// Call chat completion API
394-
const request: CreateChatCompletionRequest = this.patchBreakingChanges(this.copyOptionsToRequest<CreateChatCompletionRequest>({
395-
messages: result.output as ChatCompletionRequestMessage[],
396-
}, this.options, ['max_tokens', 'temperature', 'top_p', 'n', 'stream', 'logprobs', 'echo', 'stop', 'presence_penalty', 'frequency_penalty', 'best_of', 'logit_bias', 'user', 'functions', 'function_call', 'response_format', 'seed']));
397-
const response = await this.createChatCompletion(request);
398-
const request_duration = Date.now() - startTime;
382+
return {
383+
status: 'success',
384+
prompt: result.output,
385+
message,
386+
details
387+
};
388+
} else if (response.status == 429 && !response.statusText.includes('quota')) {
399389
if (this.options.logRequests) {
400-
console.log(Colorize.title('CHAT RESPONSE:'));
401-
console.log(Colorize.value('status', response.status));
402-
console.log(Colorize.value('duration', request_duration, 'ms'));
403-
console.log(Colorize.output(response.data));
390+
console.log(Colorize.title('HEADERS:'));
391+
console.log(Colorize.output(response.headers));
404392
}
405-
406-
// Process response
407-
if (response.status < 300) {
408-
const completion = response.data.choices[0];
409-
const usage = response.data.usage;
410-
const details: PromptResponseDetails = {
411-
finish_reason: completion.finish_reason as any,
412-
completion_tokens: usage?.completion_tokens ?? -1,
413-
prompt_tokens: usage?.prompt_tokens ?? -1,
414-
total_tokens: usage?.total_tokens ?? -1,
415-
request_duration,
416-
};
417-
418-
// Ensure message content is text
419-
const message = completion.message ?? { role: 'assistant', content: '' };
420-
if (typeof message.content == 'object') {
421-
message.content = JSON.stringify(message.content);
422-
}
423-
424-
return {
425-
status: 'success',
426-
prompt: result.output,
427-
message,
428-
details
429-
};
430-
} else if (response.status == 429 && !response.statusText.includes('quota')) {
431-
if (this.options.logRequests) {
432-
console.log(Colorize.title('HEADERS:'));
433-
console.log(Colorize.output(response.headers));
434-
}
435-
return {
436-
status: 'rate_limited',
437-
prompt: result.output,
438-
error: `The chat completion API returned a rate limit error.`
439-
}
440-
} else {
441-
return {
442-
status: 'error',
443-
prompt: result.output,
444-
error: `The chat completion API returned an error status of ${response.status}: ${response.statusText}`
445-
};
393+
return {
394+
status: 'rate_limited',
395+
prompt: result.output,
396+
error: `The chat completion API returned a rate limit error.`
446397
}
398+
} else {
399+
return {
400+
status: 'error',
401+
prompt: result.output,
402+
error: `The chat completion API returned an error status of ${response.status}: ${response.statusText}`
403+
};
447404
}
448405
}
449406

0 commit comments

Comments
 (0)