Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 6 additions & 24 deletions next_task.md
Original file line number Diff line number Diff line change
@@ -1,45 +1,27 @@
# Completed Tasks

## ~~Fix recording not stopped on agent loop error~~
Fixed — recording is now stopped unconditionally before propagating agent loop errors.

## ~~Add per-execution timeout to evaluator exec calls~~
Fixed — all `session.exec()` and `session.exec_with_exit_code()` calls in evaluator are now wrapped with `tokio::time::timeout()`. Default timeout: 120s. Configurable per-task via `eval_timeout_secs` on `EvaluatorConfig`.

---

# Next Tasks — Medium-Impact Improvements

Items identified during the structural refactoring that are worth addressing but were out of scope for the split.

## 1. Provider HTTP dedup

`src/provider/openai.rs` and `src/provider/custom.rs` are ~95% identical (same HTTP client setup, request building, response parsing). Extract a shared `http_base.rs` that both providers delegate to, differing only in URL construction and auth headers.

## 2. Legacy agent removal
## 1. Legacy agent removal

`src/agent/mod.rs`, `src/agent/tools.rs`, and `src/agent/openai.rs` (654 lines total) implement the v1 tool-call-based agent loop, fully superseded by v2 (`loop_v2.rs`). The only remaining caller is `run_inner()` in `orchestration.rs` (the legacy CLI path). Removing these requires migrating or deprecating the legacy CLI path.

## 3. Blanket `#![allow(dead_code)]` cleanup
## 2. Blanket `#![allow(dead_code)]` cleanup

~10 files have `#![allow(dead_code)]` at the top where the code is actively used. Remove the blanket allows and address any actual dead code warnings individually.

## 4. System prompt splitting
## 3. System prompt splitting

`src/agent/context.rs` contains a 116-line `format!()` macro building the system prompt. Split into section builder functions (e.g., `build_interaction_guidelines()`, `build_output_format()`) for readability and testability.

## 5. `recording.rs` `format_caption()` decomposition
## 4. `recording.rs` `format_caption()` decomposition

Mixed concerns (text truncation, layout calculation, magic numbers for font sizes/margins). Extract into smaller functions with named constants.

## 6. `task.rs` `validate()` decomposition
## 5. `task.rs` `validate()` decomposition

188-line monolith validating setup steps, evaluator config, metrics, and app config. Split into `validate_setup_steps()`, `validate_evaluator()`, `validate_metrics()`, `validate_app_config()`.

## 7. Shared constants dedup
## 6. Shared constants dedup

`DEFAULT_STEP_TIMEOUT_SECS` is defined in both `src/agent/pyautogui.rs` and `src/agent/loop_v2.rs`. Consolidate into a single location (e.g., `src/agent/mod.rs` or a shared constants module).

## 8. `is_context_length_error` false-positive on `max_tokens`

`src/agent/context.rs:317` — `lower.contains("max_tokens")` is too broad. Parameter validation errors like "max_tokens must be less than X" are misclassified as context length errors, triggering an unnecessary fallback (clear trajectory + retry). Replace with a more specific pattern like `"input length and max_tokens exceed"` or remove the `"max_tokens"` check entirely (the other patterns cover real context length errors). Flagged by Sentry in PR #16.
4 changes: 3 additions & 1 deletion src/agent/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ pub fn is_context_length_error(error_msg: &str) -> bool {
|| lower.contains("maximum context length")
|| lower.contains("prompt is too long")
|| lower.contains("maximum number of tokens")
|| lower.contains("max_tokens")
|| lower.contains("context window")
|| lower.contains("token limit")
}
Expand Down Expand Up @@ -643,6 +642,9 @@ mod tests {
assert!(!is_context_length_error("rate limit exceeded"));
assert!(!is_context_length_error("authentication failed"));
assert!(!is_context_length_error("internal server error"));
assert!(!is_context_length_error(
"max_tokens must be less than 128000"
));
}

// --- Multiple trajectory with mixed observation types ---
Expand Down
4 changes: 2 additions & 2 deletions src/agent/openai.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Backward-compatibility re-exports.
//!
//! The LLM provider implementation has moved to `crate::provider::openai`.
//! The LLM provider implementation lives in `crate::provider::http_base`.
//! Types and helpers have moved to `crate::provider`.
//! This module re-exports them so existing code continues to work.

Expand All @@ -10,4 +10,4 @@ pub use crate::provider::{
ChatMessage, FunctionCall, ToolCall,
};
#[allow(unused_imports)]
pub use crate::provider::openai::OpenAiProvider as OpenAiClient;
pub use crate::provider::http_base::HttpProvider as OpenAiClient;
247 changes: 9 additions & 238 deletions src/provider/custom.rs
Original file line number Diff line number Diff line change
@@ -1,132 +1,18 @@
#![allow(dead_code)]
use super::http_base::HttpProvider;

use std::pin::Pin;

use tracing::info;

use super::{ChatMessage, LlmProvider};
use crate::error::AppError;

/// Response shape from OpenAI-compatible chat completions endpoints.
#[derive(Debug, serde::Deserialize)]
struct ChatCompletionResponse {
choices: Vec<Choice>,
}

#[derive(Debug, serde::Deserialize)]
struct Choice {
message: ChatMessage,
}

/// Custom provider for OpenAI-compatible endpoints with a configurable base URL.
///
/// This allows using any API that implements the OpenAI chat completions interface,
/// such as local inference servers, vLLM, Ollama, etc.
pub struct CustomProvider {
http: reqwest::Client,
api_key: String,
model: String,
base_url: String,
}
/// Custom provider for OpenAI-compatible endpoints — thin wrapper around [`HttpProvider`].
pub struct CustomProvider;

impl CustomProvider {
pub fn new(api_key: &str, model: &str, base_url: &str) -> Self {
Self {
http: reqwest::Client::new(),
api_key: api_key.into(),
model: model.into(),
base_url: base_url.into(),
}
}

/// Get the full URL for the chat completions endpoint.
pub fn completions_url(&self) -> String {
format!("{}/v1/chat/completions", self.base_url)
}
}

impl LlmProvider for CustomProvider {
fn chat_completion<'a>(
&'a self,
messages: &'a [ChatMessage],
tools: &'a [serde_json::Value],
) -> Pin<Box<dyn std::future::Future<Output = Result<ChatMessage, AppError>> + Send + 'a>> {
Box::pin(async move {
let mut body = serde_json::json!({
"model": self.model,
"messages": messages,
});

if !tools.is_empty() {
body["tools"] = serde_json::json!(tools);
body["tool_choice"] = serde_json::json!("auto");
}

let body_str = serde_json::to_string(&body).unwrap_or_default();
let payload_kb = body_str.len() / 1024;
info!(
"Custom API request to {}: {} messages, ~{} KB payload",
self.base_url,
messages.len(),
payload_kb
);

let start = std::time::Instant::now();
let url = self.completions_url();

let response = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| AppError::Agent(format!("HTTP request failed: {e}")))?;

let elapsed = start.elapsed();
let status = response.status();

if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
return Err(AppError::Agent(format!(
"Custom API error ({}): {}",
status, error_body
)));
}

let completion: ChatCompletionResponse = response
.json()
.await
.map_err(|e| AppError::Agent(format!("Failed to parse response: {e}")))?;

let msg = completion
.choices
.into_iter()
.next()
.ok_or_else(|| AppError::Agent("No choices in response".into()))?
.message;

let tool_names: Vec<&str> = msg
.tool_calls
.as_ref()
.map(|tcs| tcs.iter().map(|tc| tc.function.name.as_str()).collect())
.unwrap_or_default();
info!(
"Custom API response in {:.1}s: tool_calls={:?}",
elapsed.as_secs_f64(),
tool_names
);

Ok(msg)
})
pub fn new(api_key: &str, model: &str, base_url: &str) -> HttpProvider {
HttpProvider::new(api_key, model, base_url, "Custom")
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::provider::user_message;
use crate::provider::{user_message, LlmProvider};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

Expand All @@ -141,51 +27,13 @@ mod tests {
})
}

// ---------- URL construction tests ----------

#[test]
fn test_completions_url_default() {
let provider = CustomProvider::new("key", "model", "https://my-server.com");
assert_eq!(
provider.completions_url(),
"https://my-server.com/v1/chat/completions"
);
}

#[test]
fn test_completions_url_with_trailing_slash() {
// Users may or may not include trailing slash - our URL construction handles this
let provider = CustomProvider::new("key", "model", "https://my-server.com");
assert!(provider.completions_url().starts_with("https://my-server.com/"));
}

#[test]
fn test_completions_url_localhost() {
let provider = CustomProvider::new("key", "llama-3", "http://localhost:8080");
assert_eq!(
provider.completions_url(),
"http://localhost:8080/v1/chat/completions"
);
}

#[test]
fn test_completions_url_with_path_prefix() {
let provider = CustomProvider::new("key", "model", "https://gateway.example.com/api");
assert_eq!(
provider.completions_url(),
"https://gateway.example.com/api/v1/chat/completions"
);
}

// ---------- API integration tests ----------

#[tokio::test]
async fn test_simple_text_response() {
async fn test_custom_provider_with_base_url() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200).set_body_json(mock_text_response("Custom response!")),
ResponseTemplate::new(200).set_body_json(mock_text_response("Custom!")),
)
.mount(&server)
.await;
Expand All @@ -197,84 +45,7 @@ mod tests {
assert_eq!(result.role, "assistant");
assert_eq!(
result.content.unwrap(),
serde_json::Value::String("Custom response!".into())
serde_json::Value::String("Custom!".into())
);
}

#[tokio::test]
async fn test_api_error_handling() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_string("internal error"))
.mount(&server)
.await;

let provider = CustomProvider::new("sk-test", "model", &server.uri());
let result = provider
.chat_completion(&[user_message("test")], &[])
.await;

assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("500"));
}

#[tokio::test]
async fn test_provider_trait_via_box_dyn() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(mock_text_response("Via trait!")),
)
.mount(&server)
.await;

let provider: Box<dyn LlmProvider> =
Box::new(CustomProvider::new("sk-test", "model", &server.uri()));
let messages = vec![user_message("Hi")];
let result = provider.chat_completion(&messages, &[]).await.unwrap();

assert_eq!(
result.content.unwrap(),
serde_json::Value::String("Via trait!".into())
);
}

#[tokio::test]
async fn test_tool_call_response() {
let server = MockServer::start().await;
let response_body = serde_json::json!({
"choices": [{
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {
"name": "test_tool",
"arguments": "{}"
}
}]
}
}]
});
Mock::given(method("POST"))
.and(path("/v1/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(response_body))
.mount(&server)
.await;

let provider = CustomProvider::new("sk-test", "model", &server.uri());
let result = provider
.chat_completion(&[user_message("test")], &[])
.await
.unwrap();

let tool_calls = result.tool_calls.unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "test_tool");
}
}
Loading