Skip to content

Commit be561bf

Browse files
committed
Use Anthropic count tokens for preflight
1 parent c1883d0 commit be561bf

1 file changed

Lines changed: 55 additions & 6 deletions

File tree

rust/crates/api/src/providers/anthropic.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, Session
1414
use crate::error::ApiError;
1515
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
1616

17-
use super::{preflight_message_request, Provider, ProviderFuture};
17+
use super::{model_token_limit, resolve_model_alias, Provider, ProviderFuture};
1818
use crate::sse::SseParser;
1919
use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
2020

@@ -294,7 +294,7 @@ impl AnthropicClient {
294294
}
295295
}
296296

297-
preflight_message_request(&request)?;
297+
self.preflight_message_request(&request).await?;
298298

299299
let response = self.send_with_retry(&request).await?;
300300
let request_id = request_id_from_headers(response.headers());
@@ -339,7 +339,7 @@ impl AnthropicClient {
339339
&self,
340340
request: &MessageRequest,
341341
) -> Result<MessageStream, ApiError> {
342-
preflight_message_request(request)?;
342+
self.preflight_message_request(request).await?;
343343
let response = self
344344
.send_with_retry(&request.clone().with_streaming())
345345
.await?;
@@ -466,18 +466,67 @@ impl AnthropicClient {
466466
request: &MessageRequest,
467467
) -> Result<reqwest::Response, ApiError> {
468468
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
469+
let request_body = self.request_profile.render_json_body(request)?;
470+
let request_builder = self.build_request(&request_url).json(&request_body);
471+
request_builder.send().await.map_err(ApiError::from)
472+
}
473+
474+
fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder {
469475
let request_builder = self
470476
.http
471-
.post(&request_url)
477+
.post(request_url)
472478
.header("content-type", "application/json");
473479
let mut request_builder = self.auth.apply(request_builder);
474480
for (header_name, header_value) in self.request_profile.header_pairs() {
475481
request_builder = request_builder.header(header_name, header_value);
476482
}
483+
request_builder
484+
}
485+
486+
async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> {
487+
let Some(limit) = model_token_limit(&request.model) else {
488+
return Ok(());
489+
};
490+
491+
let counted_input_tokens = match self.count_tokens(request).await {
492+
Ok(count) => count,
493+
Err(_) => return Ok(()),
494+
};
495+
let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens);
496+
if estimated_total_tokens > limit.context_window_tokens {
497+
return Err(ApiError::ContextWindowExceeded {
498+
model: resolve_model_alias(&request.model),
499+
estimated_input_tokens: counted_input_tokens,
500+
requested_output_tokens: request.max_tokens,
501+
estimated_total_tokens,
502+
context_window_tokens: limit.context_window_tokens,
503+
});
504+
}
477505

506+
Ok(())
507+
}
508+
509+
async fn count_tokens(&self, request: &MessageRequest) -> Result<u32, ApiError> {
510+
#[derive(serde::Deserialize)]
511+
struct CountTokensResponse {
512+
input_tokens: u32,
513+
}
514+
515+
let request_url = format!("{}/v1/messages/count_tokens", self.base_url.trim_end_matches('/'));
478516
let request_body = self.request_profile.render_json_body(request)?;
479-
request_builder = request_builder.json(&request_body);
480-
request_builder.send().await.map_err(ApiError::from)
517+
let response = self
518+
.build_request(&request_url)
519+
.json(&request_body)
520+
.send()
521+
.await
522+
.map_err(ApiError::from)?;
523+
524+
let parsed = expect_success(response)
525+
.await?
526+
.json::<CountTokensResponse>()
527+
.await
528+
.map_err(ApiError::from)?;
529+
Ok(parsed.input_tokens)
481530
}
482531

483532
fn record_request_failure(&self, attempt: u32, error: &ApiError) {

0 commit comments

Comments
 (0)