Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,6 @@ impl IRHelper for IntermediateRepr {
// Get best match.
let tests = function
.walk_tests()
.inspect(|t| log::info!("walking test: {:?}", t.item.1.elem.name))
.map(|t| t.item.1.elem.name.as_str())
.collect::<Vec<_>>();
error_not_found!("test", test_name, &tests)
Expand Down
3 changes: 1 addition & 2 deletions engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,7 @@ impl WasmRuntime {
},
test_snippet: snippet,
test_cases: f
.ir
.walk_function_test_pairs()
.walk_tests()
.map(|tc| {
let params = match tc.test_case_params(&ctx) {
Ok(params) => Ok(params
Expand Down
3 changes: 1 addition & 2 deletions engine/baml-schema-wasm/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,5 @@
"directory": "baml-schema-wasm"
},
"author": "BAML",

"license": "Apache-2.0"
}
}
229 changes: 50 additions & 179 deletions engine/language_server/src/cors_bypass_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,6 @@ use serde::Deserialize;
use tokio::{fs, net::TcpListener};
use tower_http::cors::{Any, CorsLayer};

/// Configuration for API key injection per provider
/// Format: (origin_url, header_name, env_var_name, baml_header_name)
const API_PROVIDERS: &[(&str, &str, &str, &str)] = &[
(
"https://api.openai.com",
"Authorization",
"OPENAI_API_KEY",
"baml-openai-api-key",
),
(
"https://api.anthropic.com",
"x-api-key",
"ANTHROPIC_API_KEY",
"baml-anthropic-api-key",
),
(
"https://generativelanguage.googleapis.com",
"x-goog-api-key",
"GOOGLE_API_KEY",
"baml-google-api-key",
),
(
"https://openrouter.ai",
"Authorization",
"OPENROUTER_API_KEY",
"baml-openrouter-api-key",
),
(
"https://api.llmapi.com",
"Authorization",
"LLAMA_API_KEY",
"baml-llama-api-key",
),
];

#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub port: u16,
Expand Down Expand Up @@ -117,16 +82,10 @@ pub fn create_proxy_router() -> Router<ProxyConfig> {
AUTHORIZATION,
HeaderName::from_static("x-api-key"),
HeaderName::from_static("baml-original-url"),
HeaderName::from_static("baml-openai-api-key"),
HeaderName::from_static("baml-anthropic-api-key"),
HeaderName::from_static("baml-google-api-key"),
HeaderName::from_static("baml-openrouter-api-key"),
HeaderName::from_static("baml-llama-api-key"),
])
.max_age(std::time::Duration::from_secs(86400));

Router::new()
.route("/static/{*path}", get(serve_static_file))
.route("/{*path}", options(handle_preflight))
.route("/{*path}", any(handle_proxy_request))
.layer(cors)
Expand All @@ -137,81 +96,30 @@ async fn handle_preflight() -> impl IntoResponse {
StatusCode::OK
}

/// Serve static files from the current working directory
async fn serve_static_file(Path(path): Path<String>) -> Result<AxumResponse, ProxyError> {
let file_path = path.strip_prefix("static/").unwrap_or(&path);
let current_dir = std::env::current_dir()
.map_err(|e| ProxyError::internal_error(format!("Failed to get current dir: {e}")))?;

// Try multiple potential base directories
let potential_paths = vec![
current_dir.join(file_path),
current_dir.join("baml_src").join(file_path),
current_dir.join("../baml_src").join(file_path),
];

let absolute_path = potential_paths
.into_iter()
.find(|path| path.exists())
.unwrap_or_else(|| current_dir.join(file_path));

match fs::read(&absolute_path).await {
Ok(contents) => {
let mime_type = from_path(file_path).first_or_octet_stream();
let content_type = mime_type.as_ref().to_string();

Ok((
StatusCode::OK,
[
("content-type", content_type.as_str()),
("access-control-allow-origin", "*"),
],
contents,
)
.into_response())
}
Err(err) => {
tracing::warn!("Failed to read static file {}: {}", file_path, err);

match err.kind() {
std::io::ErrorKind::NotFound => Err(ProxyError::not_found(format!(
"File not found: {file_path}"
))),
std::io::ErrorKind::PermissionDenied => Err(ProxyError::new(
format!("Permission denied: {file_path}"),
StatusCode::FORBIDDEN,
)),
_ => Err(ProxyError::internal_error(format!(
"Error reading file: {file_path}"
))),
}
}
}
}

/// Main proxy request handler
async fn handle_proxy_request(
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<AxumResponse, ProxyError> {
tracing::debug!(
"handle_proxy_request: {:?}: {:?}: {:?}",
method,
uri,
headers
);
let path_str = uri.path();

// Handle static file serving
if path_str.starts_with("/static/") && method == Method::GET {
return serve_static_file(Path(path_str.to_string())).await;
}

// Extract and validate the original URL
let original_url = extract_original_url(&headers)?;
let mut clean_headers = clean_headers(&headers);
let clean_headers = clean_headers(&headers);

// Parse the target URL
let mut target_url = parse_target_url(&original_url)?;

// Handle simple GET requests that don't need proxying
if is_simple_get_request(&method, path_str) {
// Handle image requests that should return empty content (matching Express behavior)
if is_image_request(&method, path_str) {
return Ok(create_empty_response().into_response());
}

Expand All @@ -224,9 +132,6 @@ async fn handle_proxy_request(
.await
.map_err(|e| ProxyError::bad_request(format!("Failed to read body: {e}")))?;

// Inject API keys for supported providers
inject_api_key(&mut clean_headers, &target_url, &headers);

// Execute the request and return the response
execute_request(method, &target_url, clean_headers, body_bytes.to_vec()).await
}
Expand All @@ -243,7 +148,7 @@ fn extract_original_url(headers: &HeaderMap) -> Result<String, ProxyError> {
/// Remove headers that shouldn't be forwarded
fn clean_headers(headers: &HeaderMap) -> HeaderMap {
let mut clean_headers = headers.clone();
let headers_to_remove = ["baml-original-url", "origin", "authorization", "host"];
let headers_to_remove = ["baml-original-url", "origin", "host"];

for header_name in &headers_to_remove {
clean_headers.remove(*header_name);
Expand All @@ -258,83 +163,43 @@ fn parse_target_url(url_str: &str) -> Result<url::Url, ProxyError> {
url::Url::parse(clean_url).map_err(|e| ProxyError::bad_request(format!("Invalid URL: {e}")))
}

/// Check if this is a simple GET request that doesn't need proxying
fn is_simple_get_request(method: &Method, path: &str) -> bool {
path.matches('.').count() == 1 && method == Method::GET
/// Check if this is an image request that should return empty content
fn is_image_request(method: &Method, path: &str) -> bool {
if method != Method::GET {
return false;
}

// Match the Express regex: /\.(png|jpe?g|gif|bmp|webp|svg)$/i
let path_lower = path.to_lowercase();
path_lower.ends_with(".png")
|| path_lower.ends_with(".jpg")
|| path_lower.ends_with(".jpeg")
|| path_lower.ends_with(".gif")
|| path_lower.ends_with(".bmp")
|| path_lower.ends_with(".webp")
|| path_lower.ends_with(".svg")
}

/// Create an empty successful response
fn create_empty_response() -> impl IntoResponse {
(StatusCode::OK, [("access-control-allow-origin", "*")], "")
}

/// Construct the final path for the target URL
/// Construct the final path for the target URL (matching Express behavior)
fn construct_final_path(url: &url::Url, path_str: &str) -> String {
let base_path = url.path().trim_end_matches('/');

let final_path = if base_path.is_empty() {
path_str.trim_end_matches('/').to_string()
if base_path.is_empty() {
path_str.to_string()
} else if !path_str.starts_with(base_path) {
// Match Express logic: basePath + (req.url.startsWith('/') ? '' : '/') + req.url
if path_str.starts_with('/') {
format!("{base_path}{path_str}")
} else {
format!("{base_path}/{path_str}")
}
} else {
path_str.to_string()
};

final_path.trim_end_matches('/').to_string()
}

/// Inject appropriate API key based on the target URL
fn inject_api_key(headers: &mut HeaderMap, target_url: &url::Url, original_headers: &HeaderMap) {
let origin = get_origin_string(target_url);

for (allowed_origin, header_name, env_var, baml_header) in API_PROVIDERS {
if origin == *allowed_origin {
if let Some(api_key) = get_api_key(env_var, baml_header, original_headers) {
let header_value = format_api_key_header(header_name, &api_key);
if let Ok(header_val) = header_value.parse() {
headers.insert(*header_name, header_val);
}
}
break;
}
}
}

/// Convert URL origin to string format
fn get_origin_string(url: &url::Url) -> String {
match url.origin() {
url::Origin::Tuple(scheme, host, port) => match (scheme.as_str(), port) {
("http", 80) | ("https", 443) => format!("{scheme}://{host}"),
_ => format!("{scheme}://{host}:{port}"),
},
url::Origin::Opaque(_) => String::new(),
}
}

/// Get API key from environment or headers
fn get_api_key(env_var: &str, baml_header: &str, headers: &HeaderMap) -> Option<String> {
// Try environment variable first
std::env::var(env_var)
.ok()
// Then try custom header
.or_else(|| {
headers
.get(baml_header)
.and_then(|v| v.to_str().ok())
.map(String::from)
})
}

/// Format API key header value based on header type
fn format_api_key_header(header_name: &str, api_key: &str) -> String {
if header_name == "Authorization" {
format!("Bearer {api_key}")
} else {
api_key.to_string()
}
}

Expand Down Expand Up @@ -464,21 +329,27 @@ mod tests {
}

#[test]
fn test_get_origin_string() {
let url = url::Url::parse("https://api.example.com/v1/chat").unwrap();
assert_eq!(get_origin_string(&url), "https://api.example.com");

let url = url::Url::parse("http://localhost:8080/api").unwrap();
assert_eq!(get_origin_string(&url), "http://localhost:8080");
}

#[test]
fn test_format_api_key_header() {
assert_eq!(
format_api_key_header("Authorization", "sk-123"),
"Bearer sk-123"
);
assert_eq!(format_api_key_header("x-api-key", "key123"), "key123");
fn test_is_image_request() {
use axum::http::Method;

// Test image extensions
assert!(is_image_request(&Method::GET, "/path/image.png"));
assert!(is_image_request(&Method::GET, "/path/image.jpg"));
assert!(is_image_request(&Method::GET, "/path/image.jpeg"));
assert!(is_image_request(&Method::GET, "/path/image.gif"));
assert!(is_image_request(&Method::GET, "/path/image.bmp"));
assert!(is_image_request(&Method::GET, "/path/image.webp"));
assert!(is_image_request(&Method::GET, "/path/image.svg"));

// Test case insensitive
assert!(is_image_request(&Method::GET, "/path/IMAGE.PNG"));

// Test non-GET methods
assert!(!is_image_request(&Method::POST, "/path/image.png"));

// Test non-image paths
assert!(!is_image_request(&Method::GET, "/api/endpoint"));
assert!(!is_image_request(&Method::GET, "/pdf/2305.08675"));
}

#[test]
Expand Down
17 changes: 11 additions & 6 deletions engine/language_server/src/server/api/requests/execute_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl SyncRequestHandler for ExecuteCommand {
return Ok(None);
}

tracing::info!("Executing command: {:?}", params);
match RegisteredCommands::from_execute_command(params) {
Err(e) => {
return Err(crate::server::api::Error {
Expand All @@ -83,27 +84,31 @@ impl SyncRequestHandler for ExecuteCommand {
});
}
Ok(RegisteredCommands::OpenBamlPanel(args)) => {
session
let tx = session
.playground_tx
.send(PreLangServerToWasmMessage::FrontendMessage(
FrontendMessage::select_function {
// TODO: this can't be correct... but it looks like it is
root_path: args.project_id,
function_name: args.function_name,
},
))
.unwrap();
));
if let Err(e) = tx {
tracing::warn!("Error forwarding OpenBamlPanel to playground: {}", e);
}
}
Ok(RegisteredCommands::RunTest(args)) => {
session
let tx = session
.playground_tx
.send(PreLangServerToWasmMessage::FrontendMessage(
FrontendMessage::run_test {
function_name: args.function_name,
test_name: args.test_case_name,
},
))
.unwrap();
));
if let Err(e) = tx {
tracing::warn!("Error forwarding RunTest to playground: {}", e);
}
}
}

Expand Down
Loading
Loading