Skip to content

Commit 0b54c1d

Browse files
committed
fix(cli): report cumulative total_tokens in stream-json/json output
The CLI's --json and --stream-json output modes were reading session.total_tokens, which stores only the last turn's token usage. This caused total_tokens to reset to the current turn's count on every chunk, rather than reporting the running session total. Fix by reading session.accumulated_total_tokens, which is already correctly maintained by update_session_metrics across all turns. Also update get_total_token_usage() to return the accumulated value for consistency. Test: add test_accumulated_total_tokens_across_multiple_turns asserting that accumulated_total_tokens grows cumulatively across turns while total_tokens remains per-turn. Fixes #8871 Signed-off-by: Trinity <trinity@multica.ai> Signed-off-by: Bright Zheng <bzqzheng@gmail.com>
1 parent 503ad20 commit 0b54c1d

2 files changed

Lines changed: 199 additions & 3 deletions

File tree

crates/goose-cli/src/session/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ impl CliSession {
12651265
.await
12661266
{
12671267
Ok(session) => JsonMetadata {
1268-
total_tokens: session.total_tokens,
1268+
total_tokens: session.accumulated_total_tokens,
12691269
status: "completed".to_string(),
12701270
},
12711271
Err(_) => JsonMetadata {
@@ -1286,7 +1286,7 @@ impl CliSession {
12861286
.get_session(&self.session_id, false)
12871287
.await
12881288
.ok()
1289-
.and_then(|s| s.total_tokens);
1289+
.and_then(|s| s.accumulated_total_tokens);
12901290
emit_stream_event(&StreamEvent::Complete { total_tokens });
12911291
} else {
12921292
println!();
@@ -1445,7 +1445,7 @@ impl CliSession {
14451445
// Get the session's total token usage
14461446
pub async fn get_total_token_usage(&self) -> Result<Option<i32>> {
14471447
let metadata = self.get_session().await?;
1448-
Ok(metadata.total_tokens)
1448+
Ok(metadata.accumulated_total_tokens)
14491449
}
14501450

14511451
/// Display enhanced context usage with session totals

crates/goose/tests/agent.rs

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,4 +1091,200 @@ mod tests {
10911091
Ok(())
10921092
}
10931093
}
1094+
1095+
#[cfg(test)]
1096+
mod cumulative_token_tests {
1097+
use super::*;
1098+
use async_trait::async_trait;
1099+
use goose::agents::{AgentConfig, SessionConfig};
1100+
use goose::config::permission::PermissionManager;
1101+
use goose::config::GooseMode;
1102+
use goose::conversation::message::Message;
1103+
use goose::model::ModelConfig;
1104+
use goose::providers::base::{
1105+
stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata,
1106+
ProviderUsage, Usage,
1107+
};
1108+
use goose::providers::errors::ProviderError;
1109+
use goose::session::session_manager::SessionType;
1110+
use goose::session::SessionManager;
1111+
use rmcp::model::Tool;
1112+
use std::path::PathBuf;
1113+
use std::sync::atomic::{AtomicUsize, Ordering};
1114+
use std::sync::Arc;
1115+
1116+
/// Mock provider that reports fixed token usage on every turn.
1117+
struct FixedUsageProvider {
1118+
call_count: AtomicUsize,
1119+
input_tokens: i32,
1120+
output_tokens: i32,
1121+
}
1122+
1123+
impl FixedUsageProvider {
1124+
fn new(input_tokens: i32, output_tokens: i32) -> Self {
1125+
Self {
1126+
call_count: AtomicUsize::new(0),
1127+
input_tokens,
1128+
output_tokens,
1129+
}
1130+
}
1131+
}
1132+
1133+
impl ProviderDef for FixedUsageProvider {
1134+
type Provider = Self;
1135+
1136+
fn metadata() -> ProviderMetadata {
1137+
ProviderMetadata {
1138+
name: "fixed-usage-mock".to_string(),
1139+
display_name: "Fixed Usage Mock".to_string(),
1140+
description: "Mock provider with fixed token usage for cumulative tests".to_string(),
1141+
default_model: "mock-model".to_string(),
1142+
known_models: vec![],
1143+
model_doc_link: "".to_string(),
1144+
config_keys: vec![],
1145+
setup_steps: vec![],
1146+
model_selection_hint: None,
1147+
}
1148+
}
1149+
1150+
fn from_env(
1151+
_model: ModelConfig,
1152+
_extensions: Vec<goose::config::ExtensionConfig>,
1153+
) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> {
1154+
Box::pin(async { Ok(Self::new(10, 5)) })
1155+
}
1156+
}
1157+
1158+
#[async_trait]
1159+
impl Provider for FixedUsageProvider {
1160+
async fn stream(
1161+
&self,
1162+
_model_config: &ModelConfig,
1163+
_session_id: &str,
1164+
_system_prompt: &str,
1165+
_messages: &[Message],
1166+
_tools: &[Tool],
1167+
) -> Result<MessageStream, ProviderError> {
1168+
let _call = self.call_count.fetch_add(1, Ordering::SeqCst);
1169+
let total = self.input_tokens + self.output_tokens;
1170+
let usage = ProviderUsage::new(
1171+
"mock-model".to_string(),
1172+
Usage::new(Some(self.input_tokens), Some(self.output_tokens), Some(total)),
1173+
);
1174+
let message = Message::assistant().with_text("Hello");
1175+
Ok(stream_from_single_message(message, usage))
1176+
}
1177+
1178+
fn get_model_config(&self) -> ModelConfig {
1179+
ModelConfig::new("mock-model").unwrap()
1180+
}
1181+
1182+
fn get_name(&self) -> &str {
1183+
"fixed-usage-mock"
1184+
}
1185+
}
1186+
1187+
#[tokio::test]
1188+
async fn test_accumulated_total_tokens_across_multiple_turns() -> Result<()> {
1189+
let temp_dir = tempfile::tempdir()?;
1190+
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
1191+
let config = AgentConfig::new(
1192+
session_manager.clone(),
1193+
PermissionManager::instance(),
1194+
None,
1195+
GooseMode::Auto,
1196+
true, // disable session naming
1197+
GoosePlatform::GooseCli,
1198+
);
1199+
let agent = Agent::with_config(config);
1200+
let provider = Arc::new(FixedUsageProvider::new(10, 5));
1201+
1202+
let session = session_manager
1203+
.create_session(
1204+
PathBuf::default(),
1205+
"cumulative-token-test".to_string(),
1206+
SessionType::Hidden,
1207+
GooseMode::default(),
1208+
)
1209+
.await?;
1210+
1211+
let session_id = session.id.clone();
1212+
agent.update_provider(provider.clone(), &session_id).await?;
1213+
1214+
// Turn 1
1215+
let session_config1 = SessionConfig {
1216+
id: session_id.clone(),
1217+
schedule_id: None,
1218+
max_turns: Some(1),
1219+
retry_config: None,
1220+
};
1221+
let reply_stream1 = agent
1222+
.reply(Message::user().with_text("Turn 1"), session_config1, None)
1223+
.await?;
1224+
tokio::pin!(reply_stream1);
1225+
while let Some(event) = reply_stream1.next().await {
1226+
let _ = event?;
1227+
}
1228+
1229+
let session_after_1 = session_manager.get_session(&session_id, false).await?;
1230+
assert_eq!(
1231+
session_after_1.accumulated_total_tokens,
1232+
Some(15),
1233+
"After turn 1, accumulated_total_tokens should be 15"
1234+
);
1235+
1236+
// Turn 2
1237+
let session_config2 = SessionConfig {
1238+
id: session_id.clone(),
1239+
schedule_id: None,
1240+
max_turns: Some(1),
1241+
retry_config: None,
1242+
};
1243+
let reply_stream2 = agent
1244+
.reply(Message::user().with_text("Turn 2"), session_config2, None)
1245+
.await?;
1246+
tokio::pin!(reply_stream2);
1247+
while let Some(event) = reply_stream2.next().await {
1248+
let _ = event?;
1249+
}
1250+
1251+
let session_after_2 = session_manager.get_session(&session_id, false).await?;
1252+
assert_eq!(
1253+
session_after_2.accumulated_total_tokens,
1254+
Some(30),
1255+
"After turn 2, accumulated_total_tokens should be 30"
1256+
);
1257+
1258+
// Turn 3
1259+
let session_config3 = SessionConfig {
1260+
id: session_id.clone(),
1261+
schedule_id: None,
1262+
max_turns: Some(1),
1263+
retry_config: None,
1264+
};
1265+
let reply_stream3 = agent
1266+
.reply(Message::user().with_text("Turn 3"), session_config3, None)
1267+
.await?;
1268+
tokio::pin!(reply_stream3);
1269+
while let Some(event) = reply_stream3.next().await {
1270+
let _ = event?;
1271+
}
1272+
1273+
let session_after_3 = session_manager.get_session(&session_id, false).await?;
1274+
assert_eq!(
1275+
session_after_3.accumulated_total_tokens,
1276+
Some(45),
1277+
"After turn 3, accumulated_total_tokens should be 45"
1278+
);
1279+
1280+
// total_tokens should still reflect the last turn only
1281+
assert_eq!(
1282+
session_after_3.total_tokens,
1283+
Some(15),
1284+
"total_tokens should reflect last turn only (15)"
1285+
);
1286+
1287+
Ok(())
1288+
}
1289+
}
10941290
}

0 commit comments

Comments
 (0)