Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ impl IRHelper for IntermediateRepr {
// Get best match.
let tests = function
.walk_tests()
.inspect(|t| log::info!("walking test: {:?}", t.item.1.elem.name))
.map(|t| t.item.1.elem.name.as_str())
.collect::<Vec<_>>();
error_not_found!("test", test_name, &tests)
Expand Down
3 changes: 1 addition & 2 deletions engine/baml-schema-wasm/src/runtime_wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1055,8 +1055,7 @@ impl WasmRuntime {
},
test_snippet: snippet,
test_cases: f
.ir
.walk_function_test_pairs()
.walk_tests()
.map(|tc| {
let params = match tc.test_case_params(&ctx) {
Ok(params) => Ok(params
Expand Down
7 changes: 3 additions & 4 deletions engine/baml-schema-wasm/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
"clean": "git clean -xdf ./dist .turbo node_modules",
"comment": "don't use 'wasm pack --dev' below or it will be too big of a bundle and playground will be slow to load.",
"build": "pnpm run release",
"release": "pnpm run pack --release",
"pack": "wasm-pack build ../ --target bundler --out-dir ./web/dist"
"release": "pnpm run pack --dev",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The release script is still using '--dev', but the PR description says 'revert wasm-pack to --release'. Please update the release script accordingly.

Suggested change
"release": "pnpm run pack --dev",
"release": "pnpm run pack --release",

"pack": "wasm-pack build ../ --target bundler --out-dir ./web/dist"
},
"exports": {
".": {
Expand All @@ -26,6 +26,5 @@
"directory": "baml-schema-wasm"
},
"author": "BAML",

"license": "Apache-2.0"
}
}
229 changes: 50 additions & 179 deletions engine/language_server/src/cors_bypass_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,6 @@ use serde::Deserialize;
use tokio::{fs, net::TcpListener};
use tower_http::cors::{Any, CorsLayer};

/// Configuration for API key injection per provider
/// Format: (origin_url, header_name, env_var_name, baml_header_name)
const API_PROVIDERS: &[(&str, &str, &str, &str)] = &[
(
"https://api.openai.com",
"Authorization",
"OPENAI_API_KEY",
"baml-openai-api-key",
),
(
"https://api.anthropic.com",
"x-api-key",
"ANTHROPIC_API_KEY",
"baml-anthropic-api-key",
),
(
"https://generativelanguage.googleapis.com",
"x-goog-api-key",
"GOOGLE_API_KEY",
"baml-google-api-key",
),
(
"https://openrouter.ai",
"Authorization",
"OPENROUTER_API_KEY",
"baml-openrouter-api-key",
),
(
"https://api.llmapi.com",
"Authorization",
"LLAMA_API_KEY",
"baml-llama-api-key",
),
];

#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub port: u16,
Expand Down Expand Up @@ -117,16 +82,10 @@ pub fn create_proxy_router() -> Router<ProxyConfig> {
AUTHORIZATION,
HeaderName::from_static("x-api-key"),
HeaderName::from_static("baml-original-url"),
HeaderName::from_static("baml-openai-api-key"),
HeaderName::from_static("baml-anthropic-api-key"),
HeaderName::from_static("baml-google-api-key"),
HeaderName::from_static("baml-openrouter-api-key"),
HeaderName::from_static("baml-llama-api-key"),
])
.max_age(std::time::Duration::from_secs(86400));

Router::new()
.route("/static/{*path}", get(serve_static_file))
.route("/{*path}", options(handle_preflight))
.route("/{*path}", any(handle_proxy_request))
.layer(cors)
Expand All @@ -137,81 +96,30 @@ async fn handle_preflight() -> impl IntoResponse {
StatusCode::OK
}

/// Serve static files from the current working directory
async fn serve_static_file(Path(path): Path<String>) -> Result<AxumResponse, ProxyError> {
let file_path = path.strip_prefix("static/").unwrap_or(&path);
let current_dir = std::env::current_dir()
.map_err(|e| ProxyError::internal_error(format!("Failed to get current dir: {e}")))?;

// Try multiple potential base directories
let potential_paths = vec![
current_dir.join(file_path),
current_dir.join("baml_src").join(file_path),
current_dir.join("../baml_src").join(file_path),
];

let absolute_path = potential_paths
.into_iter()
.find(|path| path.exists())
.unwrap_or_else(|| current_dir.join(file_path));

match fs::read(&absolute_path).await {
Ok(contents) => {
let mime_type = from_path(file_path).first_or_octet_stream();
let content_type = mime_type.as_ref().to_string();

Ok((
StatusCode::OK,
[
("content-type", content_type.as_str()),
("access-control-allow-origin", "*"),
],
contents,
)
.into_response())
}
Err(err) => {
tracing::warn!("Failed to read static file {}: {}", file_path, err);

match err.kind() {
std::io::ErrorKind::NotFound => Err(ProxyError::not_found(format!(
"File not found: {file_path}"
))),
std::io::ErrorKind::PermissionDenied => Err(ProxyError::new(
format!("Permission denied: {file_path}"),
StatusCode::FORBIDDEN,
)),
_ => Err(ProxyError::internal_error(format!(
"Error reading file: {file_path}"
))),
}
}
}
}

/// Main proxy request handler
async fn handle_proxy_request(
method: Method,
uri: Uri,
headers: HeaderMap,
body: Body,
) -> Result<AxumResponse, ProxyError> {
tracing::debug!(
"handle_proxy_request: {:?}: {:?}: {:?}",
method,
uri,
headers
);
let path_str = uri.path();

// Handle static file serving
if path_str.starts_with("/static/") && method == Method::GET {
return serve_static_file(Path(path_str.to_string())).await;
}

// Extract and validate the original URL
let original_url = extract_original_url(&headers)?;
let mut clean_headers = clean_headers(&headers);
let clean_headers = clean_headers(&headers);

// Parse the target URL
let mut target_url = parse_target_url(&original_url)?;

// Handle simple GET requests that don't need proxying
if is_simple_get_request(&method, path_str) {
// Handle image requests that should return empty content (matching Express behavior)
if is_image_request(&method, path_str) {
return Ok(create_empty_response().into_response());
}

Expand All @@ -224,9 +132,6 @@ async fn handle_proxy_request(
.await
.map_err(|e| ProxyError::bad_request(format!("Failed to read body: {e}")))?;

// Inject API keys for supported providers
inject_api_key(&mut clean_headers, &target_url, &headers);

// Execute the request and return the response
execute_request(method, &target_url, clean_headers, body_bytes.to_vec()).await
}
Expand All @@ -243,7 +148,7 @@ fn extract_original_url(headers: &HeaderMap) -> Result<String, ProxyError> {
/// Remove headers that shouldn't be forwarded
fn clean_headers(headers: &HeaderMap) -> HeaderMap {
let mut clean_headers = headers.clone();
let headers_to_remove = ["baml-original-url", "origin", "authorization", "host"];
let headers_to_remove = ["baml-original-url", "origin", "host"];

for header_name in &headers_to_remove {
clean_headers.remove(*header_name);
Expand All @@ -258,83 +163,43 @@ fn parse_target_url(url_str: &str) -> Result<url::Url, ProxyError> {
url::Url::parse(clean_url).map_err(|e| ProxyError::bad_request(format!("Invalid URL: {e}")))
}

/// Check if this is a simple GET request that doesn't need proxying
fn is_simple_get_request(method: &Method, path: &str) -> bool {
path.matches('.').count() == 1 && method == Method::GET
/// Check if this is an image request that should return empty content
fn is_image_request(method: &Method, path: &str) -> bool {
if method != Method::GET {
return false;
}

// Match the Express regex: /\.(png|jpe?g|gif|bmp|webp|svg)$/i
let path_lower = path.to_lowercase();
path_lower.ends_with(".png")
|| path_lower.ends_with(".jpg")
|| path_lower.ends_with(".jpeg")
|| path_lower.ends_with(".gif")
|| path_lower.ends_with(".bmp")
|| path_lower.ends_with(".webp")
|| path_lower.ends_with(".svg")
}

/// Create an empty successful response
fn create_empty_response() -> impl IntoResponse {
(StatusCode::OK, [("access-control-allow-origin", "*")], "")
}

/// Construct the final path for the target URL
/// Construct the final path for the target URL (matching Express behavior)
fn construct_final_path(url: &url::Url, path_str: &str) -> String {
let base_path = url.path().trim_end_matches('/');

let final_path = if base_path.is_empty() {
path_str.trim_end_matches('/').to_string()
if base_path.is_empty() {
path_str.to_string()
} else if !path_str.starts_with(base_path) {
// Match Express logic: basePath + (req.url.startsWith('/') ? '' : '/') + req.url
if path_str.starts_with('/') {
format!("{base_path}{path_str}")
} else {
format!("{base_path}/{path_str}")
}
} else {
path_str.to_string()
};

final_path.trim_end_matches('/').to_string()
}

/// Inject appropriate API key based on the target URL
fn inject_api_key(headers: &mut HeaderMap, target_url: &url::Url, original_headers: &HeaderMap) {
let origin = get_origin_string(target_url);

for (allowed_origin, header_name, env_var, baml_header) in API_PROVIDERS {
if origin == *allowed_origin {
if let Some(api_key) = get_api_key(env_var, baml_header, original_headers) {
let header_value = format_api_key_header(header_name, &api_key);
if let Ok(header_val) = header_value.parse() {
headers.insert(*header_name, header_val);
}
}
break;
}
}
}

/// Convert URL origin to string format
fn get_origin_string(url: &url::Url) -> String {
match url.origin() {
url::Origin::Tuple(scheme, host, port) => match (scheme.as_str(), port) {
("http", 80) | ("https", 443) => format!("{scheme}://{host}"),
_ => format!("{scheme}://{host}:{port}"),
},
url::Origin::Opaque(_) => String::new(),
}
}

/// Get API key from environment or headers
fn get_api_key(env_var: &str, baml_header: &str, headers: &HeaderMap) -> Option<String> {
// Try environment variable first
std::env::var(env_var)
.ok()
// Then try custom header
.or_else(|| {
headers
.get(baml_header)
.and_then(|v| v.to_str().ok())
.map(String::from)
})
}

/// Format API key header value based on header type
fn format_api_key_header(header_name: &str, api_key: &str) -> String {
if header_name == "Authorization" {
format!("Bearer {api_key}")
} else {
api_key.to_string()
}
}

Expand Down Expand Up @@ -464,21 +329,27 @@ mod tests {
}

#[test]
fn test_get_origin_string() {
let url = url::Url::parse("https://api.example.com/v1/chat").unwrap();
assert_eq!(get_origin_string(&url), "https://api.example.com");

let url = url::Url::parse("http://localhost:8080/api").unwrap();
assert_eq!(get_origin_string(&url), "http://localhost:8080");
}

#[test]
fn test_format_api_key_header() {
assert_eq!(
format_api_key_header("Authorization", "sk-123"),
"Bearer sk-123"
);
assert_eq!(format_api_key_header("x-api-key", "key123"), "key123");
fn test_is_image_request() {
use axum::http::Method;

// Test image extensions
assert!(is_image_request(&Method::GET, "/path/image.png"));
assert!(is_image_request(&Method::GET, "/path/image.jpg"));
assert!(is_image_request(&Method::GET, "/path/image.jpeg"));
assert!(is_image_request(&Method::GET, "/path/image.gif"));
assert!(is_image_request(&Method::GET, "/path/image.bmp"));
assert!(is_image_request(&Method::GET, "/path/image.webp"));
assert!(is_image_request(&Method::GET, "/path/image.svg"));

// Test case insensitive
assert!(is_image_request(&Method::GET, "/path/IMAGE.PNG"));

// Test non-GET methods
assert!(!is_image_request(&Method::POST, "/path/image.png"));

// Test non-image paths
assert!(!is_image_request(&Method::GET, "/api/endpoint"));
assert!(!is_image_request(&Method::GET, "/pdf/2305.08675"));
}

#[test]
Expand Down
17 changes: 11 additions & 6 deletions engine/language_server/src/server/api/requests/execute_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl SyncRequestHandler for ExecuteCommand {
return Ok(None);
}

tracing::info!("Executing command: {:?}", params);
match RegisteredCommands::from_execute_command(params) {
Err(e) => {
return Err(crate::server::api::Error {
Expand All @@ -83,27 +84,31 @@ impl SyncRequestHandler for ExecuteCommand {
});
}
Ok(RegisteredCommands::OpenBamlPanel(args)) => {
session
let tx = session
.playground_tx
.send(PreLangServerToWasmMessage::FrontendMessage(
FrontendMessage::select_function {
// TODO: this can't be correct... but it looks like it is
root_path: args.project_id,
function_name: args.function_name,
},
))
.unwrap();
));
if let Err(e) = tx {
tracing::warn!("Error forwarding OpenBamlPanel to playground: {}", e);
}
}
Ok(RegisteredCommands::RunTest(args)) => {
session
let tx = session
.playground_tx
.send(PreLangServerToWasmMessage::FrontendMessage(
FrontendMessage::run_test {
function_name: args.function_name,
test_name: args.test_case_name,
},
))
.unwrap();
));
if let Err(e) = tx {
tracing::warn!("Error forwarding RunTest to playground: {}", e);
}
}
}

Expand Down
Loading