diff --git a/packages/proxy/schema/index.ts b/packages/proxy/schema/index.ts index 4108039..7381c77 100644 --- a/packages/proxy/schema/index.ts +++ b/packages/proxy/schema/index.ts @@ -6,6 +6,7 @@ import type { ModelParams, } from "@braintrust/core/typespecs"; import { AvailableModels, ModelFormat, ModelEndpointType } from "./models"; +import { isObject } from "@braintrust/core"; export * from "./secrets"; export * from "./models"; @@ -28,31 +29,6 @@ export const MessageTypeToMessageType: { model: "assistant", }; -export const modelParamToModelParam: { - [name: string]: keyof AnyModelParam | null; -} = { - temperature: "temperature", - top_p: "top_p", - top_k: "top_k", - max_tokens: "max_tokens", - max_tokens_to_sample: null, - use_cache: "use_cache", - maxOutputTokens: "max_tokens", - topP: "top_p", - topK: "top_k", - presence_penalty: null, - frequency_penalty: null, - user: null, - function_call: null, - n: null, - logprobs: null, - stream_options: null, - parallel_tool_calls: null, - response_format: null, - reasoning_effort: null, - stop: null, -}; - export const sliderSpecs: { // min, max, step, required [name: string]: [number, number, number, boolean]; @@ -423,60 +399,224 @@ ${content}<|im_end|>`, ); } +// XXX we can't use @braintrust/core from the workspace so testing with this here. +// eventually i'll move to sdk and import +const braintrustModelParamSchema = z.object({ + use_cache: z.boolean().optional(), + + temperature: z.number().optional(), + max_tokens: z.number().optional(), + // XXX how do we want to handle deprecated params + max_completion_tokens: z.number().optional(), + top_p: z.number().optional(), + top_k: z.number().optional(), + frequency_penalty: z.number().optional(), + presence_penalty: z.number().optional(), + /* XXX we special case these in the proxy but i need to understand how. will probably keep that logic where it is for now + response_format: z + .object({ + type: z.literal("json_object"), + }) + .nullish(), + tool_choice: z.object({ + type: z.literal("function"), + }).optional(), + function_call: z.object({ + name: z.string().optional(), + }).optional(), + */ + n: z.number().optional(), + stop: z.array(z.string()).optional(), + reasoning_effort: z.enum(["low", "medium", "high"]).optional(), +}); +type BraintrustModelParams = z.infer; +type BraintrustParamMapping = + | keyof BraintrustModelParams + | { + key: keyof BraintrustModelParams | null; + deprecated?: boolean; + o1_like?: boolean; + }; + +// XXX add to sdk +type ConverseModelParams = { + maxTokens: number; + stopSequences: string[]; +}; + +const anyModelParamToBraintrustModelParam: Record< + keyof AnyModelParam | keyof ConverseModelParams, + BraintrustParamMapping +> = { + use_cache: "use_cache", + temperature: "temperature", + + max_tokens: "max_tokens", + max_completion_tokens: { key: "max_tokens", o1_like: true }, + maxOutputTokens: "max_tokens", + maxTokens: "max_tokens", + // XXX map this to max_tokens? + max_tokens_to_sample: { key: null, deprecated: true }, + + top_p: "top_p", + topP: "top_p", + top_k: "top_k", + topK: "top_k", + frequency_penalty: "frequency_penalty", // null + presence_penalty: "presence_penalty", // null + + stop: "stop", // null + stop_sequences: "stop", // null + stopSequences: "stop", // null + + n: "n", // null + + reasoning_effort: { key: "reasoning_effort", o1_like: true }, + + response_format: { key: null }, // handled elsewhere? + function_call: { key: null }, // handled elsewhere + tool_choice: { key: null }, // handled elsewhere + // XXX parallel_tool_calls: { key: null }, // handled elsewhere +}; + +function translateKey( + toProvider: ModelFormat | undefined, + key: string, +): keyof ModelParams | null { + const braintrustKey = + anyModelParamToBraintrustModelParam[key as keyof AnyModelParam]; + let normalizedKey: keyof BraintrustModelParams | null = null; + if (braintrustKey === undefined) { + normalizedKey = null; + } else if (!isObject(braintrustKey)) { + normalizedKey = braintrustKey; + } else { + if (braintrustKey.deprecated) { + console.warn(`Deprecated model param: ${key}`); + } + + if (braintrustKey.key === null) { + normalizedKey = null; + } else { + normalizedKey = braintrustKey.key; + } + } + + if (normalizedKey === null) { + return null; + } + + // XXX if toProvider is undefined, return the normalized key. this is useful for the ui to parse span data when the + // provider is not known. maybe we can try harder to infer the provider? + if (toProvider === undefined) { + return normalizedKey; + } + + // XXX turn these into Record | null> + // maps from braintrust key to provider key. can live in proxy/providers + switch (toProvider) { + case "openai": + switch (normalizedKey) { + case "temperature": + return "temperature"; + case "max_tokens": + return "max_tokens"; + case "top_p": + return "top_p"; + case "stop": + return "stop"; + case "frequency_penalty": + return "frequency_penalty"; + case "presence_penalty": + return "presence_penalty"; + case "n": + return "n"; + default: + return null; + } + case "anthropic": + switch (normalizedKey) { + case "temperature": + return "temperature"; + case "max_tokens": + return "max_tokens"; + case "top_k": + return "top_k"; + case "top_p": + return "top_p"; + case "stop": + return "stop_sequences"; + default: + return null; + } + case "google": + switch (normalizedKey) { + case "temperature": + return "temperature"; + case "top_p": + return "topP"; + case "top_k": + return "topK"; + /* XXX add support for this? + case "stop": + return "stopSequences"; + */ + case "max_tokens": + return "maxOutputTokens"; + default: + return null; + } + case "window": + switch (normalizedKey) { + case "temperature": + return "temperature"; + case "top_k": + return "topK"; + default: + return null; + } + case "converse": + switch (normalizedKey) { + case "temperature": + return "temperature"; + case "max_tokens": + return "maxTokens"; + case "top_k": + return "topK"; + case "top_p": + return "topP"; + case "stop": + return "stopSequences"; + default: + return null; + } + case "js": + return null; + default: + const _exhaustiveCheck: never = toProvider; + throw new Error(`Unknown provider: ${_exhaustiveCheck}`); + } +} + export function translateParams( - toProvider: ModelFormat, + toProvider: ModelFormat | undefined, params: Record, -): Record { - const translatedParams: Record = {}; +): Record { + const translatedParams: Record = {}; for (const [k, v] of Object.entries(params || {})) { const safeValue = v ?? undefined; // Don't propagate "null" along - const translatedKey = modelParamToModelParam[k as keyof ModelParams] as - | keyof ModelParams - | undefined - | null; + const translatedKey = translateKey(toProvider, k); if (translatedKey === null) { continue; - } else if ( - translatedKey !== undefined && - defaultModelParamSettings[toProvider][translatedKey] !== undefined - ) { + } else if (safeValue !== undefined) { translatedParams[translatedKey] = safeValue; - } else { - translatedParams[k] = safeValue; } + // XXX should we add default params from defaultModelParamSettings? + // probably only do that if translateParams is being called from the prompt ui but not for proxy calls + // + // also, the previous logic here seemed incorrect in doing translatedParams[k] = saveValue. i dont + // see why we would want to pass along params we know are not accepted by toProvider } return translatedParams; } - -export const anthropicSupportedMediaTypes = [ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", -]; - -export const anthropicTextBlockSchema = z.object({ - type: z.literal("text").optional(), - text: z.string().default(""), -}); -export const anthropicImageBlockSchema = z.object({ - type: z.literal("image").optional(), - source: z.object({ - type: z.enum(["base64"]).optional(), - media_type: z.enum(["image/jpeg", "image/png", "image/gif", "image/webp"]), - data: z.string().default(""), - }), -}); -const anthropicContentBlockSchema = z.union([ - anthropicTextBlockSchema, - anthropicImageBlockSchema, -]); -const anthropicContentBlocksSchema = z.array(anthropicContentBlockSchema); -const anthropicContentSchema = z.union([ - z.string().default(""), - anthropicContentBlocksSchema, -]); - -export type AnthropicImageBlock = z.infer; -export type AnthropicContent = z.infer; diff --git a/packages/proxy/schema/schema.test.ts b/packages/proxy/schema/schema.test.ts new file mode 100644 index 0000000..72229b3 --- /dev/null +++ b/packages/proxy/schema/schema.test.ts @@ -0,0 +1,360 @@ +import { describe, test, expect } from "vitest"; +import { translateParams } from "./index"; +import { ModelFormat } from "./models"; + +type TranslateParamsCase = + | { + from: Record; + to: Record; + } + | "skip"; + +describe("translateParams", () => { + const temperature = 0.55; + const max_tokens = 12345; + const top_p = 0.123; + const top_k = 45; + const stop = ["\n"]; + const frequency_penalty = 0.1; + const presence_penalty = 0.2; + const n = 2; + + const matrix: Record< + ModelFormat, + Record + > = { + openai: { + openai: "skip", + anthropic: { + from: { + temperature, + max_tokens, + top_p, + stop, + frequency_penalty, + presence_penalty, + n, + }, + to: { temperature, max_tokens, top_p, stop_sequences: stop }, + }, + google: { + from: { + temperature, + max_tokens, + top_p, + stop, + frequency_penalty, + presence_penalty, + n, + }, + to: { temperature, maxOutputTokens: max_tokens, topP: top_p }, + }, + window: { + from: { + temperature, + max_tokens, + top_p, + stop, + frequency_penalty, + presence_penalty, + n, + }, + to: { temperature }, + }, + converse: { + from: { + temperature, + max_tokens, + top_p, + stop, + frequency_penalty, + presence_penalty, + n, + }, + to: { + temperature, + maxTokens: max_tokens, + topP: top_p, + stopSequences: stop, + }, + }, + js: { + from: { + temperature, + max_tokens, + top_p, + stop, + frequency_penalty, + presence_penalty, + n, + }, + to: {}, + }, + }, + anthropic: { + openai: { + from: { temperature, max_tokens, top_p, top_k, stop_sequences: stop }, + to: { temperature, max_tokens, top_p, stop }, + }, + anthropic: "skip", + google: { + from: { temperature, max_tokens, top_p, top_k, stop_sequences: stop }, + to: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + }, + window: { + from: { temperature, max_tokens, top_p, top_k, stop_sequences: stop }, + to: { temperature, topK: top_k }, + }, + converse: { + from: { temperature, max_tokens, top_p, top_k, stop_sequences: stop }, + to: { + temperature, + maxTokens: max_tokens, + topP: top_p, + topK: top_k, + stopSequences: stop, + }, + }, + js: { + from: { temperature, max_tokens, top_p, top_k, stop_sequences: stop }, + to: {}, + }, + }, + google: { + openai: { + from: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + to: { temperature, max_tokens, top_p }, + }, + anthropic: { + from: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + to: { temperature, max_tokens, top_p, top_k }, + }, + google: "skip", + window: { + from: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + to: { temperature, topK: top_k }, + }, + converse: { + from: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + to: { temperature, maxTokens: max_tokens, topP: top_p, topK: top_k }, + }, + js: { + from: { + temperature, + maxOutputTokens: max_tokens, + topP: top_p, + topK: top_k, + }, + to: {}, + }, + }, + window: { + openai: { + from: { temperature, topK: top_k }, + to: { temperature }, + }, + anthropic: { + from: { temperature, topK: top_k }, + to: { temperature, top_k }, + }, + google: { + from: { temperature, topK: top_k }, + to: { temperature, topK: top_k }, + }, + window: "skip", + converse: { + from: { temperature, topK: top_k }, + to: { temperature, topK: top_k }, + }, + js: { + from: { temperature, topK: top_k }, + to: {}, + }, + }, + converse: { + openai: { + from: { + temperature, + maxTokens: max_tokens, + topK: top_k, + topP: top_p, + stopSequences: ["\n"], + }, + to: { temperature, max_tokens, top_p, stop }, + }, + anthropic: { + from: { + temperature, + maxTokens: max_tokens, + topK: top_k, + topP: top_p, + stopSequences: ["\n"], + }, + to: { temperature, max_tokens, top_k, top_p, stop_sequences: stop }, + }, + google: { + from: { + temperature, + maxTokens: max_tokens, + topK: top_k, + topP: top_p, + stopSequences: ["\n"], + }, + to: { + temperature, + maxOutputTokens: max_tokens, + topK: top_k, + topP: top_p, + }, + }, + window: { + from: { + temperature, + maxTokens: max_tokens, + topK: top_k, + topP: top_p, + stopSequences: ["\n"], + }, + to: { temperature, topK: top_k }, + }, + converse: "skip", + js: { + from: { temperature, topK: top_k }, + to: {}, + }, + }, + js: { + openai: { + from: { some_param: "foo" }, + to: {}, + }, + anthropic: { + from: { some_param: "foo" }, + to: {}, + }, + google: { + from: { some_param: "foo" }, + to: {}, + }, + window: { + from: { some_param: "foo" }, + to: {}, + }, + converse: { + from: { some_param: "foo" }, + to: {}, + }, + js: "skip", + }, + }; + + test.each( + Object.entries(matrix).flatMap(([fromProvider, toParams]) => + Object.entries(toParams).flatMap(([toProvider, args]) => { + if (args === "skip") { + // XXX maybe test roundtrip? + return []; + } else { + return [ + { + fromProvider: fromProvider as ModelFormat, + toProvider: toProvider as ModelFormat, + fromParams: args.from, + toParams: args.to, + }, + ]; + } + }), + ), + )( + "translateParams from $fromProvider to $toProvider", + ({ fromProvider, toProvider, fromParams, toParams }) => { + if (fromProvider === toProvider) { + expect(translateParams(fromProvider, fromParams)).toEqual(toParams); + } else { + expect(translateParams(toProvider, fromParams)).toEqual(toParams); + } + }, + ); + + /* + test("openai -> anthropic", () => { + expect(translateParams("openai", { + temperature: 0.55, + top_p: 0.245, + max_tokens: 1000, + frequency_penalty: 0.1, + presence_penalty: 0.1, + response_format: { type: "json_schema", schema: { type: "object" } }, + n: 1, + stop: ["\n"], + reasoning_effort: "low", + })).toEqual({ + temperature: 0.55, + top_p: 0.245, + max_tokens: 1000, + //stop_sequences: ["\n"], + }); + }); + + test("anthropic -> openai", () => { + expect(translateParams("anthropic", { + max_tokens: 1000, + temperature: 0.5, + top_p: 0.245, + top_k: 54, + stop_sequences: ["\n"], + })).toEqual({ + temperature: 0.5, + top_p: 0.245, + top_k: 54, + max_tokens: 1000, + //stop: ["\n"], + }); + }); + + test("openai -> google", () => { + expect(translateParams("google", { + temperature: 0.55, + top_p: 0.245, + max_tokens: 1000, + frequency_penalty: 0.1, + presence_penalty: 0.1, + response_format: { type: "json_schema", schema: { type: "object" } }, + n: 1, + stop: ["\n"], + reasoning_effort: "low", + })).toEqual({ + temperature: 0.55, + topP: 0.245, + maxOutputTokens: 1000, + //stop: ["\n"], + }); + }); + */ +}); diff --git a/packages/proxy/src/providers/bedrock.ts b/packages/proxy/src/providers/bedrock.ts index 36c576d..8f03a75 100644 --- a/packages/proxy/src/providers/bedrock.ts +++ b/packages/proxy/src/providers/bedrock.ts @@ -806,6 +806,7 @@ function openAIResponse( }; } +// XXX can now replace with translateParams("converse", params) function translateInferenceConfig( params: Record, ): InferenceConfiguration { diff --git a/packages/proxy/src/providers/google.ts b/packages/proxy/src/providers/google.ts index 2cad8d8..f62dd6d 100644 --- a/packages/proxy/src/providers/google.ts +++ b/packages/proxy/src/providers/google.ts @@ -260,15 +260,3 @@ export function googleCompletionToOpenAICompletion( : undefined, }; } - -export const OpenAIParamsToGoogleParams: { - [name: string]: string | null; -} = { - temperature: "temperature", - top_p: "topP", - stop: "stopSequences", - max_tokens: "maxOutputTokens", - frequency_penalty: null, - presence_penalty: null, - tool_choice: null, -}; diff --git a/packages/proxy/src/proxy.ts b/packages/proxy/src/proxy.ts index 12d80bf..d1caab8 100644 --- a/packages/proxy/src/proxy.ts +++ b/packages/proxy/src/proxy.ts @@ -45,7 +45,6 @@ import { googleEventToOpenAIChatEvent, openAIContentToGoogleContent, openAIMessagesToGoogleMessages, - OpenAIParamsToGoogleParams, } from "./providers/google"; import { Message, @@ -2081,10 +2080,10 @@ async function fetchAnthropicChatCompletions({ } messages = flattenAnthropicMessages(messages); - const params: Record = { - max_tokens: 1024, // Required param - ...translateParams("anthropic", oaiParams), - }; + const params: Record = translateParams( + "anthropic", + oaiParams, + ); const stop = z .union([z.string(), z.array(z.string())]) @@ -2286,7 +2285,12 @@ async function googleSchemaFromJsonSchema(schema: any): Promise { return schema; } -async function openAIToolsToGoogleTools(params: ChatCompletionCreateParams) { +async function openAIToolsToGoogleTools( + params: Pick< + ChatCompletionCreateParams, + "tools" | "tool_choice" | "functions" + >, +) { if (params.tools || params.functions) { params.tools = params.tools || @@ -2522,18 +2526,7 @@ async function fetchGoogleChatCompletions({ const content = await openAIMessagesToGoogleMessages( oaiMessages.filter((m: any) => m.role !== "system"), ); - const params = Object.fromEntries( - Object.entries(translateParams("google", oaiParams)) - .map(([key, value]) => { - const translatedKey = OpenAIParamsToGoogleParams[key]; - if (translatedKey === null) { - // These are unsupported params - return [null, null]; - } - return [translatedKey ?? key, value]; - }) - .filter(([k, _]) => k !== null), - ); + const params = translateParams("google", oaiParams); let fullURL: URL; if (secret.type === "google") {