Skip to content

Commit be4ffdc

Browse files
committed
revamped model parameter translation
1 parent 80ed656 commit be4ffdc

File tree

5 files changed

+584
-100
lines changed

5 files changed

+584
-100
lines changed

packages/proxy/schema/index.ts

Lines changed: 212 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import type {
66
ModelParams,
77
} from "@braintrust/core/typespecs";
88
import { AvailableModels, ModelFormat, ModelEndpointType } from "./models";
9+
import { isObject } from "@braintrust/core";
910

1011
export * from "./secrets";
1112
export * from "./models";
@@ -28,31 +29,6 @@ export const MessageTypeToMessageType: {
2829
model: "assistant",
2930
};
3031

31-
export const modelParamToModelParam: {
32-
[name: string]: keyof AnyModelParam | null;
33-
} = {
34-
temperature: "temperature",
35-
top_p: "top_p",
36-
top_k: "top_k",
37-
max_tokens: "max_tokens",
38-
max_tokens_to_sample: null,
39-
use_cache: "use_cache",
40-
maxOutputTokens: "max_tokens",
41-
topP: "top_p",
42-
topK: "top_k",
43-
presence_penalty: null,
44-
frequency_penalty: null,
45-
user: null,
46-
function_call: null,
47-
n: null,
48-
logprobs: null,
49-
stream_options: null,
50-
parallel_tool_calls: null,
51-
response_format: null,
52-
reasoning_effort: null,
53-
stop: null,
54-
};
55-
5632
export const sliderSpecs: {
5733
// min, max, step, required
5834
[name: string]: [number, number, number, boolean];
@@ -423,60 +399,226 @@ ${content}<|im_end|>`,
423399
);
424400
}
425401

402+
// XXX we can't use @braintrust/core from the workspace so testing with this here.
403+
// eventually i'll move to sdk and import
404+
const braintrustModelParamSchema = z.object({
405+
use_cache: z.boolean().optional(),
406+
407+
temperature: z.number().optional(),
408+
max_tokens: z.number().optional(),
409+
// XXX how do we want to handle deprecated params
410+
max_completion_tokens: z.number().optional(),
411+
top_p: z.number().optional(),
412+
top_k: z.number().optional(),
413+
frequency_penalty: z.number().optional(),
414+
presence_penalty: z.number().optional(),
415+
/* XXX we special case these in the proxy but i need to understand how. will probably keep that logic where it is for now
416+
response_format: z
417+
.object({
418+
type: z.literal("json_object"),
419+
})
420+
.nullish(),
421+
tool_choice: z.object({
422+
type: z.literal("function"),
423+
}).optional(),
424+
function_call: z.object({
425+
name: z.string().optional(),
426+
}).optional(),
427+
*/
428+
n: z.number().optional(),
429+
stop: z.array(z.string()).optional(),
430+
reasoning_effort: z.enum(["low", "medium", "high"]).optional(),
431+
});
432+
type BraintrustModelParams = z.infer<typeof braintrustModelParamSchema>;
433+
type BraintrustParamMapping =
434+
| keyof BraintrustModelParams
435+
| {
436+
key: keyof BraintrustModelParams | null;
437+
deprecated?: boolean;
438+
o1_like?: boolean;
439+
};
440+
441+
// XXX add to sdk
442+
type ConverseModelParams = {
443+
maxTokens: number;
444+
stopSequences: string[];
445+
};
446+
447+
const anyModelParamToBraintrustModelParam: Record<
448+
keyof AnyModelParam | keyof ConverseModelParams,
449+
BraintrustParamMapping
450+
> = {
451+
use_cache: "use_cache",
452+
temperature: "temperature",
453+
454+
max_tokens: "max_tokens",
455+
max_completion_tokens: { key: "max_tokens", o1_like: true },
456+
maxOutputTokens: "max_tokens",
457+
maxTokens: "max_tokens",
458+
// XXX map this to max_tokens?
459+
max_tokens_to_sample: { key: null, deprecated: true },
460+
461+
top_p: "top_p",
462+
topP: "top_p",
463+
top_k: "top_k",
464+
topK: "top_k",
465+
frequency_penalty: "frequency_penalty", // null
466+
presence_penalty: "presence_penalty", // null
467+
468+
stop: "stop", // null
469+
stop_sequences: "stop", // null
470+
stopSequences: "stop", // null
471+
472+
n: "n", // null
473+
474+
reasoning_effort: { key: "reasoning_effort", o1_like: true },
475+
476+
response_format: { key: null }, // handled elsewhere?
477+
function_call: { key: null }, // handled elsewhere
478+
tool_choice: { key: null }, // handled elsewhere
479+
// XXX parallel_tool_calls: { key: null }, // handled elsewhere
480+
};
481+
482+
function translateKey(
483+
toProvider: ModelFormat | undefined,
484+
key: string,
485+
): keyof ModelParams | null {
486+
const braintrustKey =
487+
anyModelParamToBraintrustModelParam[key as keyof AnyModelParam];
488+
let normalizedKey: keyof BraintrustModelParams | null = null;
489+
if (braintrustKey === undefined) {
490+
normalizedKey = null;
491+
} else if (!isObject(braintrustKey)) {
492+
normalizedKey = braintrustKey;
493+
} else if (isObject(braintrustKey)) {
494+
if (braintrustKey.deprecated) {
495+
console.warn(`Deprecated model param: ${key}`);
496+
}
497+
498+
if (braintrustKey.key === null) {
499+
normalizedKey = null;
500+
} else {
501+
normalizedKey = braintrustKey.key;
502+
}
503+
} else {
504+
normalizedKey = braintrustKey;
505+
}
506+
507+
if (normalizedKey === null) {
508+
return null;
509+
}
510+
511+
// XXX if toProvider is undefined, return the normalized key. this is useful for the ui to parse span data when the
512+
// provider is not known. maybe we can try harder to infer the provider?
513+
if (toProvider === undefined) {
514+
return normalizedKey;
515+
}
516+
517+
// XXX turn these into Record<keyof BraintrustModelParams, keyof z.infer<typeof anthropicModelParamsSchema> | null>
518+
// maps from braintrust key to provider key. can live in proxy/providers
519+
switch (toProvider) {
520+
case "openai":
521+
switch (normalizedKey) {
522+
case "temperature":
523+
return "temperature";
524+
case "max_tokens":
525+
return "max_tokens";
526+
case "top_p":
527+
return "top_p";
528+
case "stop":
529+
return "stop";
530+
case "frequency_penalty":
531+
return "frequency_penalty";
532+
case "presence_penalty":
533+
return "presence_penalty";
534+
case "n":
535+
return "n";
536+
default:
537+
return null;
538+
}
539+
case "anthropic":
540+
switch (normalizedKey) {
541+
case "temperature":
542+
return "temperature";
543+
case "max_tokens":
544+
return "max_tokens";
545+
case "top_k":
546+
return "top_k";
547+
case "top_p":
548+
return "top_p";
549+
case "stop":
550+
return "stop_sequences";
551+
default:
552+
return null;
553+
}
554+
case "google":
555+
switch (normalizedKey) {
556+
case "temperature":
557+
return "temperature";
558+
case "top_p":
559+
return "topP";
560+
case "top_k":
561+
return "topK";
562+
/* XXX add support for this?
563+
case "stop":
564+
return "stopSequences";
565+
*/
566+
case "max_tokens":
567+
return "maxOutputTokens";
568+
default:
569+
return null;
570+
}
571+
case "window":
572+
switch (normalizedKey) {
573+
case "temperature":
574+
return "temperature";
575+
case "top_k":
576+
return "topK";
577+
default:
578+
return null;
579+
}
580+
case "converse":
581+
switch (normalizedKey) {
582+
case "temperature":
583+
return "temperature";
584+
case "max_tokens":
585+
return "maxTokens";
586+
case "top_k":
587+
return "topK";
588+
case "top_p":
589+
return "topP";
590+
case "stop":
591+
return "stopSequences";
592+
default:
593+
return null;
594+
}
595+
case "js":
596+
return null;
597+
default:
598+
const _exhaustiveCheck: never = toProvider;
599+
throw new Error(`Unknown provider: ${_exhaustiveCheck}`);
600+
}
601+
}
602+
426603
export function translateParams(
427-
toProvider: ModelFormat,
604+
toProvider: ModelFormat | undefined,
428605
params: Record<string, unknown>,
429-
): Record<string, unknown> {
430-
const translatedParams: Record<string, unknown> = {};
606+
): Record<keyof ModelParams, unknown> {
607+
const translatedParams: Record<keyof ModelParams, unknown> = {};
431608
for (const [k, v] of Object.entries(params || {})) {
432609
const safeValue = v ?? undefined; // Don't propagate "null" along
433-
const translatedKey = modelParamToModelParam[k as keyof ModelParams] as
434-
| keyof ModelParams
435-
| undefined
436-
| null;
610+
const translatedKey = translateKey(toProvider, k);
437611
if (translatedKey === null) {
438612
continue;
439-
} else if (
440-
translatedKey !== undefined &&
441-
defaultModelParamSettings[toProvider][translatedKey] !== undefined
442-
) {
613+
} else if (safeValue !== undefined) {
443614
translatedParams[translatedKey] = safeValue;
444-
} else {
445-
translatedParams[k] = safeValue;
446615
}
616+
// XXX should we add default params from defaultModelParamSettings?
617+
// probably only do that if translateParams is being called from the prompt ui but not for proxy calls
618+
//
619+
// also, the previous logic here seemed incorrect in doing translatedParams[k] = saveValue. i dont
620+
// see why we would want to pass along params we know are not accepted by toProvider
447621
}
448622

449623
return translatedParams;
450624
}
451-
452-
export const anthropicSupportedMediaTypes = [
453-
"image/jpeg",
454-
"image/png",
455-
"image/gif",
456-
"image/webp",
457-
];
458-
459-
export const anthropicTextBlockSchema = z.object({
460-
type: z.literal("text").optional(),
461-
text: z.string().default(""),
462-
});
463-
export const anthropicImageBlockSchema = z.object({
464-
type: z.literal("image").optional(),
465-
source: z.object({
466-
type: z.enum(["base64"]).optional(),
467-
media_type: z.enum(["image/jpeg", "image/png", "image/gif", "image/webp"]),
468-
data: z.string().default(""),
469-
}),
470-
});
471-
const anthropicContentBlockSchema = z.union([
472-
anthropicTextBlockSchema,
473-
anthropicImageBlockSchema,
474-
]);
475-
const anthropicContentBlocksSchema = z.array(anthropicContentBlockSchema);
476-
const anthropicContentSchema = z.union([
477-
z.string().default(""),
478-
anthropicContentBlocksSchema,
479-
]);
480-
481-
export type AnthropicImageBlock = z.infer<typeof anthropicImageBlockSchema>;
482-
export type AnthropicContent = z.infer<typeof anthropicContentSchema>;

0 commit comments

Comments
 (0)