@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
1414use crate :: error:: ApiError ;
1515use crate :: prompt_cache:: { PromptCache , PromptCacheRecord , PromptCacheStats } ;
1616
17- use super :: { preflight_message_request , Provider , ProviderFuture } ;
17+ use super :: { model_token_limit , resolve_model_alias , Provider , ProviderFuture } ;
1818use crate :: sse:: SseParser ;
1919use crate :: types:: { MessageDeltaEvent , MessageRequest , MessageResponse , StreamEvent , Usage } ;
2020
@@ -294,7 +294,7 @@ impl AnthropicClient {
294294 }
295295 }
296296
297- preflight_message_request ( & request) ?;
297+ self . preflight_message_request ( & request) . await ?;
298298
299299 let response = self . send_with_retry ( & request) . await ?;
300300 let request_id = request_id_from_headers ( response. headers ( ) ) ;
@@ -339,7 +339,7 @@ impl AnthropicClient {
339339 & self ,
340340 request : & MessageRequest ,
341341 ) -> Result < MessageStream , ApiError > {
342- preflight_message_request ( request) ?;
342+ self . preflight_message_request ( request) . await ?;
343343 let response = self
344344 . send_with_retry ( & request. clone ( ) . with_streaming ( ) )
345345 . await ?;
@@ -466,18 +466,67 @@ impl AnthropicClient {
466466 request : & MessageRequest ,
467467 ) -> Result < reqwest:: Response , ApiError > {
468468 let request_url = format ! ( "{}/v1/messages" , self . base_url. trim_end_matches( '/' ) ) ;
469+ let request_body = self . request_profile . render_json_body ( request) ?;
470+ let request_builder = self . build_request ( & request_url) . json ( & request_body) ;
471+ request_builder. send ( ) . await . map_err ( ApiError :: from)
472+ }
473+
474+ fn build_request ( & self , request_url : & str ) -> reqwest:: RequestBuilder {
469475 let request_builder = self
470476 . http
471- . post ( & request_url)
477+ . post ( request_url)
472478 . header ( "content-type" , "application/json" ) ;
473479 let mut request_builder = self . auth . apply ( request_builder) ;
474480 for ( header_name, header_value) in self . request_profile . header_pairs ( ) {
475481 request_builder = request_builder. header ( header_name, header_value) ;
476482 }
483+ request_builder
484+ }
485+
486+ async fn preflight_message_request ( & self , request : & MessageRequest ) -> Result < ( ) , ApiError > {
487+ let Some ( limit) = model_token_limit ( & request. model ) else {
488+ return Ok ( ( ) ) ;
489+ } ;
490+
491+ let counted_input_tokens = match self . count_tokens ( request) . await {
492+ Ok ( count) => count,
493+ Err ( _) => return Ok ( ( ) ) ,
494+ } ;
495+ let estimated_total_tokens = counted_input_tokens. saturating_add ( request. max_tokens ) ;
496+ if estimated_total_tokens > limit. context_window_tokens {
497+ return Err ( ApiError :: ContextWindowExceeded {
498+ model : resolve_model_alias ( & request. model ) ,
499+ estimated_input_tokens : counted_input_tokens,
500+ requested_output_tokens : request. max_tokens ,
501+ estimated_total_tokens,
502+ context_window_tokens : limit. context_window_tokens ,
503+ } ) ;
504+ }
477505
506+ Ok ( ( ) )
507+ }
508+
509+ async fn count_tokens ( & self , request : & MessageRequest ) -> Result < u32 , ApiError > {
510+ #[ derive( serde:: Deserialize ) ]
511+ struct CountTokensResponse {
512+ input_tokens : u32 ,
513+ }
514+
515+ let request_url = format ! ( "{}/v1/messages/count_tokens" , self . base_url. trim_end_matches( '/' ) ) ;
478516 let request_body = self . request_profile . render_json_body ( request) ?;
479- request_builder = request_builder. json ( & request_body) ;
480- request_builder. send ( ) . await . map_err ( ApiError :: from)
517+ let response = self
518+ . build_request ( & request_url)
519+ . json ( & request_body)
520+ . send ( )
521+ . await
522+ . map_err ( ApiError :: from) ?;
523+
524+ let parsed = expect_success ( response)
525+ . await ?
526+ . json :: < CountTokensResponse > ( )
527+ . await
528+ . map_err ( ApiError :: from) ?;
529+ Ok ( parsed. input_tokens )
481530 }
482531
483532 fn record_request_failure ( & self , attempt : u32 , error : & ApiError ) {
0 commit comments