Skip to content

WIP: revamped model parameter translation #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 210 additions & 70 deletions packages/proxy/schema/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
ModelParams,
} from "@braintrust/core/typespecs";
import { AvailableModels, ModelFormat, ModelEndpointType } from "./models";
import { isObject } from "@braintrust/core";

export * from "./secrets";
export * from "./models";
Expand All @@ -28,31 +29,6 @@ export const MessageTypeToMessageType: {
model: "assistant",
};

export const modelParamToModelParam: {
[name: string]: keyof AnyModelParam | null;
} = {
temperature: "temperature",
top_p: "top_p",
top_k: "top_k",
max_tokens: "max_tokens",
max_tokens_to_sample: null,
use_cache: "use_cache",
maxOutputTokens: "max_tokens",
topP: "top_p",
topK: "top_k",
presence_penalty: null,
frequency_penalty: null,
user: null,
function_call: null,
n: null,
logprobs: null,
stream_options: null,
parallel_tool_calls: null,
response_format: null,
reasoning_effort: null,
stop: null,
};

export const sliderSpecs: {
// min, max, step, required
[name: string]: [number, number, number, boolean];
Expand Down Expand Up @@ -423,60 +399,224 @@ ${content}<|im_end|>`,
);
}

// XXX we can't use @braintrust/core from the workspace so testing with this here.
// eventually i'll move to sdk and import
const braintrustModelParamSchema = z.object({
use_cache: z.boolean().optional(),

temperature: z.number().optional(),
max_tokens: z.number().optional(),
// XXX how do we want to handle deprecated params
max_completion_tokens: z.number().optional(),
top_p: z.number().optional(),
top_k: z.number().optional(),
frequency_penalty: z.number().optional(),
presence_penalty: z.number().optional(),
/* XXX we special case these in the proxy but i need to understand how. will probably keep that logic where it is for now
response_format: z
.object({
type: z.literal("json_object"),
})
.nullish(),
tool_choice: z.object({
type: z.literal("function"),
}).optional(),
function_call: z.object({
name: z.string().optional(),
}).optional(),
*/
n: z.number().optional(),
stop: z.array(z.string()).optional(),
reasoning_effort: z.enum(["low", "medium", "high"]).optional(),
});
type BraintrustModelParams = z.infer<typeof braintrustModelParamSchema>;
type BraintrustParamMapping =
| keyof BraintrustModelParams
| {
key: keyof BraintrustModelParams | null;
deprecated?: boolean;
o1_like?: boolean;
};

// XXX add to sdk
type ConverseModelParams = {
maxTokens: number;
stopSequences: string[];
};

const anyModelParamToBraintrustModelParam: Record<
keyof AnyModelParam | keyof ConverseModelParams,
BraintrustParamMapping
> = {
use_cache: "use_cache",
temperature: "temperature",

max_tokens: "max_tokens",
max_completion_tokens: { key: "max_tokens", o1_like: true },
maxOutputTokens: "max_tokens",
maxTokens: "max_tokens",
// XXX map this to max_tokens?
max_tokens_to_sample: { key: null, deprecated: true },

top_p: "top_p",
topP: "top_p",
top_k: "top_k",
topK: "top_k",
frequency_penalty: "frequency_penalty", // null
presence_penalty: "presence_penalty", // null

stop: "stop", // null
stop_sequences: "stop", // null
stopSequences: "stop", // null

n: "n", // null

reasoning_effort: { key: "reasoning_effort", o1_like: true },

response_format: { key: null }, // handled elsewhere?
function_call: { key: null }, // handled elsewhere
tool_choice: { key: null }, // handled elsewhere
// XXX parallel_tool_calls: { key: null }, // handled elsewhere
};

function translateKey(
toProvider: ModelFormat | undefined,
key: string,
): keyof ModelParams | null {
const braintrustKey =
anyModelParamToBraintrustModelParam[key as keyof AnyModelParam];
let normalizedKey: keyof BraintrustModelParams | null = null;
if (braintrustKey === undefined) {
normalizedKey = null;
} else if (!isObject(braintrustKey)) {
normalizedKey = braintrustKey;
} else {
if (braintrustKey.deprecated) {
console.warn(`Deprecated model param: ${key}`);
}

if (braintrustKey.key === null) {
normalizedKey = null;
} else {
normalizedKey = braintrustKey.key;
}
}

if (normalizedKey === null) {
return null;
}

// XXX if toProvider is undefined, return the normalized key. this is useful for the ui to parse span data when the
// provider is not known. maybe we can try harder to infer the provider?
if (toProvider === undefined) {
return normalizedKey;
}

// XXX turn these into Record<keyof BraintrustModelParams, keyof z.infer<typeof anthropicModelParamsSchema> | null>
// maps from braintrust key to provider key. can live in proxy/providers
switch (toProvider) {
case "openai":
switch (normalizedKey) {
case "temperature":
return "temperature";
case "max_tokens":
return "max_tokens";
case "top_p":
return "top_p";
case "stop":
return "stop";
case "frequency_penalty":
return "frequency_penalty";
case "presence_penalty":
return "presence_penalty";
case "n":
return "n";
default:
return null;
}
case "anthropic":
switch (normalizedKey) {
case "temperature":
return "temperature";
case "max_tokens":
return "max_tokens";
case "top_k":
return "top_k";
case "top_p":
return "top_p";
case "stop":
return "stop_sequences";
default:
return null;
}
case "google":
switch (normalizedKey) {
case "temperature":
return "temperature";
case "top_p":
return "topP";
case "top_k":
return "topK";
/* XXX add support for this?
case "stop":
return "stopSequences";
*/
case "max_tokens":
return "maxOutputTokens";
default:
return null;
}
case "window":
switch (normalizedKey) {
case "temperature":
return "temperature";
case "top_k":
return "topK";
default:
return null;
}
case "converse":
switch (normalizedKey) {
case "temperature":
return "temperature";
case "max_tokens":
return "maxTokens";
case "top_k":
return "topK";
case "top_p":
return "topP";
case "stop":
return "stopSequences";
default:
return null;
}
case "js":
return null;
default:
const _exhaustiveCheck: never = toProvider;
throw new Error(`Unknown provider: ${_exhaustiveCheck}`);
}
}

export function translateParams(
toProvider: ModelFormat,
toProvider: ModelFormat | undefined,
params: Record<string, unknown>,
): Record<string, unknown> {
const translatedParams: Record<string, unknown> = {};
): Record<keyof ModelParams, unknown> {
const translatedParams: Record<keyof ModelParams, unknown> = {};
for (const [k, v] of Object.entries(params || {})) {
const safeValue = v ?? undefined; // Don't propagate "null" along
const translatedKey = modelParamToModelParam[k as keyof ModelParams] as
| keyof ModelParams
| undefined
| null;
const translatedKey = translateKey(toProvider, k);
if (translatedKey === null) {
continue;
} else if (
translatedKey !== undefined &&
defaultModelParamSettings[toProvider][translatedKey] !== undefined
) {
} else if (safeValue !== undefined) {
translatedParams[translatedKey] = safeValue;
} else {
translatedParams[k] = safeValue;
}
// XXX should we add default params from defaultModelParamSettings?
// probably only do that if translateParams is being called from the prompt ui but not for proxy calls
//
// also, the previous logic here seemed incorrect in doing translatedParams[k] = saveValue. i dont
// see why we would want to pass along params we know are not accepted by toProvider
}

return translatedParams;
}

export const anthropicSupportedMediaTypes = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
];

export const anthropicTextBlockSchema = z.object({
type: z.literal("text").optional(),
text: z.string().default(""),
});
export const anthropicImageBlockSchema = z.object({
type: z.literal("image").optional(),
source: z.object({
type: z.enum(["base64"]).optional(),
media_type: z.enum(["image/jpeg", "image/png", "image/gif", "image/webp"]),
data: z.string().default(""),
}),
});
const anthropicContentBlockSchema = z.union([
anthropicTextBlockSchema,
anthropicImageBlockSchema,
]);
const anthropicContentBlocksSchema = z.array(anthropicContentBlockSchema);
const anthropicContentSchema = z.union([
z.string().default(""),
anthropicContentBlocksSchema,
]);

export type AnthropicImageBlock = z.infer<typeof anthropicImageBlockSchema>;
export type AnthropicContent = z.infer<typeof anthropicContentSchema>;
Loading