@@ -1091,4 +1091,200 @@ mod tests {
10911091 Ok ( ( ) )
10921092 }
10931093 }
1094+
1095+ #[ cfg( test) ]
1096+ mod cumulative_token_tests {
1097+ use super :: * ;
1098+ use async_trait:: async_trait;
1099+ use goose:: agents:: { AgentConfig , SessionConfig } ;
1100+ use goose:: config:: permission:: PermissionManager ;
1101+ use goose:: config:: GooseMode ;
1102+ use goose:: conversation:: message:: Message ;
1103+ use goose:: model:: ModelConfig ;
1104+ use goose:: providers:: base:: {
1105+ stream_from_single_message, MessageStream , Provider , ProviderDef , ProviderMetadata ,
1106+ ProviderUsage , Usage ,
1107+ } ;
1108+ use goose:: providers:: errors:: ProviderError ;
1109+ use goose:: session:: session_manager:: SessionType ;
1110+ use goose:: session:: SessionManager ;
1111+ use rmcp:: model:: Tool ;
1112+ use std:: path:: PathBuf ;
1113+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1114+ use std:: sync:: Arc ;
1115+
1116+ /// Mock provider that reports fixed token usage on every turn.
1117+ struct FixedUsageProvider {
1118+ call_count : AtomicUsize ,
1119+ input_tokens : i32 ,
1120+ output_tokens : i32 ,
1121+ }
1122+
1123+ impl FixedUsageProvider {
1124+ fn new ( input_tokens : i32 , output_tokens : i32 ) -> Self {
1125+ Self {
1126+ call_count : AtomicUsize :: new ( 0 ) ,
1127+ input_tokens,
1128+ output_tokens,
1129+ }
1130+ }
1131+ }
1132+
1133+ impl ProviderDef for FixedUsageProvider {
1134+ type Provider = Self ;
1135+
1136+ fn metadata ( ) -> ProviderMetadata {
1137+ ProviderMetadata {
1138+ name : "fixed-usage-mock" . to_string ( ) ,
1139+ display_name : "Fixed Usage Mock" . to_string ( ) ,
1140+ description : "Mock provider with fixed token usage for cumulative tests" . to_string ( ) ,
1141+ default_model : "mock-model" . to_string ( ) ,
1142+ known_models : vec ! [ ] ,
1143+ model_doc_link : "" . to_string ( ) ,
1144+ config_keys : vec ! [ ] ,
1145+ setup_steps : vec ! [ ] ,
1146+ model_selection_hint : None ,
1147+ }
1148+ }
1149+
1150+ fn from_env (
1151+ _model : ModelConfig ,
1152+ _extensions : Vec < goose:: config:: ExtensionConfig > ,
1153+ ) -> futures:: future:: BoxFuture < ' static , anyhow:: Result < Self > > {
1154+ Box :: pin ( async { Ok ( Self :: new ( 10 , 5 ) ) } )
1155+ }
1156+ }
1157+
1158+ #[ async_trait]
1159+ impl Provider for FixedUsageProvider {
1160+ async fn stream (
1161+ & self ,
1162+ _model_config : & ModelConfig ,
1163+ _session_id : & str ,
1164+ _system_prompt : & str ,
1165+ _messages : & [ Message ] ,
1166+ _tools : & [ Tool ] ,
1167+ ) -> Result < MessageStream , ProviderError > {
1168+ let _call = self . call_count . fetch_add ( 1 , Ordering :: SeqCst ) ;
1169+ let total = self . input_tokens + self . output_tokens ;
1170+ let usage = ProviderUsage :: new (
1171+ "mock-model" . to_string ( ) ,
1172+ Usage :: new ( Some ( self . input_tokens ) , Some ( self . output_tokens ) , Some ( total) ) ,
1173+ ) ;
1174+ let message = Message :: assistant ( ) . with_text ( "Hello" ) ;
1175+ Ok ( stream_from_single_message ( message, usage) )
1176+ }
1177+
1178+ fn get_model_config ( & self ) -> ModelConfig {
1179+ ModelConfig :: new ( "mock-model" ) . unwrap ( )
1180+ }
1181+
1182+ fn get_name ( & self ) -> & str {
1183+ "fixed-usage-mock"
1184+ }
1185+ }
1186+
1187+ #[ tokio:: test]
1188+ async fn test_accumulated_total_tokens_across_multiple_turns ( ) -> Result < ( ) > {
1189+ let temp_dir = tempfile:: tempdir ( ) ?;
1190+ let session_manager = Arc :: new ( SessionManager :: new ( temp_dir. path ( ) . to_path_buf ( ) ) ) ;
1191+ let config = AgentConfig :: new (
1192+ session_manager. clone ( ) ,
1193+ PermissionManager :: instance ( ) ,
1194+ None ,
1195+ GooseMode :: Auto ,
1196+ true , // disable session naming
1197+ GoosePlatform :: GooseCli ,
1198+ ) ;
1199+ let agent = Agent :: with_config ( config) ;
1200+ let provider = Arc :: new ( FixedUsageProvider :: new ( 10 , 5 ) ) ;
1201+
1202+ let session = session_manager
1203+ . create_session (
1204+ PathBuf :: default ( ) ,
1205+ "cumulative-token-test" . to_string ( ) ,
1206+ SessionType :: Hidden ,
1207+ GooseMode :: default ( ) ,
1208+ )
1209+ . await ?;
1210+
1211+ let session_id = session. id . clone ( ) ;
1212+ agent. update_provider ( provider. clone ( ) , & session_id) . await ?;
1213+
1214+ // Turn 1
1215+ let session_config1 = SessionConfig {
1216+ id : session_id. clone ( ) ,
1217+ schedule_id : None ,
1218+ max_turns : Some ( 1 ) ,
1219+ retry_config : None ,
1220+ } ;
1221+ let reply_stream1 = agent
1222+ . reply ( Message :: user ( ) . with_text ( "Turn 1" ) , session_config1, None )
1223+ . await ?;
1224+ tokio:: pin!( reply_stream1) ;
1225+ while let Some ( event) = reply_stream1. next ( ) . await {
1226+ let _ = event?;
1227+ }
1228+
1229+ let session_after_1 = session_manager. get_session ( & session_id, false ) . await ?;
1230+ assert_eq ! (
1231+ session_after_1. accumulated_total_tokens,
1232+ Some ( 15 ) ,
1233+ "After turn 1, accumulated_total_tokens should be 15"
1234+ ) ;
1235+
1236+ // Turn 2
1237+ let session_config2 = SessionConfig {
1238+ id : session_id. clone ( ) ,
1239+ schedule_id : None ,
1240+ max_turns : Some ( 1 ) ,
1241+ retry_config : None ,
1242+ } ;
1243+ let reply_stream2 = agent
1244+ . reply ( Message :: user ( ) . with_text ( "Turn 2" ) , session_config2, None )
1245+ . await ?;
1246+ tokio:: pin!( reply_stream2) ;
1247+ while let Some ( event) = reply_stream2. next ( ) . await {
1248+ let _ = event?;
1249+ }
1250+
1251+ let session_after_2 = session_manager. get_session ( & session_id, false ) . await ?;
1252+ assert_eq ! (
1253+ session_after_2. accumulated_total_tokens,
1254+ Some ( 30 ) ,
1255+ "After turn 2, accumulated_total_tokens should be 30"
1256+ ) ;
1257+
1258+ // Turn 3
1259+ let session_config3 = SessionConfig {
1260+ id : session_id. clone ( ) ,
1261+ schedule_id : None ,
1262+ max_turns : Some ( 1 ) ,
1263+ retry_config : None ,
1264+ } ;
1265+ let reply_stream3 = agent
1266+ . reply ( Message :: user ( ) . with_text ( "Turn 3" ) , session_config3, None )
1267+ . await ?;
1268+ tokio:: pin!( reply_stream3) ;
1269+ while let Some ( event) = reply_stream3. next ( ) . await {
1270+ let _ = event?;
1271+ }
1272+
1273+ let session_after_3 = session_manager. get_session ( & session_id, false ) . await ?;
1274+ assert_eq ! (
1275+ session_after_3. accumulated_total_tokens,
1276+ Some ( 45 ) ,
1277+ "After turn 3, accumulated_total_tokens should be 45"
1278+ ) ;
1279+
1280+ // total_tokens should still reflect the last turn only
1281+ assert_eq ! (
1282+ session_after_3. total_tokens,
1283+ Some ( 15 ) ,
1284+ "total_tokens should reflect last turn only (15)"
1285+ ) ;
1286+
1287+ Ok ( ( ) )
1288+ }
1289+ }
10941290}
0 commit comments