@@ -18,6 +18,7 @@ use utoipa::OpenApi;
1818use crate :: error:: AiGatewayError ;
1919use crate :: handlers:: types:: AiGatewayAppState ;
2020use crate :: services:: gateway_service:: { ByokOverride , CredentialType } ;
21+ use crate :: services:: usage_service:: AiRequestContext ;
2122use crate :: services:: UsageService ;
2223use crate :: types:: * ;
2324
@@ -35,6 +36,37 @@ fn extract_byok(headers: &HeaderMap) -> ByokOverride {
3536 }
3637}
3738
39+ /// Extract AI request context (conversation, tags, trace) from request headers.
40+ fn extract_ai_context ( headers : & HeaderMap ) -> AiRequestContext {
41+ AiRequestContext {
42+ conversation_id : headers
43+ . get ( "x-conversation-id" )
44+ . and_then ( |v| v. to_str ( ) . ok ( ) )
45+ . map ( |s| s. to_string ( ) ) ,
46+ tags : headers
47+ . get ( "x-tags" )
48+ . and_then ( |v| v. to_str ( ) . ok ( ) )
49+ . map ( |s| {
50+ s. split ( ',' )
51+ . map ( |t| t. trim ( ) . to_string ( ) )
52+ . filter ( |t| !t. is_empty ( ) )
53+ . collect ( )
54+ } )
55+ . unwrap_or_default ( ) ,
56+ request_id : headers
57+ . get ( "x-request-id" )
58+ . and_then ( |v| v. to_str ( ) . ok ( ) )
59+ . map ( |s| s. to_string ( ) ) ,
60+ trace_id : headers
61+ . get ( "traceparent" )
62+ . and_then ( |v| v. to_str ( ) . ok ( ) )
63+ . and_then ( |tp| {
64+ // W3C traceparent: {version}-{trace-id}-{parent-id}-{flags}
65+ tp. split ( '-' ) . nth ( 1 ) . map ( String :: from)
66+ } ) ,
67+ }
68+ }
69+
3870fn credential_type_str ( ct : CredentialType ) -> & ' static str {
3971 match ct {
4072 CredentialType :: System => "system" ,
@@ -74,6 +106,7 @@ fn extract_usage_from_sse_line(line: &str) -> Option<(i64, i64)> {
74106
75107/// Wraps an upstream SSE byte stream to transparently intercept usage data
76108/// from the final chunks, then logs it after the stream ends.
109+ #[ allow( clippy:: too_many_arguments) ]
77110fn wrap_stream_with_usage_tracking (
78111 inner : std:: pin:: Pin <
79112 Box < dyn tokio_stream:: Stream < Item = Result < Bytes , AiGatewayError > > + Send > ,
@@ -84,6 +117,7 @@ fn wrap_stream_with_usage_tracking(
84117 model : String ,
85118 start : Instant ,
86119 is_byok : bool ,
120+ ai_context : AiRequestContext ,
87121) -> std:: pin:: Pin < Box < dyn tokio_stream:: Stream < Item = Result < Bytes , AiGatewayError > > + Send > > {
88122 use tokio_stream:: StreamExt ;
89123
@@ -130,7 +164,7 @@ fn wrap_stream_with_usage_tracking(
130164 if input > 0 || output > 0 {
131165 tokio:: spawn ( async move {
132166 if let Err ( e) = usage_service
133- . log_usage (
167+ . log_usage_with_context (
134168 Some ( user_id) ,
135169 & provider,
136170 & model,
@@ -141,6 +175,7 @@ fn wrap_stream_with_usage_tracking(
141175 200 ,
142176 true , // streaming
143177 is_byok,
178+ & ai_context,
144179 )
145180 . await
146181 {
@@ -358,6 +393,7 @@ async fn chat_completions(
358393 }
359394
360395 let byok = extract_byok ( & headers) ;
396+ let ai_context = extract_ai_context ( & headers) ;
361397 let start = Instant :: now ( ) ;
362398 let model = request. model . clone ( ) ;
363399 let is_streaming = request. stream ;
@@ -389,6 +425,7 @@ async fn chat_completions(
389425 model. clone ( ) ,
390426 start,
391427 cred_type == CredentialType :: Byok ,
428+ ai_context. clone ( ) ,
392429 ) ;
393430 let body = Body :: from_stream ( wrapped) ;
394431
@@ -436,9 +473,10 @@ async fn chat_completions(
436473 let output = usage. completion_tokens ;
437474 let latency_ms = latency. as_millis ( ) as i32 ;
438475 let is_byok = cred_type == CredentialType :: Byok ;
476+ let ctx = ai_context. clone ( ) ;
439477 tokio:: spawn ( async move {
440478 if let Err ( e) = usage_service
441- . log_usage (
479+ . log_usage_with_context (
442480 Some ( user_id) ,
443481 & provider_clone,
444482 & model_clone,
@@ -449,6 +487,7 @@ async fn chat_completions(
449487 200 ,
450488 false , // non-streaming path
451489 is_byok,
490+ & ctx,
452491 )
453492 . await
454493 {
@@ -894,4 +933,101 @@ mod tests {
894933 let line = r#"data: {"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}"# ;
895934 assert_eq ! ( extract_usage_from_sse_line( line) , None ) ;
896935 }
936+
937+ #[ test]
938+ fn test_extract_ai_context_empty_headers ( ) {
939+ let headers = HeaderMap :: new ( ) ;
940+ let ctx = extract_ai_context ( & headers) ;
941+ assert ! ( ctx. conversation_id. is_none( ) ) ;
942+ assert ! ( ctx. tags. is_empty( ) ) ;
943+ assert ! ( ctx. request_id. is_none( ) ) ;
944+ assert ! ( ctx. trace_id. is_none( ) ) ;
945+ }
946+
947+ #[ test]
948+ fn test_extract_ai_context_conversation_id ( ) {
949+ let mut headers = HeaderMap :: new ( ) ;
950+ headers. insert ( "x-conversation-id" , "conv_abc123" . parse ( ) . unwrap ( ) ) ;
951+ let ctx = extract_ai_context ( & headers) ;
952+ assert_eq ! ( ctx. conversation_id. as_deref( ) , Some ( "conv_abc123" ) ) ;
953+ }
954+
955+ #[ test]
956+ fn test_extract_ai_context_tags ( ) {
957+ let mut headers = HeaderMap :: new ( ) ;
958+ headers. insert ( "x-tags" , "agent:support, env:prod" . parse ( ) . unwrap ( ) ) ;
959+ let ctx = extract_ai_context ( & headers) ;
960+ assert_eq ! ( ctx. tags, vec![ "agent:support" , "env:prod" ] ) ;
961+ }
962+
963+ #[ test]
964+ fn test_extract_ai_context_tags_trims_whitespace ( ) {
965+ let mut headers = HeaderMap :: new ( ) ;
966+ headers. insert ( "x-tags" , " foo , bar , baz " . parse ( ) . unwrap ( ) ) ;
967+ let ctx = extract_ai_context ( & headers) ;
968+ assert_eq ! ( ctx. tags, vec![ "foo" , "bar" , "baz" ] ) ;
969+ }
970+
971+ #[ test]
972+ fn test_extract_ai_context_tags_filters_empty ( ) {
973+ let mut headers = HeaderMap :: new ( ) ;
974+ headers. insert ( "x-tags" , "foo,,bar," . parse ( ) . unwrap ( ) ) ;
975+ let ctx = extract_ai_context ( & headers) ;
976+ assert_eq ! ( ctx. tags, vec![ "foo" , "bar" ] ) ;
977+ }
978+
979+ #[ test]
980+ fn test_extract_ai_context_request_id ( ) {
981+ let mut headers = HeaderMap :: new ( ) ;
982+ headers. insert ( "x-request-id" , "req_xyz789" . parse ( ) . unwrap ( ) ) ;
983+ let ctx = extract_ai_context ( & headers) ;
984+ assert_eq ! ( ctx. request_id. as_deref( ) , Some ( "req_xyz789" ) ) ;
985+ }
986+
987+ #[ test]
988+ fn test_extract_ai_context_traceparent ( ) {
989+ let mut headers = HeaderMap :: new ( ) ;
990+ headers. insert (
991+ "traceparent" ,
992+ "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
993+ . parse ( )
994+ . unwrap ( ) ,
995+ ) ;
996+ let ctx = extract_ai_context ( & headers) ;
997+ assert_eq ! (
998+ ctx. trace_id. as_deref( ) ,
999+ Some ( "0af7651916cd43dd8448eb211c80319c" )
1000+ ) ;
1001+ }
1002+
1003+ #[ test]
1004+ fn test_extract_ai_context_invalid_traceparent ( ) {
1005+ let mut headers = HeaderMap :: new ( ) ;
1006+ headers. insert ( "traceparent" , "not-valid" . parse ( ) . unwrap ( ) ) ;
1007+ let ctx = extract_ai_context ( & headers) ;
1008+ // "not-valid" split by '-': ["not", "valid"] -> nth(1) = "valid"
1009+ assert_eq ! ( ctx. trace_id. as_deref( ) , Some ( "valid" ) ) ;
1010+ }
1011+
1012+ #[ test]
1013+ fn test_extract_ai_context_all_headers ( ) {
1014+ let mut headers = HeaderMap :: new ( ) ;
1015+ headers. insert ( "x-conversation-id" , "conv_123" . parse ( ) . unwrap ( ) ) ;
1016+ headers. insert ( "x-tags" , "agent:bot,tier:premium" . parse ( ) . unwrap ( ) ) ;
1017+ headers. insert ( "x-request-id" , "req_456" . parse ( ) . unwrap ( ) ) ;
1018+ headers. insert (
1019+ "traceparent" ,
1020+ "00-abcdef1234567890abcdef1234567890-1234567890abcdef-01"
1021+ . parse ( )
1022+ . unwrap ( ) ,
1023+ ) ;
1024+ let ctx = extract_ai_context ( & headers) ;
1025+ assert_eq ! ( ctx. conversation_id. as_deref( ) , Some ( "conv_123" ) ) ;
1026+ assert_eq ! ( ctx. tags, vec![ "agent:bot" , "tier:premium" ] ) ;
1027+ assert_eq ! ( ctx. request_id. as_deref( ) , Some ( "req_456" ) ) ;
1028+ assert_eq ! (
1029+ ctx. trace_id. as_deref( ) ,
1030+ Some ( "abcdef1234567890abcdef1234567890" )
1031+ ) ;
1032+ }
8971033}
0 commit comments