Skip to content

Commit e116e29

Browse files
committed
revamped model parameter translation
1 parent 80ed656 commit e116e29

File tree

5 files changed

+582
-100
lines changed

5 files changed

+582
-100
lines changed

packages/proxy/schema/index.ts

Lines changed: 210 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import type {
66
ModelParams,
77
} from "@braintrust/core/typespecs";
88
import { AvailableModels, ModelFormat, ModelEndpointType } from "./models";
9+
import { isObject } from "@braintrust/core";
910

1011
export * from "./secrets";
1112
export * from "./models";
@@ -28,31 +29,6 @@ export const MessageTypeToMessageType: {
2829
model: "assistant",
2930
};
3031

31-
export const modelParamToModelParam: {
32-
[name: string]: keyof AnyModelParam | null;
33-
} = {
34-
temperature: "temperature",
35-
top_p: "top_p",
36-
top_k: "top_k",
37-
max_tokens: "max_tokens",
38-
max_tokens_to_sample: null,
39-
use_cache: "use_cache",
40-
maxOutputTokens: "max_tokens",
41-
topP: "top_p",
42-
topK: "top_k",
43-
presence_penalty: null,
44-
frequency_penalty: null,
45-
user: null,
46-
function_call: null,
47-
n: null,
48-
logprobs: null,
49-
stream_options: null,
50-
parallel_tool_calls: null,
51-
response_format: null,
52-
reasoning_effort: null,
53-
stop: null,
54-
};
55-
5632
export const sliderSpecs: {
5733
// min, max, step, required
5834
[name: string]: [number, number, number, boolean];
@@ -423,60 +399,224 @@ ${content}<|im_end|>`,
423399
);
424400
}
425401

402+
const braintrustModelParamSchema = z.object({
403+
use_cache: z.boolean().optional(),
404+
405+
temperature: z.number().optional(),
406+
max_tokens: z.number().optional(),
407+
// XXX how do we want to handle deprecated params
408+
max_completion_tokens: z.number().optional(),
409+
top_p: z.number().optional(),
410+
top_k: z.number().optional(),
411+
frequency_penalty: z.number().optional(),
412+
presence_penalty: z.number().optional(),
413+
response_format: z
414+
.object({
415+
type: z.literal("json_object"),
416+
})
417+
.nullish(),
418+
/*
419+
tool_choice: z.object({
420+
type: z.literal("function"),
421+
}).optional(),
422+
function_call: z.object({
423+
name: z.string().optional(),
424+
}).optional(),
425+
*/
426+
n: z.number().optional(),
427+
stop: z.array(z.string()).optional(),
428+
reasoning_effort: z.enum(["low", "medium", "high"]).optional(),
429+
});
430+
type BraintrustModelParams = z.infer<typeof braintrustModelParamSchema>;
431+
type BraintrustParamMapping =
432+
| keyof BraintrustModelParams
433+
| {
434+
key: keyof BraintrustModelParams | null;
435+
deprecated?: boolean;
436+
o1_like?: boolean;
437+
};
438+
439+
// XXX add to sdk
440+
type ConverseModelParams = {
441+
maxTokens: number;
442+
stopSequences: string[];
443+
};
444+
445+
const anyModelParamToBraintrustModelParam: Record<
446+
keyof AnyModelParam | keyof ConverseModelParams,
447+
BraintrustParamMapping
448+
> = {
449+
use_cache: "use_cache",
450+
temperature: "temperature",
451+
452+
max_tokens: "max_tokens",
453+
max_completion_tokens: { key: "max_tokens", o1_like: true },
454+
maxOutputTokens: "max_tokens",
455+
maxTokens: "max_tokens",
456+
// XXX map this to max_tokens?
457+
max_tokens_to_sample: { key: null, deprecated: true },
458+
459+
top_p: "top_p",
460+
topP: "top_p",
461+
top_k: "top_k",
462+
topK: "top_k",
463+
frequency_penalty: "frequency_penalty", // null
464+
presence_penalty: "presence_penalty", // null
465+
466+
stop: "stop", // null
467+
stop_sequences: "stop", // null
468+
stopSequences: "stop", // null
469+
470+
n: "n", // null
471+
472+
reasoning_effort: { key: "reasoning_effort", o1_like: true },
473+
474+
response_format: "response_format", // handled elsewhere?
475+
function_call: { key: null }, // handled elsewhere
476+
tool_choice: { key: null }, // handled elsewhere
477+
// parallel_tool_calls: { key: null }, // handled elsewhere
478+
};
479+
480+
function translateKey(
481+
toProvider: ModelFormat | undefined,
482+
key: string,
483+
): keyof ModelParams | null {
484+
const braintrustKey =
485+
anyModelParamToBraintrustModelParam[key as keyof AnyModelParam];
486+
let normalizedKey: keyof BraintrustModelParams | null = null;
487+
if (braintrustKey === undefined) {
488+
normalizedKey = null;
489+
} else if (!isObject(braintrustKey)) {
490+
normalizedKey = braintrustKey;
491+
} else if (isObject(braintrustKey)) {
492+
if (braintrustKey.deprecated) {
493+
console.warn(`Deprecated model param: ${key}`);
494+
}
495+
496+
if (braintrustKey.key === null) {
497+
normalizedKey = null;
498+
} else {
499+
normalizedKey = braintrustKey.key;
500+
}
501+
} else {
502+
normalizedKey = braintrustKey;
503+
}
504+
505+
if (normalizedKey === null) {
506+
return null;
507+
}
508+
509+
// XXX if toProvider is undefined, return the normalized key. this is useful for the ui to parse span data when the
510+
// provider is not known. maybe we can try harder to infer the provider?
511+
if (toProvider === undefined) {
512+
return normalizedKey;
513+
}
514+
515+
// XXX turn these into Record<keyof BraintrustModelParams, keyof z.infer<typeof anthropicModelParamsSchema> | null>
516+
// maps from braintrust key to provider key
517+
switch (toProvider) {
518+
case "openai":
519+
switch (normalizedKey) {
520+
case "temperature":
521+
return "temperature";
522+
case "max_tokens":
523+
return "max_tokens";
524+
case "top_p":
525+
return "top_p";
526+
case "stop":
527+
return "stop";
528+
case "frequency_penalty":
529+
return "frequency_penalty";
530+
case "presence_penalty":
531+
return "presence_penalty";
532+
case "n":
533+
return "n";
534+
default:
535+
return null;
536+
}
537+
case "anthropic":
538+
switch (normalizedKey) {
539+
case "temperature":
540+
return "temperature";
541+
case "max_tokens":
542+
return "max_tokens";
543+
case "top_k":
544+
return "top_k";
545+
case "top_p":
546+
return "top_p";
547+
case "stop":
548+
return "stop_sequences";
549+
default:
550+
return null;
551+
}
552+
case "google":
553+
switch (normalizedKey) {
554+
case "temperature":
555+
return "temperature";
556+
case "top_p":
557+
return "topP";
558+
case "top_k":
559+
return "topK";
560+
/* XXX add support for this?
561+
case "stop":
562+
return "stopSequences";
563+
*/
564+
case "max_tokens":
565+
return "maxOutputTokens";
566+
default:
567+
return null;
568+
}
569+
case "window":
570+
switch (normalizedKey) {
571+
case "temperature":
572+
return "temperature";
573+
case "top_k":
574+
return "topK";
575+
default:
576+
return null;
577+
}
578+
case "converse":
579+
switch (normalizedKey) {
580+
case "temperature":
581+
return "temperature";
582+
case "max_tokens":
583+
return "maxTokens";
584+
case "top_k":
585+
return "topK";
586+
case "top_p":
587+
return "topP";
588+
case "stop":
589+
return "stopSequences";
590+
default:
591+
return null;
592+
}
593+
case "js":
594+
return null;
595+
default:
596+
const _exhaustiveCheck: never = toProvider;
597+
throw new Error(`Unknown provider: ${_exhaustiveCheck}`);
598+
}
599+
}
600+
426601
export function translateParams(
427-
toProvider: ModelFormat,
602+
toProvider: ModelFormat | undefined,
428603
params: Record<string, unknown>,
429-
): Record<string, unknown> {
430-
const translatedParams: Record<string, unknown> = {};
604+
): Record<keyof ModelParams, unknown> {
605+
const translatedParams: Record<keyof ModelParams, unknown> = {};
431606
for (const [k, v] of Object.entries(params || {})) {
432607
const safeValue = v ?? undefined; // Don't propagate "null" along
433-
const translatedKey = modelParamToModelParam[k as keyof ModelParams] as
434-
| keyof ModelParams
435-
| undefined
436-
| null;
608+
const translatedKey = translateKey(toProvider, k);
437609
if (translatedKey === null) {
438610
continue;
439-
} else if (
440-
translatedKey !== undefined &&
441-
defaultModelParamSettings[toProvider][translatedKey] !== undefined
442-
) {
611+
} else if (safeValue !== undefined) {
443612
translatedParams[translatedKey] = safeValue;
444-
} else {
445-
translatedParams[k] = safeValue;
446613
}
614+
// XXX should we add default params from defaultModelParamSettings?
615+
// probably only do that if translateParams is being called from the prompt ui but not for proxy calls
616+
//
617+
// also, the previous logic here seemed incorrect in doing translatedParams[k] = saveValue. i dont
618+
// see why we would want to pass along params we know are not accepted by toProvider
447619
}
448620

449621
return translatedParams;
450622
}
451-
452-
export const anthropicSupportedMediaTypes = [
453-
"image/jpeg",
454-
"image/png",
455-
"image/gif",
456-
"image/webp",
457-
];
458-
459-
export const anthropicTextBlockSchema = z.object({
460-
type: z.literal("text").optional(),
461-
text: z.string().default(""),
462-
});
463-
export const anthropicImageBlockSchema = z.object({
464-
type: z.literal("image").optional(),
465-
source: z.object({
466-
type: z.enum(["base64"]).optional(),
467-
media_type: z.enum(["image/jpeg", "image/png", "image/gif", "image/webp"]),
468-
data: z.string().default(""),
469-
}),
470-
});
471-
const anthropicContentBlockSchema = z.union([
472-
anthropicTextBlockSchema,
473-
anthropicImageBlockSchema,
474-
]);
475-
const anthropicContentBlocksSchema = z.array(anthropicContentBlockSchema);
476-
const anthropicContentSchema = z.union([
477-
z.string().default(""),
478-
anthropicContentBlocksSchema,
479-
]);
480-
481-
export type AnthropicImageBlock = z.infer<typeof anthropicImageBlockSchema>;
482-
export type AnthropicContent = z.infer<typeof anthropicContentSchema>;

0 commit comments

Comments
 (0)