Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use latest @copilot-extensions/preview-sdk and all its goodies #6

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
feat: use latest @copilot-extensions/preview-sdk and all its goodies
  • Loading branch information
gr2m committed Sep 10, 2024
commit ccea89f5640b01ffed979a9957460e3aab21ed03
13 changes: 7 additions & 6 deletions src/functions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { PromptFunction, InteropMessage } from "@copilot-extensions/preview-sdk";

import { ModelsAPI } from "./models-api.js";

// defaultModel is the model used for internal calls - for tool calling,
@@ -8,26 +9,26 @@ export const defaultModel = "gpt-4o-mini";
// RunnerResponse is the response from a function call.
export interface RunnerResponse {
model: string;
messages: OpenAI.ChatCompletionMessageParam[];
messages: InteropMessage[];
}

export abstract class Tool {
modelsAPI: ModelsAPI;
static definition: OpenAI.FunctionDefinition;
static definition: PromptFunction["function"];

constructor(modelsAPI: ModelsAPI) {
this.modelsAPI = modelsAPI;
}

static get tool(): OpenAI.Chat.Completions.ChatCompletionTool {
static get tool(): PromptFunction {
return {
type: "function",
function: this.definition,
};
}

abstract execute(
messages: OpenAI.ChatCompletionMessageParam[],
args: object
messages: InteropMessage[],
args: Record<string, unknown>
): Promise<RunnerResponse>;
}
5 changes: 3 additions & 2 deletions src/functions/describe-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class describeModel extends Tool {
@@ -19,7 +20,7 @@ export class describeModel extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[],
messages: InteropMessage[],
args: { model: string }
): Promise<RunnerResponse> {
const [model, modelSchema] = await Promise.all([
5 changes: 3 additions & 2 deletions src/functions/execute-model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, Tool } from "../functions.js";

type MessageWithReferences = OpenAI.ChatCompletionMessageParam & {
type MessageWithReferences = InteropMessage & {
copilot_references: Reference[];
};

5 changes: 3 additions & 2 deletions src/functions/list-models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class listModels extends Tool {
@@ -15,7 +16,7 @@ export class listModels extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[]
messages: InteropMessage[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

5 changes: 3 additions & 2 deletions src/functions/recommend-model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import OpenAI from "openai";
import type { InteropMessage } from "@copilot-extensions/preview-sdk";

import { RunnerResponse, defaultModel, Tool } from "../functions.js";

export class recommendModel extends Tool {
@@ -15,7 +16,7 @@ export class recommendModel extends Tool {
};

async execute(
messages: OpenAI.ChatCompletionMessageParam[]
messages: InteropMessage[]
): Promise<RunnerResponse> {
const models = await this.modelsAPI.listModels();

60 changes: 28 additions & 32 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createServer, IncomingMessage } from "node:http";
import { createServer } from "node:http";

import { verifyAndParseRequest, transformPayloadForOpenAICompatibility } from "@copilot-extensions/preview-sdk";
import { verifyAndParseRequest, transformPayloadForOpenAICompatibility, getFunctionCalls, createDoneEvent } from "@copilot-extensions/preview-sdk";
import OpenAI from "openai";

import { describeModel } from "./functions/describe-model.js";
@@ -12,6 +12,7 @@ import { ModelsAPI } from "./models-api.js";

const server = createServer(async (request, response) => {
if (request.method === "GET") {
// health check
response.statusCode = 200;
response.end(`OK`);
return;
@@ -54,7 +55,7 @@ const server = createServer(async (request, response) => {
}

// List of functions that are available to be called
const modelsAPI = new ModelsAPI(apiKey);
const modelsAPI = new ModelsAPI();
const functions = [listModels, describeModel, executeModel, recommendModel];

// Use the Copilot API to determine which function to execute
@@ -86,51 +87,49 @@ const server = createServer(async (request, response) => {
"<-- END OF LIST OF MODELS -->",
].join("\n"),
},
...compatibilityPayload.messages,
...compatibilityPayload.messages,
];

console.time("tool-call");
const toolCaller = await capiClient.chat.completions.create({
stream: false,
model: "gpt-4o",
messages: toolCallMessages,
tool_choice: "auto",
token: apiKey,
tools: functions.map((f) => f.tool),
});
})
console.timeEnd("tool-call");

const [functionToCall] = getFunctionCalls(promptResult)

if (
!toolCaller.choices[0] ||
!toolCaller.choices[0].message ||
!toolCaller.choices[0].message.tool_calls ||
!toolCaller.choices[0].message.tool_calls[0].function
!functionToCall
) {
console.log("No tool call found");
// No tool to call, so just call the model with the original messages
const stream = await capiClient.chat.completions.create({
stream: true,
model: "gpt-4o",
messages: payload.messages,
});
token: apiKey,
})

for await (const chunk of stream) {
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
response.write(new TextDecoder().decode(chunk));
}
response.write("data: [DONE]\n\n");
response.end();

response.end(createDoneEvent().toString());
return;
}

const functionToCall = toolCaller.choices[0].message.tool_calls[0].function;
const args = JSON.parse(functionToCall.arguments);
const args = JSON.parse(functionToCall.function.arguments);

console.time("function-exec");
let functionCallRes: RunnerResponse;
try {
console.log("Executing function", functionToCall.name);
console.log("Executing function", functionToCall.function.name);
const funcClass = functions.find(
(f) => f.definition.name === functionToCall.name
(f) => f.definition.name === functionToCall.function.name
);
if (!funcClass) {
throw new Error("Unknown function");
@@ -148,23 +147,20 @@ const server = createServer(async (request, response) => {
console.timeEnd("function-exec");

try {
const stream = await modelsAPI.inference.chat.completions.create({
console.time("streaming");
const { stream } = await prompt.stream({
endpoint: 'https://models.inference.ai.azure.com/chat/completions',
model: functionCallRes.model,
messages: functionCallRes.messages,
stream: true,
stream_options: {
include_usage: false,
},
});
token: apiKey,
})

console.time("streaming");
for await (const chunk of stream) {
const chunkStr = "data: " + JSON.stringify(chunk) + "\n\n";
response.write(chunkStr);
response.write(new TextDecoder().decode(chunk));
}
response.write("data: [DONE]\n\n");

response.end(createDoneEvent().toString());
console.timeEnd("streaming");
response.end();
} catch (err) {
console.error(err);
response.statusCode = 500
@@ -176,12 +172,12 @@ const port = process.env.PORT || "3000"
server.listen(port);
console.log(`Server running at http://localhost:${port}`);

function getBody(request: IncomingMessage): Promise<string> {
function getBody(request: any): Promise<string> {
return new Promise((resolve) => {
const bodyParts: any[] = [];
let body;
request
.on("data", (chunk) => {
.on("data", (chunk: Buffer) => {
bodyParts.push(chunk);
})
.on("end", () => {
10 changes: 0 additions & 10 deletions src/models-api.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import OpenAI from "openai";

// Model is the structure of a model in the model catalog.
export interface Model {
id: string;
@@ -33,16 +31,8 @@ export type ModelSchemaParameter = {
};

export class ModelsAPI {
inference: OpenAI;
private _models: Model[] | null = null;

constructor(apiKey: string) {
this.inference = new OpenAI({
baseURL: "https://models.inference.ai.azure.com",
apiKey,
});
}

async getModel(modelName: string): Promise<Model> {
const modelRes = await fetch(
"https://modelcatalog.azure-api.net/v1/model/" + modelName