Skip to content

Commit

Permalink
feat: use latest @copilot-extensions/preview-sdk and all its goodies
Browse files Browse the repository at this point in the history
  • Loading branch information
gr2m committed Sep 5, 2024
1 parent 485c77a commit 24a0a06
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 71 deletions.
13 changes: 7 additions & 6 deletions src/functions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { PromptFunction, InteropMessage } from "@copilot-extensions/preview-sdk";

import { ModelsAPI } from "./models-api.js";

// defaultModel is the model used for internal calls - for tool calling,
Expand All @@ -8,26 +9,26 @@ export const defaultModel = "gpt-4o-mini";
// RunnerResponse is the response from a function call.
export interface RunnerResponse {
model: string;
messages: OpenAI.ChatCompletionMessageParam[];
messages: InteropMessage[];
}

export abstract class Tool {
modelsAPI: ModelsAPI;
static definition: OpenAI.FunctionDefinition;
static definition: PromptFunction["function"];

constructor(modelsAPI: ModelsAPI) {
this.modelsAPI = modelsAPI;
}

static get tool(): OpenAI.Chat.Completions.ChatCompletionTool {
static get tool(): PromptFunction {
return {
type: "function",
function: this.definition,
};
}

abstract execute(
messages: OpenAI.ChatCompletionMessageParam[],
args: object
messages: InteropMessage[],
args: Record<string, unknown>
): Promise<RunnerResponse>;
}
5 changes: 3 additions & 2 deletions src/functions/describe-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class describeModel extends Tool {
Expand All @@ -19,7 +20,7 @@ export class describeModel extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[],
messages: InteropMessage[],
args: { model: string }
): Promise<RunnerResponse> {
const [model, modelSchema] = await Promise.all([
Expand Down
5 changes: 3 additions & 2 deletions src/functions/execute-model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, Tool } from "../functions.js";

type MessageWithReferences = OpenAI.ChatCompletionMessageParam & {
type MessageWithReferences = InteropMessage & {
copilot_references: Reference[];
};

Expand Down
5 changes: 3 additions & 2 deletions src/functions/list-models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class listModels extends Tool {
Expand All @@ -15,7 +16,7 @@ export class listModels extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[]
messages: InteropMessage[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

Expand Down
5 changes: 3 additions & 2 deletions src/functions/recommend-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class recommendModel extends Tool {
Expand All @@ -15,7 +16,7 @@ export class recommendModel extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[]
messages: InteropMessage[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

Expand Down
77 changes: 30 additions & 47 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { createServer, IncomingMessage } from "node:http";
import { createServer } from "node:http";

import { verifyAndParseRequest, createAckEvent } from "@copilot-extensions/preview-sdk";
import OpenAI from "openai";
import { prompt, getFunctionCalls, createAckEvent, createDoneEvent, verifyAndParseRequest, createTextEvent } from "@copilot-extensions/preview-sdk";

import { describeModel } from "./functions/describe-model.js";
import { executeModel } from "./functions/execute-model.js";
Expand All @@ -12,6 +11,7 @@ import { ModelsAPI } from "./models-api.js";

const server = createServer(async (request, response) => {
if (request.method === "GET") {
// health check
response.statusCode = 200;
response.end(`OK`);
return;
Expand Down Expand Up @@ -54,15 +54,9 @@ const server = createServer(async (request, response) => {
response.write(createAckEvent().toString());

// List of functions that are available to be called
const modelsAPI = new ModelsAPI(apiKey);
const modelsAPI = new ModelsAPI();
const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
const capiClient = new OpenAI({
baseURL: "https://api.githubcopilot.com",
apiKey,
});

// Prepend a system message that includes the list of models, so that
// tool calls can better select the right model to use.
const models = await modelsAPI.listModels();
Expand Down Expand Up @@ -90,56 +84,48 @@ const server = createServer(async (request, response) => {
].concat(payload.messages);

console.time("tool-call");
const toolCaller = await capiClient.chat.completions.create({
stream: false,
model: "gpt-4",
const promptResult = await prompt({
messages: toolCallMessages,
tool_choice: "auto",
token: apiKey,
tools: functions.map((f) => f.tool),
});
})
console.timeEnd("tool-call");

const [functionToCall] = getFunctionCalls(promptResult)

if (
!toolCaller.choices[0] ||
!toolCaller.choices[0].message ||
!toolCaller.choices[0].message.tool_calls ||
!toolCaller.choices[0].message.tool_calls[0].function
!functionToCall
) {
console.log("No tool call found");
// No tool to call, so just call the model with the original messages
const stream = await capiClient.chat.completions.create({
stream: true,
model: "gpt-4",
// @ts-expect-error - TODO @gr2m - type incompatibility between @openai/api and @copilot-extensions/preview-sdk

const { stream } = await prompt.stream({
messages: payload.messages,
});
token: apiKey,
})

for await (const chunk of stream) {
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
response.write(new TextDecoder().decode(chunk));
}
response.write("data: [DONE]\n\n");
response.end();

response.end(createDoneEvent().toString());
return;
}

const functionToCall = toolCaller.choices[0].message.tool_calls[0].function;
const args = JSON.parse(functionToCall.arguments);
const args = JSON.parse(functionToCall.function.arguments);

console.time("function-exec");
let functionCallRes: RunnerResponse;
try {
console.log("Executing function", functionToCall.name);
console.log("Executing function", functionToCall.function.name);
const funcClass = functions.find(
(f) => f.definition.name === functionToCall.name
(f) => f.definition.name === functionToCall.function.name
);
if (!funcClass) {
throw new Error("Unknown function");
}

console.log("\t with args", args);
const func = new funcClass(modelsAPI);
// @ts-expect-error - TODO @gr2m - type incompatibility between @openai/api and @copilot-extensions/preview-sdk
functionCallRes = await func.execute(payload.messages, args);
} catch (err) {
console.error(err);
Expand All @@ -150,23 +136,20 @@ const server = createServer(async (request, response) => {
console.timeEnd("function-exec");

try {
const stream = await modelsAPI.inference.chat.completions.create({
console.time("streaming");
const { stream } = await prompt.stream({
endpoint: 'https://models.inference.ai.azure.com/chat/completions',
model: functionCallRes.model,
messages: functionCallRes.messages,
stream: true,
stream_options: {
include_usage: false,
},
});
token: apiKey,
})

console.time("streaming");
for await (const chunk of stream) {
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
response.write(new TextDecoder().decode(chunk));
}
response.write("data: [DONE]\n\n");

response.end(createDoneEvent().toString());
console.timeEnd("streaming");
response.end();
} catch (err) {
console.error(err);
response.statusCode = 500
Expand All @@ -178,12 +161,12 @@ const port = process.env.PORT || "3000"
server.listen(port);
console.log(`Server running at http://localhost:${port}`);

function getBody(request: IncomingMessage): Promise<string> {
function getBody(request: any): Promise<string> {
return new Promise((resolve) => {
const bodyParts: any[] = [];
let body;
request
.on("data", (chunk) => {
.on("data", (chunk: Buffer) => {
bodyParts.push(chunk);
})
.on("end", () => {
Expand Down
10 changes: 0 additions & 10 deletions src/models-api.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import OpenAI from "openai";

// Model is the structure of a model in the model catalog.
export interface Model {
id: string;
Expand Down Expand Up @@ -33,16 +31,8 @@ export type ModelSchemaParameter = {
};

export class ModelsAPI {
inference: OpenAI;
private _models: Model[] | null = null;

constructor(apiKey: string) {
this.inference = new OpenAI({
baseURL: "https://models.inference.ai.azure.com",
apiKey,
});
}

async getModel(modelName: string): Promise<Model> {
const modelRes = await fetch(
"https://modelcatalog.azure-api.net/v1/model/" + modelName
Expand Down

0 comments on commit 24a0a06

Please sign in to comment.