@@ -6,6 +6,7 @@ import type {
6
6
ModelParams ,
7
7
} from "@braintrust/core/typespecs" ;
8
8
import { AvailableModels , ModelFormat , ModelEndpointType } from "./models" ;
9
+ import { isObject } from "@braintrust/core" ;
9
10
10
11
export * from "./secrets" ;
11
12
export * from "./models" ;
@@ -28,31 +29,6 @@ export const MessageTypeToMessageType: {
28
29
model : "assistant" ,
29
30
} ;
30
31
31
- export const modelParamToModelParam : {
32
- [ name : string ] : keyof AnyModelParam | null ;
33
- } = {
34
- temperature : "temperature" ,
35
- top_p : "top_p" ,
36
- top_k : "top_k" ,
37
- max_tokens : "max_tokens" ,
38
- max_tokens_to_sample : null ,
39
- use_cache : "use_cache" ,
40
- maxOutputTokens : "max_tokens" ,
41
- topP : "top_p" ,
42
- topK : "top_k" ,
43
- presence_penalty : null ,
44
- frequency_penalty : null ,
45
- user : null ,
46
- function_call : null ,
47
- n : null ,
48
- logprobs : null ,
49
- stream_options : null ,
50
- parallel_tool_calls : null ,
51
- response_format : null ,
52
- reasoning_effort : null ,
53
- stop : null ,
54
- } ;
55
-
56
32
export const sliderSpecs : {
57
33
// min, max, step, required
58
34
[ name : string ] : [ number , number , number , boolean ] ;
@@ -423,60 +399,224 @@ ${content}<|im_end|>`,
423
399
) ;
424
400
}
425
401
402
+ const braintrustModelParamSchema = z . object ( {
403
+ use_cache : z . boolean ( ) . optional ( ) ,
404
+
405
+ temperature : z . number ( ) . optional ( ) ,
406
+ max_tokens : z . number ( ) . optional ( ) ,
407
+ // XXX how do we want to handle deprecated params
408
+ max_completion_tokens : z . number ( ) . optional ( ) ,
409
+ top_p : z . number ( ) . optional ( ) ,
410
+ top_k : z . number ( ) . optional ( ) ,
411
+ frequency_penalty : z . number ( ) . optional ( ) ,
412
+ presence_penalty : z . number ( ) . optional ( ) ,
413
+ response_format : z
414
+ . object ( {
415
+ type : z . literal ( "json_object" ) ,
416
+ } )
417
+ . nullish ( ) ,
418
+ /*
419
+ tool_choice: z.object({
420
+ type: z.literal("function"),
421
+ }).optional(),
422
+ function_call: z.object({
423
+ name: z.string().optional(),
424
+ }).optional(),
425
+ */
426
+ n : z . number ( ) . optional ( ) ,
427
+ stop : z . array ( z . string ( ) ) . optional ( ) ,
428
+ reasoning_effort : z . enum ( [ "low" , "medium" , "high" ] ) . optional ( ) ,
429
+ } ) ;
430
+ type BraintrustModelParams = z . infer < typeof braintrustModelParamSchema > ;
431
+ type BraintrustParamMapping =
432
+ | keyof BraintrustModelParams
433
+ | {
434
+ key : keyof BraintrustModelParams | null ;
435
+ deprecated ?: boolean ;
436
+ o1_like ?: boolean ;
437
+ } ;
438
+
439
+ // XXX add to sdk
440
+ type ConverseModelParams = {
441
+ maxTokens : number ;
442
+ stopSequences : string [ ] ;
443
+ } ;
444
+
445
+ const anyModelParamToBraintrustModelParam : Record <
446
+ keyof AnyModelParam | keyof ConverseModelParams ,
447
+ BraintrustParamMapping
448
+ > = {
449
+ use_cache : "use_cache" ,
450
+ temperature : "temperature" ,
451
+
452
+ max_tokens : "max_tokens" ,
453
+ max_completion_tokens : { key : "max_tokens" , o1_like : true } ,
454
+ maxOutputTokens : "max_tokens" ,
455
+ maxTokens : "max_tokens" ,
456
+ // XXX map this to max_tokens?
457
+ max_tokens_to_sample : { key : null , deprecated : true } ,
458
+
459
+ top_p : "top_p" ,
460
+ topP : "top_p" ,
461
+ top_k : "top_k" ,
462
+ topK : "top_k" ,
463
+ frequency_penalty : "frequency_penalty" , // null
464
+ presence_penalty : "presence_penalty" , // null
465
+
466
+ stop : "stop" , // null
467
+ stop_sequences : "stop" , // null
468
+ stopSequences : "stop" , // null
469
+
470
+ n : "n" , // null
471
+
472
+ reasoning_effort : { key : "reasoning_effort" , o1_like : true } ,
473
+
474
+ response_format : "response_format" , // handled elsewhere?
475
+ function_call : { key : null } , // handled elsewhere
476
+ tool_choice : { key : null } , // handled elsewhere
477
+ // parallel_tool_calls: { key: null }, // handled elsewhere
478
+ } ;
479
+
480
+ function translateKey (
481
+ toProvider : ModelFormat | undefined ,
482
+ key : string ,
483
+ ) : keyof ModelParams | null {
484
+ const braintrustKey =
485
+ anyModelParamToBraintrustModelParam [ key as keyof AnyModelParam ] ;
486
+ let normalizedKey : keyof BraintrustModelParams | null = null ;
487
+ if ( braintrustKey === undefined ) {
488
+ normalizedKey = null ;
489
+ } else if ( ! isObject ( braintrustKey ) ) {
490
+ normalizedKey = braintrustKey ;
491
+ } else if ( isObject ( braintrustKey ) ) {
492
+ if ( braintrustKey . deprecated ) {
493
+ console . warn ( `Deprecated model param: ${ key } ` ) ;
494
+ }
495
+
496
+ if ( braintrustKey . key === null ) {
497
+ normalizedKey = null ;
498
+ } else {
499
+ normalizedKey = braintrustKey . key ;
500
+ }
501
+ } else {
502
+ normalizedKey = braintrustKey ;
503
+ }
504
+
505
+ if ( normalizedKey === null ) {
506
+ return null ;
507
+ }
508
+
509
+ // XXX if toProvider is undefined, return the normalized key. this is useful for the ui to parse span data when the
510
+ // provider is not known. maybe we can try harder to infer the provider?
511
+ if ( toProvider === undefined ) {
512
+ return normalizedKey ;
513
+ }
514
+
515
+ // XXX turn these into Record<keyof BraintrustModelParams, keyof z.infer<typeof anthropicModelParamsSchema> | null>
516
+ // maps from braintrust key to provider key
517
+ switch ( toProvider ) {
518
+ case "openai" :
519
+ switch ( normalizedKey ) {
520
+ case "temperature" :
521
+ return "temperature" ;
522
+ case "max_tokens" :
523
+ return "max_tokens" ;
524
+ case "top_p" :
525
+ return "top_p" ;
526
+ case "stop" :
527
+ return "stop" ;
528
+ case "frequency_penalty" :
529
+ return "frequency_penalty" ;
530
+ case "presence_penalty" :
531
+ return "presence_penalty" ;
532
+ case "n" :
533
+ return "n" ;
534
+ default :
535
+ return null ;
536
+ }
537
+ case "anthropic" :
538
+ switch ( normalizedKey ) {
539
+ case "temperature" :
540
+ return "temperature" ;
541
+ case "max_tokens" :
542
+ return "max_tokens" ;
543
+ case "top_k" :
544
+ return "top_k" ;
545
+ case "top_p" :
546
+ return "top_p" ;
547
+ case "stop" :
548
+ return "stop_sequences" ;
549
+ default :
550
+ return null ;
551
+ }
552
+ case "google" :
553
+ switch ( normalizedKey ) {
554
+ case "temperature" :
555
+ return "temperature" ;
556
+ case "top_p" :
557
+ return "topP" ;
558
+ case "top_k" :
559
+ return "topK" ;
560
+ /* XXX add support for this?
561
+ case "stop":
562
+ return "stopSequences";
563
+ */
564
+ case "max_tokens" :
565
+ return "maxOutputTokens" ;
566
+ default :
567
+ return null ;
568
+ }
569
+ case "window" :
570
+ switch ( normalizedKey ) {
571
+ case "temperature" :
572
+ return "temperature" ;
573
+ case "top_k" :
574
+ return "topK" ;
575
+ default :
576
+ return null ;
577
+ }
578
+ case "converse" :
579
+ switch ( normalizedKey ) {
580
+ case "temperature" :
581
+ return "temperature" ;
582
+ case "max_tokens" :
583
+ return "maxTokens" ;
584
+ case "top_k" :
585
+ return "topK" ;
586
+ case "top_p" :
587
+ return "topP" ;
588
+ case "stop" :
589
+ return "stopSequences" ;
590
+ default :
591
+ return null ;
592
+ }
593
+ case "js" :
594
+ return null ;
595
+ default :
596
+ const _exhaustiveCheck : never = toProvider ;
597
+ throw new Error ( `Unknown provider: ${ _exhaustiveCheck } ` ) ;
598
+ }
599
+ }
600
+
426
601
export function translateParams (
427
- toProvider : ModelFormat ,
602
+ toProvider : ModelFormat | undefined ,
428
603
params : Record < string , unknown > ,
429
- ) : Record < string , unknown > {
430
- const translatedParams : Record < string , unknown > = { } ;
604
+ ) : Record < keyof ModelParams , unknown > {
605
+ const translatedParams : Record < keyof ModelParams , unknown > = { } ;
431
606
for ( const [ k , v ] of Object . entries ( params || { } ) ) {
432
607
const safeValue = v ?? undefined ; // Don't propagate "null" along
433
- const translatedKey = modelParamToModelParam [ k as keyof ModelParams ] as
434
- | keyof ModelParams
435
- | undefined
436
- | null ;
608
+ const translatedKey = translateKey ( toProvider , k ) ;
437
609
if ( translatedKey === null ) {
438
610
continue ;
439
- } else if (
440
- translatedKey !== undefined &&
441
- defaultModelParamSettings [ toProvider ] [ translatedKey ] !== undefined
442
- ) {
611
+ } else if ( safeValue !== undefined ) {
443
612
translatedParams [ translatedKey ] = safeValue ;
444
- } else {
445
- translatedParams [ k ] = safeValue ;
446
613
}
614
+ // XXX should we add default params from defaultModelParamSettings?
615
+ // probably only do that if translateParams is being called from the prompt ui but not for proxy calls
616
+ //
617
+ // also, the previous logic here seemed incorrect in doing translatedParams[k] = saveValue. i dont
618
+ // see why we would want to pass along params we know are not accepted by toProvider
447
619
}
448
620
449
621
return translatedParams ;
450
622
}
451
-
452
- export const anthropicSupportedMediaTypes = [
453
- "image/jpeg" ,
454
- "image/png" ,
455
- "image/gif" ,
456
- "image/webp" ,
457
- ] ;
458
-
459
- export const anthropicTextBlockSchema = z . object ( {
460
- type : z . literal ( "text" ) . optional ( ) ,
461
- text : z . string ( ) . default ( "" ) ,
462
- } ) ;
463
- export const anthropicImageBlockSchema = z . object ( {
464
- type : z . literal ( "image" ) . optional ( ) ,
465
- source : z . object ( {
466
- type : z . enum ( [ "base64" ] ) . optional ( ) ,
467
- media_type : z . enum ( [ "image/jpeg" , "image/png" , "image/gif" , "image/webp" ] ) ,
468
- data : z . string ( ) . default ( "" ) ,
469
- } ) ,
470
- } ) ;
471
- const anthropicContentBlockSchema = z . union ( [
472
- anthropicTextBlockSchema ,
473
- anthropicImageBlockSchema ,
474
- ] ) ;
475
- const anthropicContentBlocksSchema = z . array ( anthropicContentBlockSchema ) ;
476
- const anthropicContentSchema = z . union ( [
477
- z . string ( ) . default ( "" ) ,
478
- anthropicContentBlocksSchema ,
479
- ] ) ;
480
-
481
- export type AnthropicImageBlock = z . infer < typeof anthropicImageBlockSchema > ;
482
- export type AnthropicContent = z . infer < typeof anthropicContentSchema > ;
0 commit comments