Skip to content

Commit a5a321b

Browse files
committed
cleaning up
1 parent 3473881 commit a5a321b

6 files changed

Lines changed: 40 additions & 24 deletions

File tree

crates/goose-cli/src/commands/gateway.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::sync::Arc;
55

66
pub async fn handle_gateway_status() -> Result<()> {
77
let agent_manager = AgentManager::instance().await?;
8-
let gateway_manager = Arc::new(GatewayManager::new(agent_manager));
8+
let gateway_manager = Arc::new(GatewayManager::new(agent_manager)?);
99
let statuses = gateway_manager.status().await;
1010

1111
if statuses.is_empty() {
@@ -39,7 +39,7 @@ pub async fn handle_gateway_start(
3939
platform_config: serde_json::Value,
4040
) -> Result<()> {
4141
let agent_manager = AgentManager::instance().await?;
42-
let gateway_manager = Arc::new(GatewayManager::new(agent_manager));
42+
let gateway_manager = Arc::new(GatewayManager::new(agent_manager)?);
4343

4444
let mut config = goose::gateway::GatewayConfig {
4545
gateway_type,
@@ -60,15 +60,15 @@ pub async fn handle_gateway_start(
6060

6161
pub async fn handle_gateway_stop(gateway_type: String) -> Result<()> {
6262
let agent_manager = AgentManager::instance().await?;
63-
let gateway_manager = Arc::new(GatewayManager::new(agent_manager));
63+
let gateway_manager = Arc::new(GatewayManager::new(agent_manager)?);
6464
gateway_manager.stop_gateway(&gateway_type).await?;
6565
println!("Gateway '{}' stopped.", gateway_type);
6666
Ok(())
6767
}
6868

6969
pub async fn handle_gateway_pair(gateway_type: String) -> Result<()> {
7070
let agent_manager = AgentManager::instance().await?;
71-
let gateway_manager = Arc::new(GatewayManager::new(agent_manager));
71+
let gateway_manager = Arc::new(GatewayManager::new(agent_manager)?);
7272
let (code, expires_at) = gateway_manager.generate_pairing_code(&gateway_type).await?;
7373

7474
let expires = chrono::DateTime::from_timestamp(expires_at, 0)

crates/goose-server/src/state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl AppState {
5151

5252
let agent_manager = AgentManager::instance().await?;
5353
let tunnel_manager = Arc::new(TunnelManager::new());
54-
let gateway_manager = Arc::new(GatewayManager::new(agent_manager.clone()));
54+
let gateway_manager = Arc::new(GatewayManager::new(agent_manager.clone())?);
5555

5656
Ok(Arc::new(Self {
5757
agent_manager,

crates/goose/src/gateway/manager.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ pub struct GatewayManager {
6161
}
6262

6363
impl GatewayManager {
64-
pub fn new(agent_manager: Arc<AgentManager>) -> Self {
64+
pub fn new(agent_manager: Arc<AgentManager>) -> anyhow::Result<Self> {
6565
let db_path = Paths::data_dir().join("gateway").join("pairings.db");
66-
let pairing_store = Arc::new(PairingStore::new(&db_path));
66+
let pairing_store = Arc::new(PairingStore::new(&db_path)?);
6767

68-
Self {
68+
Ok(Self {
6969
gateways: RwLock::new(HashMap::new()),
7070
pairing_store,
7171
agent_manager,
72-
}
72+
})
7373
}
7474

7575
#[allow(dead_code)]

crates/goose/src/gateway/mod.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,28 @@ use utoipa::ToSchema;
1212

1313
use handler::GatewayHandler;
1414

15-
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
15+
#[derive(Debug, Clone, Serialize, Deserialize)]
1616
pub struct PlatformUser {
1717
pub platform: String,
1818
pub user_id: String,
1919
pub display_name: Option<String>,
2020
}
2121

22+
impl PartialEq for PlatformUser {
23+
fn eq(&self, other: &Self) -> bool {
24+
self.platform == other.platform && self.user_id == other.user_id
25+
}
26+
}
27+
28+
impl Eq for PlatformUser {}
29+
30+
impl std::hash::Hash for PlatformUser {
31+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
32+
self.platform.hash(state);
33+
self.user_id.hash(state);
34+
}
35+
}
36+
2237
#[derive(Debug, Clone)]
2338
#[allow(dead_code)]
2439
pub struct IncomingMessage {

crates/goose/src/gateway/pairing.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ pub struct PairingStore {
1414
}
1515

1616
impl PairingStore {
17-
pub fn new(db_path: &Path) -> Self {
17+
pub fn new(db_path: &Path) -> anyhow::Result<Self> {
1818
if let Some(parent) = db_path.parent() {
19-
std::fs::create_dir_all(parent).expect("Failed to create pairing database directory");
19+
std::fs::create_dir_all(parent)?;
2020
}
2121

2222
let options = SqliteConnectOptions::new()
@@ -27,11 +27,11 @@ impl PairingStore {
2727

2828
let pool = SqlitePoolOptions::new().connect_lazy_with(options);
2929

30-
Self {
30+
Ok(Self {
3131
pairings: RwLock::new(HashMap::new()),
3232
pool,
3333
initialized: OnceCell::new(),
34-
}
34+
})
3535
}
3636

3737
async fn ensure_initialized(&self) -> anyhow::Result<&Pool<Sqlite>> {
@@ -343,7 +343,7 @@ mod tests {
343343
#[tokio::test]
344344
async fn test_pairing_lifecycle() {
345345
let tmp = TempDir::new().unwrap();
346-
let store = PairingStore::new(&tmp.path().join("test.db"));
346+
let store = PairingStore::new(&tmp.path().join("test.db")).unwrap();
347347
let user = test_user("telegram", "12345");
348348

349349
let state = store.get(&user).await.unwrap();
@@ -381,7 +381,7 @@ mod tests {
381381
#[tokio::test]
382382
async fn test_pending_code_flow() {
383383
let tmp = TempDir::new().unwrap();
384-
let store = PairingStore::new(&tmp.path().join("test.db"));
384+
let store = PairingStore::new(&tmp.path().join("test.db")).unwrap();
385385

386386
let expires = chrono::Utc::now().timestamp() + 300;
387387
store
@@ -399,7 +399,7 @@ mod tests {
399399
#[tokio::test]
400400
async fn test_expired_code() {
401401
let tmp = TempDir::new().unwrap();
402-
let store = PairingStore::new(&tmp.path().join("test.db"));
402+
let store = PairingStore::new(&tmp.path().join("test.db")).unwrap();
403403

404404
let expired = chrono::Utc::now().timestamp() - 10;
405405
store
@@ -427,7 +427,7 @@ mod tests {
427427
let user = test_user("discord", "user42");
428428

429429
{
430-
let store = PairingStore::new(&db_path);
430+
let store = PairingStore::new(&db_path).unwrap();
431431
store
432432
.set(
433433
&user,
@@ -440,7 +440,7 @@ mod tests {
440440
.unwrap();
441441
}
442442

443-
let store2 = PairingStore::new(&db_path);
443+
let store2 = PairingStore::new(&db_path).unwrap();
444444
let state = store2.get(&user).await.unwrap();
445445
match state {
446446
PairingState::Paired { session_id, .. } => assert_eq!(session_id, "s-42"),
@@ -451,7 +451,7 @@ mod tests {
451451
#[tokio::test]
452452
async fn test_remove_all_for_platform() {
453453
let tmp = TempDir::new().unwrap();
454-
let store = PairingStore::new(&tmp.path().join("test.db"));
454+
let store = PairingStore::new(&tmp.path().join("test.db")).unwrap();
455455

456456
let tg1 = test_user("telegram", "111");
457457
let tg2 = test_user("telegram", "222");

crates/goose/src/gateway/telegram_format.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ fn collapse_newlines(text: &str) -> String {
7676
let mut chars = text.chars().peekable();
7777

7878
while let Some(ch) = chars.next() {
79-
if !in_pre && ch == '<' {
79+
if ch == '<' {
8080
let rest: String = chars.clone().take(4).collect();
81-
if rest.starts_with("pre>") || rest.starts_with("pre ") {
81+
if !in_pre && (rest.starts_with("pre>") || rest.starts_with("pre ")) {
8282
in_pre = true;
8383
}
84-
if rest.starts_with("/pre") {
84+
if in_pre && rest.starts_with("/pre") {
8585
in_pre = false;
8686
}
8787
result.push(ch);
@@ -106,6 +106,7 @@ fn escape_html(text: &str) -> String {
106106
text.replace('&', "&amp;")
107107
.replace('<', "&lt;")
108108
.replace('>', "&gt;")
109+
.replace('"', "&quot;")
109110
}
110111

111112
#[cfg(test)]
@@ -231,7 +232,7 @@ That's all!"#;
231232
assert!(html.contains("1. "));
232233
assert!(html.contains("2. "));
233234
assert!(html.contains("<pre><code>"));
234-
assert!(html.contains("print(\"hello\")"));
235+
assert!(html.contains("print(&quot;hello&quot;)"));
235236
assert!(html.contains("<a href="));
236237
assert!(html.contains("<blockquote>"));
237238
assert!(html.contains("That's all!"));

0 commit comments

Comments
 (0)