Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions atoma-proxy/src/server/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use openai_api::{
};
use openai_api::{CreateChatCompletionRequest, CreateChatCompletionStreamRequest};
use opentelemetry::KeyValue;
use reqwest::StatusCode;
use serde::Deserialize;
use serde_json::Value;
use sqlx::types::chrono::{DateTime, Utc};
Expand All @@ -46,9 +47,10 @@ use super::metrics::{
CHAT_COMPLETIONS_INPUT_TOKENS, CHAT_COMPLETIONS_INPUT_TOKENS_PER_USER,
CHAT_COMPLETIONS_LATENCY_METRICS, CHAT_COMPLETIONS_NUM_REQUESTS, CHAT_COMPLETIONS_TOTAL_TOKENS,
CHAT_COMPLETIONS_TOTAL_TOKENS_PER_USER, CHAT_COMPLETION_REQUESTS_PER_USER,
INTENTIONALLY_CANCELLED_CHAT_COMPLETION_STREAMING_REQUESTS, TOTAL_COMPLETED_REQUESTS,
TOTAL_FAILED_CHAT_REQUESTS, TOTAL_FAILED_REQUESTS,
UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER,
INTENTIONALLY_CANCELLED_CHAT_COMPLETION_STREAMING_REQUESTS, TOTAL_BAD_REQUESTS,
TOTAL_COMPLETED_REQUESTS, TOTAL_FAILED_CHAT_REQUESTS, TOTAL_FAILED_REQUESTS,
TOTAL_LOCKED_REQUESTS, TOTAL_TOO_EARLY_REQUESTS, TOTAL_TOO_MANY_REQUESTS,
TOTAL_UNAUTHORIZED_REQUESTS, UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER,
};
use super::request_model::{ComputeUnitsEstimate, RequestModel};
use super::{
Expand Down Expand Up @@ -78,6 +80,12 @@ pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
/// The messages field in the request payload.
const MESSAGES: &str = "messages";

/// The model key
const MODEL_KEY: &str = "model";

/// The user id key
const USER_ID_KEY: &str = "user_id";

#[derive(OpenApi)]
#[openapi(
paths(chat_completions_create, chat_completions_create_stream),
Expand Down Expand Up @@ -175,12 +183,32 @@ pub async fn chat_completions_create(
Ok(response)
}
Err(e) => {
TOTAL_FAILED_CHAT_REQUESTS
.add(1, &[KeyValue::new("model", metadata.model_name.clone())]);
TOTAL_FAILED_REQUESTS
.add(1, &[KeyValue::new("model", metadata.model_name.clone())]);
UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new("user_id", metadata.user_id)]);
let model = metadata.model_name.clone();
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
_ => {
TOTAL_FAILED_CHAT_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);

UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new(USER_ID_KEY, metadata.user_id)]);
}
}
if let Some(stack_small_id) = metadata.selected_stack_small_id {
update_state_manager(
&state.state_manager_sender,
Expand Down
90 changes: 70 additions & 20 deletions atoma-proxy/src/server/handlers/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use openai_api_completions::{
Usage,
};
use opentelemetry::KeyValue;
use reqwest::StatusCode;
use serde::Deserialize;
use serde_json::Value;
use sqlx::types::chrono::{DateTime, Utc};
Expand All @@ -31,11 +32,14 @@ use utoipa::OpenApi;

use super::metrics::{
CHAT_COMPLETIONS_COMPLETIONS_TOKENS, CHAT_COMPLETIONS_COMPLETIONS_TOKENS_PER_USER,
CHAT_COMPLETIONS_INPUT_TOKENS, CHAT_COMPLETIONS_INPUT_TOKENS_PER_USER,
CHAT_COMPLETIONS_LATENCY_METRICS, CHAT_COMPLETIONS_NUM_REQUESTS, CHAT_COMPLETIONS_TOTAL_TOKENS,
CHAT_COMPLETIONS_CONFIDENTIAL_NUM_REQUESTS, CHAT_COMPLETIONS_INPUT_TOKENS,
CHAT_COMPLETIONS_INPUT_TOKENS_PER_USER, CHAT_COMPLETIONS_LATENCY_METRICS,
CHAT_COMPLETIONS_NUM_REQUESTS, CHAT_COMPLETIONS_TOTAL_TOKENS,
CHAT_COMPLETIONS_TOTAL_TOKENS_PER_USER, CHAT_COMPLETION_REQUESTS_PER_USER,
INTENTIONALLY_CANCELLED_CHAT_COMPLETION_STREAMING_REQUESTS, TOTAL_COMPLETED_REQUESTS,
TOTAL_FAILED_CHAT_REQUESTS, TOTAL_FAILED_REQUESTS,
INTENTIONALLY_CANCELLED_CHAT_COMPLETION_STREAMING_REQUESTS, TOTAL_BAD_REQUESTS,
TOTAL_COMPLETED_REQUESTS, TOTAL_FAILED_CHAT_CONFIDENTIAL_REQUESTS, TOTAL_FAILED_CHAT_REQUESTS,
TOTAL_FAILED_REQUESTS, TOTAL_LOCKED_REQUESTS, TOTAL_TOO_EARLY_REQUESTS,
TOTAL_TOO_MANY_REQUESTS, TOTAL_UNAUTHORIZED_REQUESTS,
UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER,
};
use super::request_model::{ComputeUnitsEstimate, RequestModel};
Expand All @@ -57,6 +61,12 @@ pub const CONFIDENTIAL_COMPLETIONS_PATH: &str = "/v1/confidential/completions";
/// The key for the prompt in the request.
const PROMPT: &str = "prompt";

/// The model key
const MODEL_KEY: &str = "model";

/// The user id key
const USER_ID_KEY: &str = "user_id";

/// The OpenAPI schema for the completions endpoint.
#[derive(OpenApi)]
#[openapi(
Expand Down Expand Up @@ -133,12 +143,33 @@ pub async fn completions_create(
Ok(response)
}
Err(e) => {
TOTAL_FAILED_CHAT_REQUESTS
.add(1, &[KeyValue::new("model", metadata.model_name.clone())]);
TOTAL_FAILED_REQUESTS
.add(1, &[KeyValue::new("model", metadata.model_name.clone())]);
UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new("user_id", metadata.user_id)]);
let model = metadata.model_name.clone();
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
_ => {
TOTAL_FAILED_CHAT_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);

UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new(USER_ID_KEY, metadata.user_id)]);
}
}

if let Some(stack_small_id) = metadata.selected_stack_small_id {
update_state_manager(
&state.state_manager_sender,
Expand Down Expand Up @@ -381,19 +412,38 @@ pub async fn confidential_completions_create(
Ok(response) => {
if !is_streaming {
// The streaming metric is recorded in the streamer (final chunk)
TOTAL_COMPLETED_REQUESTS.add(1, &[KeyValue::new("model", metadata.model_name)]);
CHAT_COMPLETIONS_CONFIDENTIAL_NUM_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, metadata.model_name)]);
}
Ok(response)
}
Err(e) => {
let model_label: String = metadata.model_name.clone();
TOTAL_FAILED_CHAT_REQUESTS.add(1, &[KeyValue::new("model", model_label.clone())]);

// Record the failed request in the total failed requests metric
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new("model", model_label)]);
UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new("user_id", metadata.user_id)]);
let model = metadata.model_name.clone();
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
_ => {
TOTAL_FAILED_CHAT_CONFIDENTIAL_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);

UNSUCCESSFUL_CHAT_COMPLETION_REQUESTS_PER_USER
.add(1, &[KeyValue::new(USER_ID_KEY, metadata.user_id)]);
}
}
if let Some(stack_small_id) = metadata.selected_stack_small_id {
update_state_manager(
&state.state_manager_sender,
Expand Down Expand Up @@ -670,7 +720,7 @@ async fn handle_non_streaming_response(
/// * `node_address` - The address of the node
/// * `user_id` - The user id
/// * `headers` - The headers of the request
/// * `payload` - The payload of the request
/// * `payload` - The payload of the request
/// * `num_input_tokens` - The number of input tokens
/// * `estimated_output_tokens` - The estimated output tokens
/// * `price_per_million` - The price per million
Expand All @@ -687,7 +737,7 @@ async fn handle_non_streaming_response(
/// * `serde_json::Error` - If the request fails
/// * `flume::Error` - If the request fails
/// * `tokio::Error` - If the request fails
///
///
#[instrument(
level = "info",
skip_all,
Expand Down
75 changes: 63 additions & 12 deletions atoma-proxy/src/server/handlers/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::{
Extension, Json,
};
use opentelemetry::KeyValue;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::types::chrono::{DateTime, Utc};
Expand All @@ -28,8 +29,10 @@ use super::{
handle_status_code_error,
metrics::{
EMBEDDING_TOTAL_TOKENS_PER_USER, SUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER,
TEXT_EMBEDDINGS_LATENCY_METRICS, TEXT_EMBEDDINGS_NUM_REQUESTS, TOTAL_COMPLETED_REQUESTS,
TOTAL_FAILED_REQUESTS, TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS,
TEXT_EMBEDDINGS_LATENCY_METRICS, TEXT_EMBEDDINGS_NUM_REQUESTS, TOTAL_BAD_REQUESTS,
TOTAL_COMPLETED_REQUESTS, TOTAL_FAILED_CONFIDENTIAL_EMBEDDING_REQUESTS,
TOTAL_FAILED_REQUESTS, TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS, TOTAL_LOCKED_REQUESTS,
TOTAL_TOO_EARLY_REQUESTS, TOTAL_TOO_MANY_REQUESTS, TOTAL_UNAUTHORIZED_REQUESTS,
UNSUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER,
},
request_model::{ComputeUnitsEstimate, RequestModel},
Expand All @@ -52,6 +55,12 @@ pub const EMBEDDINGS_PATH: &str = "/v1/embeddings";
/// The input field in the request payload.
const INPUT: &str = "input";

/// The model key
const MODEL_KEY: &str = "model";

/// The user id key
const USER_ID_KEY: &str = "user_id";

// A model representing an embeddings request payload.
///
/// This struct encapsulates the necessary fields for processing an embeddings request
Expand Down Expand Up @@ -224,11 +233,32 @@ pub async fn embeddings_create(
Ok(Json(response).into_response())
}
Err(e) => {
let model_label: String = metadata.model_name.clone();
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new("model", model_label.clone())]);
TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS.add(1, &[KeyValue::new("model", model_label)]);
UNSUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER
.add(1, &[KeyValue::new("user_id", metadata.user_id)]);
let model = metadata.model_name.clone();
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
_ => {
TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);

UNSUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER
.add(1, &[KeyValue::new(USER_ID_KEY, metadata.user_id)]);
}
}
match metadata.selected_stack_small_id {
Some(stack_small_id) => {
update_state_manager(
Expand Down Expand Up @@ -374,11 +404,32 @@ pub async fn confidential_embeddings_create(
Ok(Json(response).into_response())
}
Err(e) => {
let model_label: String = metadata.model_name.clone();
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new("model", model_label.clone())]);
TOTAL_FAILED_TEXT_EMBEDDING_REQUESTS.add(1, &[KeyValue::new("model", model_label)]);
UNSUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER
.add(1, &[KeyValue::new("user_id", metadata.user_id)]);
let model = metadata.model_name.clone();
match e.status_code() {
StatusCode::TOO_MANY_REQUESTS => {
TOTAL_TOO_MANY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::BAD_REQUEST => {
TOTAL_BAD_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::LOCKED => {
TOTAL_LOCKED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::TOO_EARLY => {
TOTAL_TOO_EARLY_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
StatusCode::UNAUTHORIZED => {
TOTAL_UNAUTHORIZED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);
}
_ => {
TOTAL_FAILED_CONFIDENTIAL_EMBEDDING_REQUESTS
.add(1, &[KeyValue::new(MODEL_KEY, model.clone())]);
TOTAL_FAILED_REQUESTS.add(1, &[KeyValue::new(MODEL_KEY, model)]);

UNSUCCESSFUL_TEXT_EMBEDDING_REQUESTS_PER_USER
.add(1, &[KeyValue::new(USER_ID_KEY, metadata.user_id)]);
}
}

match metadata.selected_stack_small_id {
Some(stack_small_id) => {
Expand Down
Loading