Skip to content

Commit a5a497b

Browse files
authored
pass globally unique conversation identifier as sessionId in databricks api call (#8576)
1 parent c29e320 commit a5a497b

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

crates/goose/src/instance_id.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use crate::config::paths::Paths;
2+
use once_cell::sync::Lazy;
3+
use std::fs;
4+
use uuid::Uuid;
5+
6+
static INSTANCE_ID: Lazy<String> = Lazy::new(load_or_create);
7+
8+
fn file_path() -> std::path::PathBuf {
9+
Paths::state_dir().join("instance_id")
10+
}
11+
12+
fn load_or_create() -> String {
13+
let path = file_path();
14+
15+
if let Ok(id) = fs::read_to_string(&path) {
16+
let id = id.trim().to_string();
17+
if !id.is_empty() {
18+
return id;
19+
}
20+
}
21+
22+
let id = Uuid::new_v4().to_string();
23+
24+
if let Some(parent) = path.parent() {
25+
let _ = fs::create_dir_all(parent);
26+
}
27+
let _ = fs::write(&path, &id);
28+
29+
id
30+
}
31+
32+
/// Returns a stable, globally unique identifier for this Goose installation.
33+
/// The ID is generated once and persisted to disk, surviving restarts.
34+
pub fn get_instance_id() -> &'static str {
35+
&INSTANCE_ID
36+
}
37+
38+
#[cfg(test)]
39+
mod tests {
40+
use super::*;
41+
42+
#[test]
43+
fn test_instance_id_is_stable() {
44+
let id1 = get_instance_id();
45+
let id2 = get_instance_id();
46+
assert_eq!(id1, id2);
47+
assert!(!id1.is_empty());
48+
}
49+
}

crates/goose/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub mod execution;
1818
pub mod gateway;
1919
pub mod goose_apps;
2020
pub mod hints;
21+
pub mod instance_id;
2122
pub mod logging;
2223
pub mod mcp_utils;
2324
pub mod model;

crates/goose/src/providers/databricks.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use super::retry::ProviderRetry;
2929
use super::utils::{ImageFormat, RequestLog};
3030
use crate::config::ConfigError;
3131
use crate::conversation::message::Message;
32+
use crate::instance_id::get_instance_id;
3233
use crate::model::ModelConfig;
3334
use crate::providers::retry::{
3435
RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS,
@@ -132,6 +133,8 @@ pub struct DatabricksProvider {
132133
name: String,
133134
#[serde(skip)]
134135
token_cache: Arc<Mutex<Option<String>>>,
136+
#[serde(skip)]
137+
instance_id: Option<String>,
135138
}
136139

137140
impl DatabricksProvider {
@@ -186,6 +189,7 @@ impl DatabricksProvider {
186189
fast_retry_config,
187190
name: DATABRICKS_PROVIDER_NAME.to_string(),
188191
token_cache,
192+
instance_id: Self::resolve_instance_id(),
189193
};
190194
provider.model =
191195
model.with_fast(DATABRICKS_DEFAULT_FAST_MODEL, DATABRICKS_PROVIDER_NAME)?;
@@ -249,9 +253,21 @@ impl DatabricksProvider {
249253
fast_retry_config: RetryConfig::new(0, 0, 1.0, 0),
250254
name: DATABRICKS_PROVIDER_NAME.to_string(),
251255
token_cache,
256+
instance_id: Self::resolve_instance_id(),
252257
})
253258
}
254259

260+
fn resolve_instance_id() -> Option<String> {
261+
let enabled = crate::config::Config::global()
262+
.get_param::<bool>("GOOSE_DATABRICKS_CLIENT_REQUEST_ID")
263+
.unwrap_or(false);
264+
if enabled {
265+
Some(get_instance_id().to_string())
266+
} else {
267+
None
268+
}
269+
}
270+
255271
fn is_responses_model(model_name: &str) -> bool {
256272
let normalized = model_name.to_ascii_lowercase();
257273
normalized.contains("codex")
@@ -267,16 +283,31 @@ impl DatabricksProvider {
267283
}
268284
}
269285

286+
fn build_client_request_id(&self, session_id: &str) -> Option<String> {
287+
self.instance_id.as_ref().map(|instance_id| {
288+
json!({
289+
"sessionId": format!("{}_{}", instance_id, session_id),
290+
})
291+
.to_string()
292+
})
293+
}
294+
270295
async fn post(
271296
&self,
272297
session_id: Option<&str>,
273-
payload: Value,
298+
mut payload: Value,
274299
model_name: Option<&str>,
275300
) -> Result<Value, ProviderError> {
276301
let is_embedding = payload.get("input").is_some() && payload.get("messages").is_none();
277302
let model_to_use = model_name.unwrap_or(&self.model.model_name);
278303
let path = self.get_endpoint_path(model_to_use, is_embedding);
279304

305+
if let Some(session_id) = session_id {
306+
if let Some(client_request_id) = self.build_client_request_id(session_id) {
307+
payload["client_request_id"] = Value::String(client_request_id);
308+
}
309+
}
310+
280311
let response = self
281312
.api_client
282313
.response_post(session_id, &path, &payload)
@@ -341,10 +372,14 @@ impl Provider for DatabricksProvider {
341372
tools: &[Tool],
342373
) -> Result<MessageStream, ProviderError> {
343374
let path = self.get_endpoint_path(&model_config.model_name, false);
375+
let client_request_id = self.build_client_request_id(session_id);
344376

345377
if Self::is_responses_model(&model_config.model_name) {
346378
let mut payload = create_responses_request(model_config, system, messages, tools)?;
347379
payload["stream"] = Value::Bool(true);
380+
if let Some(ref client_request_id) = client_request_id {
381+
payload["client_request_id"] = Value::String(client_request_id.clone());
382+
}
348383

349384
let mut log = RequestLog::start(model_config, &payload)?;
350385

@@ -383,6 +418,9 @@ impl Provider for DatabricksProvider {
383418
.as_object_mut()
384419
.expect("payload should have model key")
385420
.remove("model");
421+
if let Some(client_request_id) = client_request_id {
422+
payload["client_request_id"] = Value::String(client_request_id);
423+
}
386424

387425
payload
388426
.as_object_mut()

0 commit comments

Comments
 (0)