diff --git a/baml_language/crates/baml_compiler_hir/src/generics.rs b/baml_language/crates/baml_compiler_hir/src/generics.rs index 0a047a723c..a09d6c0f0b 100644 --- a/baml_language/crates/baml_compiler_hir/src/generics.rs +++ b/baml_language/crates/baml_compiler_hir/src/generics.rs @@ -4,8 +4,17 @@ //! from the `ItemTree` to maintain the invalidation barrier. Changes to generic //! parameters don't invalidate the `ItemTree`. -use baml_base::Name; +use std::sync::Arc; + +use baml_base::{Name, SourceFile}; +use baml_compiler_parser::syntax_tree; +use baml_compiler_syntax::ast; use la_arena::{Arena, Idx}; +use rowan::ast::AstNode; + +use crate::fqn::QualifiedName; +use crate::ids::{ItemKind, LocalIdAllocator}; +use crate::{ClassId, Db, EnumId, FunctionId, TypeAliasId}; /// Type parameter in a generic definition. /// @@ -53,3 +62,274 @@ impl std::ops::Index for GenericParams { &self.type_params[index] } } + +fn empty_generic_params() -> Arc { + Arc::new(GenericParams::new()) +} + +fn generic_params_from_list(list: Option) -> Arc { + let Some(list) = list else { + return empty_generic_params(); + }; + let mut params = GenericParams::new(); + for param_token in list.params() { + params.type_params.alloc(TypeParam { + name: Name::new(param_token.text()), + }); + } + Arc::new(params) +} + +fn with_source_file( + db: &dyn Db, + file: SourceFile, + f: impl FnOnce(ast::SourceFile) -> T, +) -> Option { + let tree = syntax_tree(db, file); + ast::SourceFile::cast(tree).map(f) +} + +fn item_name_and_list( + name: Option, + list: Option, +) -> Option<(Name, Option)> { + name.map(|name| (name, list)) +} + +fn match_generic_params( + allocator: &mut LocalIdAllocator, + kind: ItemKind, + name: &Name, + list: Option, + target_id: u32, +) -> Option> { + let id = allocator.alloc_id::<()>(kind, name); + (id.as_u32() == target_id).then(|| generic_params_from_list(list)) +} + +fn class_name_for_methods(class_node: &ast::ClassDef) -> String { + class_node + .name() + .map(|token| token.text().to_string()) + .unwrap_or_else(|| "UnnamedClass".to_string()) +} + +fn scan_client_resolve_for_generics( + allocator: &mut LocalIdAllocator, + client_node: &ast::ClientDef, + target_id: u32, +) -> Option> { + let name_token = client_node.name()?; + let resolve_name = Name::new(format!("{}.resolve", name_token.text())); + match_generic_params( + allocator, + ItemKind::Function, + &resolve_name, + None, + target_id, + ) +} + +fn find_top_level_generic_params( + source_file: &ast::SourceFile, + allocator: &mut LocalIdAllocator, + target_id: u32, + item_kind: ItemKind, + mut extract: impl FnMut(ast::Item) -> Option<(Name, Option)>, +) -> Option> { + for item in source_file.items() { + if let Some((name, list)) = extract(item) { + if let Some(params) = match_generic_params(allocator, item_kind, &name, list, target_id) + { + return Some(params); + } + } + } + None +} + +fn scan_function_item_for_generics( + allocator: &mut LocalIdAllocator, + func_node: &ast::FunctionDef, + target_id: u32, +) -> Option> { + let name_token = func_node.name()?; + let base_name = Name::new(name_token.text()); + + if let Some(params) = match_generic_params( + allocator, + ItemKind::Function, + &base_name, + func_node.generic_param_list(), + target_id, + ) { + return Some(params); + } + + func_node.llm_body()?; + + let render_name = Name::new(format!("{base_name}.render_prompt")); + if let Some(params) = match_generic_params( + allocator, + ItemKind::Function, + &render_name, + func_node.generic_param_list(), + target_id, + ) { + return Some(params); + } + + let build_name = Name::new(format!("{base_name}.build_request")); + match_generic_params( + allocator, + ItemKind::Function, + &build_name, + func_node.generic_param_list(), + target_id, + ) +} + +fn scan_class_methods_for_generics( + allocator: &mut LocalIdAllocator, + class_node: &ast::ClassDef, + target_id: u32, +) -> Option> { + let class_name = class_name_for_methods(class_node); + for method in class_node.methods() { + let Some(method_name) = method.name() else { + continue; + }; + let qualified_method_name = + QualifiedName::local_method_from_str(&class_name, method_name.text()); + let params = match_generic_params( + allocator, + ItemKind::Function, + &qualified_method_name, + method.generic_param_list(), + target_id, + ); + if params.is_some() { + return params; + } + } + None +} + +pub(crate) fn function_generic_params_from_cst( + db: &dyn Db, + func: FunctionId<'_>, +) -> Arc { + let file = func.file(db); + let target_id = func.id(db).as_u32(); + + let result = with_source_file(db, file, |source_file| { + let mut allocator = LocalIdAllocator::new(); + for item in source_file.items() { + let params = match item { + ast::Item::Function(func_node) => { + scan_function_item_for_generics(&mut allocator, &func_node, target_id) + } + ast::Item::Class(class_node) => { + scan_class_methods_for_generics(&mut allocator, &class_node, target_id) + } + ast::Item::Client(client_node) => { + scan_client_resolve_for_generics(&mut allocator, &client_node, target_id) + } + _ => None, + }; + if let Some(params) = params { + return params; + } + } + empty_generic_params() + }); + + result.unwrap_or_else(empty_generic_params) +} + +pub(crate) fn class_generic_params_from_cst(db: &dyn Db, class: ClassId<'_>) -> Arc { + let file = class.file(db); + let target_id = class.id(db).as_u32(); + + let result = with_source_file(db, file, |source_file| { + let mut allocator = LocalIdAllocator::new(); + find_top_level_generic_params( + &source_file, + &mut allocator, + target_id, + ItemKind::Class, + |item| { + if let ast::Item::Class(class_node) = item { + item_name_and_list( + class_node.name().map(|token| Name::new(token.text())), + class_node.generic_param_list(), + ) + } else { + None + } + }, + ) + }); + + result.flatten().unwrap_or_else(empty_generic_params) +} + +pub(crate) fn enum_generic_params_from_cst( + db: &dyn Db, + enum_def: EnumId<'_>, +) -> Arc { + let file = enum_def.file(db); + let target_id = enum_def.id(db).as_u32(); + + let result = with_source_file(db, file, |source_file| { + let mut allocator = LocalIdAllocator::new(); + find_top_level_generic_params( + &source_file, + &mut allocator, + target_id, + ItemKind::Enum, + |item| { + if let ast::Item::Enum(enum_node) = item { + item_name_and_list( + enum_node.name().map(|token| Name::new(token.text())), + enum_node.generic_param_list(), + ) + } else { + None + } + }, + ) + }); + + result.flatten().unwrap_or_else(empty_generic_params) +} + +pub(crate) fn type_alias_generic_params_from_cst( + db: &dyn Db, + alias: TypeAliasId<'_>, +) -> Arc { + let file = alias.file(db); + let target_id = alias.id(db).as_u32(); + + let result = with_source_file(db, file, |source_file| { + let mut allocator = LocalIdAllocator::new(); + find_top_level_generic_params( + &source_file, + &mut allocator, + target_id, + ItemKind::TypeAlias, + |item| { + if let ast::Item::TypeAlias(alias_node) = item { + item_name_and_list( + alias_node.name().map(|token| Name::new(token.text())), + alias_node.generic_param_list(), + ) + } else { + None + } + }, + ) + }); + + result.flatten().unwrap_or_else(empty_generic_params) +} diff --git a/baml_language/crates/baml_compiler_hir/src/ids.rs b/baml_language/crates/baml_compiler_hir/src/ids.rs index 5a459f33ec..34003a6efd 100644 --- a/baml_language/crates/baml_compiler_hir/src/ids.rs +++ b/baml_language/crates/baml_compiler_hir/src/ids.rs @@ -8,6 +8,9 @@ use std::marker::PhantomData; +use baml_base::Name; +use rustc_hash::FxHashMap; + /// Identifier for a class definition. pub use crate::loc::ClassLoc as ClassId; /// Identifier for a client configuration. @@ -188,3 +191,52 @@ pub enum ItemKind { TemplateString, RetryPolicy, } + +pub(crate) fn allocate_local_id( + next_index: &mut FxHashMap<(ItemKind, u16), u16>, + kind: ItemKind, + name: &Name, +) -> LocalItemId { + let hash = hash_name(name); + let index = next_index.entry((kind, hash)).or_insert(0); + let id = LocalItemId::new(hash, *index); + // Saturating add prevents silent wraparound in release builds. + // Reaching u16::MAX collisions for a single (kind, hash) bucket is + // practically impossible (requires 65 535 items with the same 16-bit + // name hash), so saturation is a safe sentinel rather than a recoverable + // error path. + debug_assert!( + *index < u16::MAX, + "LocalItemId collision index saturated for {name:?}" + ); + *index = index.saturating_add(1); + id +} + +/// Allocator for `LocalItemId`s with collision handling. +/// +/// Replays the same hashing and collision-indexing logic as `ItemTree`. +/// This allows queries to reproduce the same local IDs when scanning CST. +pub struct LocalIdAllocator { + next_index: FxHashMap<(ItemKind, u16), u16>, +} + +impl LocalIdAllocator { + /// Create a new allocator with empty collision state. + pub fn new() -> Self { + Self { + next_index: FxHashMap::default(), + } + } + + /// Allocate a `LocalItemId` for a named item, updating collision state. + pub fn alloc_id(&mut self, kind: ItemKind, name: &Name) -> LocalItemId { + allocate_local_id(&mut self.next_index, kind, name) + } +} + +impl Default for LocalIdAllocator { + fn default() -> Self { + Self::new() + } +} diff --git a/baml_language/crates/baml_compiler_hir/src/item_tree.rs b/baml_language/crates/baml_compiler_hir/src/item_tree.rs index 54ad6fa2b1..da7af64001 100644 --- a/baml_language/crates/baml_compiler_hir/src/item_tree.rs +++ b/baml_language/crates/baml_compiler_hir/src/item_tree.rs @@ -12,7 +12,7 @@ use rowan::TextRange; use rustc_hash::FxHashMap; use crate::{ - ids::{ItemKind, LocalItemId, hash_name}, + ids::{ItemKind, LocalItemId, allocate_local_id}, loc::{ ClassMarker, ClientMarker, EnumMarker, FunctionMarker, GeneratorMarker, RetryPolicyMarker, TemplateStringMarker, TestMarker, TypeAliasMarker, @@ -111,11 +111,7 @@ impl ItemTree { /// Allocate a collision-resistant ID for an item. /// Returns a `LocalItemId` with the name's hash and a unique collision index. fn alloc_id(&mut self, kind: ItemKind, name: &Name) -> LocalItemId { - let hash = hash_name(name); - let index = self.next_index.entry((kind, hash)).or_insert(0); - let id = LocalItemId::new(hash, *index); - *index += 1; - id + allocate_local_id(&mut self.next_index, kind, name) } /// Add a function and return its local ID. diff --git a/baml_language/crates/baml_compiler_hir/src/lib.rs b/baml_language/crates/baml_compiler_hir/src/lib.rs index 3578c2f3a1..a2776c04b8 100644 --- a/baml_language/crates/baml_compiler_hir/src/lib.rs +++ b/baml_language/crates/baml_compiler_hir/src/lib.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use baml_base::{FileId, Name, SourceFile, Span, TyAttr}; use baml_compiler_diagnostics::{HirDiagnostic, NameError}; use baml_compiler_parser::syntax_tree; -use baml_compiler_syntax::SyntaxNode; -use rowan::{SyntaxToken, TextRange, ast::AstNode}; +use baml_compiler_syntax::{SyntaxNode, SyntaxToken}; +use rowan::{TextRange, ast::AstNode}; // Module declarations mod body; @@ -210,34 +210,27 @@ pub fn project_items(db: &dyn Db, root: baml_workspace::Project) -> ProjectItems /// /// This is queried separately from `ItemTree` for incrementality - changes to /// generic parameters don't invalidate the `ItemTree`. -/// -/// For now, this returns empty generic parameters since BAML doesn't currently -/// parse generic syntax. Future work will extract `` from the CST. #[salsa::tracked] -pub fn function_generic_params(_db: &dyn Db, _func: FunctionId<'_>) -> Arc { - // TODO: Extract generic parameters from CST when BAML adds generic syntax - Arc::new(GenericParams::new()) +pub fn function_generic_params(db: &dyn Db, func: FunctionId<'_>) -> Arc { + generics::function_generic_params_from_cst(db, func) } /// Tracked: Get generic parameters for a class. #[salsa::tracked] -pub fn class_generic_params(_db: &dyn Db, _class: ClassId<'_>) -> Arc { - // TODO: Extract generic parameters from CST when BAML adds generic syntax - Arc::new(GenericParams::new()) +pub fn class_generic_params(db: &dyn Db, class: ClassId<'_>) -> Arc { + generics::class_generic_params_from_cst(db, class) } /// Tracked: Get generic parameters for an enum. #[salsa::tracked] -pub fn enum_generic_params(_db: &dyn Db, _enum: EnumId<'_>) -> Arc { - // TODO: Extract generic parameters from CST when BAML adds generic syntax - Arc::new(GenericParams::new()) +pub fn enum_generic_params(db: &dyn Db, enum_def: EnumId<'_>) -> Arc { + generics::enum_generic_params_from_cst(db, enum_def) } /// Tracked: Get generic parameters for a type alias. #[salsa::tracked] -pub fn type_alias_generic_params(_db: &dyn Db, _alias: TypeAliasId<'_>) -> Arc { - // TODO: Extract generic parameters from CST when BAML adds generic syntax - Arc::new(GenericParams::new()) +pub fn type_alias_generic_params(db: &dyn Db, alias: TypeAliasId<'_>) -> Arc { + generics::type_alias_generic_params_from_cst(db, alias) } // @@ -1909,7 +1902,6 @@ fn lower_item(tree: &mut ItemTree, node: &SyntaxNode, ctx: &mut LoweringContext) }), }; tree.alloc_function(resolve_fn); - tree.alloc_client(c); } } diff --git a/baml_language/crates/baml_compiler_parser/src/parser.rs b/baml_language/crates/baml_compiler_parser/src/parser.rs index 6a94f9c438..d58a485d3f 100644 --- a/baml_language/crates/baml_compiler_parser/src/parser.rs +++ b/baml_language/crates/baml_compiler_parser/src/parser.rs @@ -1833,6 +1833,10 @@ impl<'a> Parser<'a> { p.error_unexpected_token("enum name".to_string()); } + if p.at(TokenKind::Less) { + p.parse_generic_param_list(); + } + // Opening brace if !p.expect(TokenKind::LBrace) { return; // Error recovery: stop here @@ -1954,6 +1958,10 @@ impl<'a> Parser<'a> { p.error_unexpected_token("class name".to_string()); } + if p.at(TokenKind::Less) { + p.parse_generic_param_list(); + } + // Opening brace if !p.expect(TokenKind::LBrace) { return; @@ -2045,6 +2053,10 @@ impl<'a> Parser<'a> { } } + if p.at(TokenKind::Less) { + p.parse_generic_param_list(); + } + // Check for old-style function syntax: `function Name {` (without parens and return type) // If we see '{' directly after the name, emit a single helpful error and skip to body if p.at(TokenKind::LBrace) { @@ -2088,6 +2100,43 @@ impl<'a> Parser<'a> { }); } + fn parse_generic_param_list(&mut self) { + self.with_node(SyntaxKind::GENERIC_PARAM_LIST, |p| { + p.expect(TokenKind::Less); + + if p.at(TokenKind::Greater) || p.at(TokenKind::GreaterGreater) { + p.error_unexpected_token("generic parameter".to_string()); + p.expect_greater(); + return; + } + + if p.at(TokenKind::Word) { + p.bump(); + } else { + p.error_unexpected_token("generic parameter".to_string()); + if !p.at(TokenKind::Greater) && !p.at(TokenKind::GreaterGreater) { + p.bump(); + } + } + + while p.eat(TokenKind::Comma) { + if p.at(TokenKind::Greater) || p.at(TokenKind::GreaterGreater) { + break; + } + if p.at(TokenKind::Word) { + p.bump(); + } else { + p.error_unexpected_token("generic parameter".to_string()); + if !p.at(TokenKind::Greater) && !p.at(TokenKind::GreaterGreater) { + p.bump(); + } + } + } + + p.expect_greater(); + }); + } + fn parse_parameter_list(&mut self) { self.with_node(SyntaxKind::PARAMETER_LIST, |p| { p.expect(TokenKind::LParen); @@ -4277,6 +4326,10 @@ impl<'a> Parser<'a> { p.error_unexpected_token("type alias name".to_string()); } + if p.at(TokenKind::Less) { + p.parse_generic_param_list(); + } + // Equals p.expect(TokenKind::Equals); @@ -4368,6 +4421,7 @@ mod tests { use baml_base::FileId; use baml_compiler_lexer::lex_lossless; use baml_compiler_syntax::{SyntaxKind, SyntaxNode}; + use rowan::NodeOrToken; use super::{ParseError, parse_file}; @@ -4384,6 +4438,12 @@ mod tests { ); } + fn parse_source_no_errors(source: &str) -> SyntaxNode { + let (root, errors) = parse_source(source); + assert_no_errors(&errors); + root + } + #[test] fn error_on_parameter_type_without_colon() { // When the user writes `x int` instead of `x: int`, the parser should @@ -4763,4 +4823,65 @@ function Demo() -> int { assert_eq!(child_kinds[0], SyntaxKind::THROW_EXPR); assert_eq!(child_kinds[1], SyntaxKind::CATCH_CLAUSE); } + + fn generic_params_in(source: &str) -> Vec { + let root = parse_source_no_errors(source); + root.descendants() + .find(|n| n.kind() == SyntaxKind::GENERIC_PARAM_LIST) + .map(|node| { + node.children_with_tokens() + .filter_map(NodeOrToken::into_token) + .filter(|t| t.kind() == SyntaxKind::WORD) + .map(|t| t.text().to_string()) + .collect() + }) + .unwrap_or_default() + } + + #[test] + fn parses_function_generic_params() { + let source = "function Foo(x: T) -> T { return x; }"; + let params = generic_params_in(source); + assert_eq!(params, vec!["T", "U"]); + } + + #[test] + fn parses_class_generic_params() { + let source = "class Box { value: T }"; + let params = generic_params_in(source); + assert_eq!(params, vec!["T"]); + } + + #[test] + fn parses_type_alias_generic_params() { + let source = "type Id = T;"; + let params = generic_params_in(source); + assert_eq!(params, vec!["T"]); + } + + #[test] + fn parses_enum_generic_params() { + let source = "enum Option { Some\n None }"; + let params = generic_params_in(source); + assert_eq!(params, vec!["T"]); + } + + #[test] + fn recovers_from_invalid_first_generic_param() { + // Parser should skip invalid tokens and still collect valid params. + // This test intentionally produces parse errors, so we parse directly. + let (root, _errors) = parse_source("function Bad<123, T>(x: T) -> T { return x; }"); + let params: Vec = root + .descendants() + .find(|n| n.kind() == SyntaxKind::GENERIC_PARAM_LIST) + .map(|node| { + node.children_with_tokens() + .filter_map(NodeOrToken::into_token) + .filter(|t| t.kind() == SyntaxKind::WORD) + .map(|t| t.text().to_string()) + .collect() + }) + .unwrap_or_default(); + assert_eq!(params, vec!["T"]); + } } diff --git a/baml_language/crates/baml_compiler_syntax/src/ast.rs b/baml_language/crates/baml_compiler_syntax/src/ast.rs index d37759df0f..017cf1dbe2 100644 --- a/baml_language/crates/baml_compiler_syntax/src/ast.rs +++ b/baml_language/crates/baml_compiler_syntax/src/ast.rs @@ -88,6 +88,7 @@ ast_node!(RetryPolicyDef, RETRY_POLICY_DEF); ast_node!(TemplateStringDef, TEMPLATE_STRING_DEF); ast_node!(TypeAliasDef, TYPE_ALIAS_DEF); +ast_node!(GenericParamList, GENERIC_PARAM_LIST); ast_node!(ParameterList, PARAMETER_LIST); ast_node!(Parameter, PARAMETER); ast_node!(FunctionBody, FUNCTION_BODY); @@ -753,6 +754,11 @@ impl FunctionDef { .nth(0) // Get the first WORD (function keyword is KW_FUNCTION, not WORD) } + /// Get the generic parameter list. + pub fn generic_param_list(&self) -> Option { + self.syntax.children().find_map(GenericParamList::cast) + } + /// Get the parameter list. pub fn param_list(&self) -> Option { self.syntax.children().find_map(ParameterList::cast) @@ -1007,6 +1013,16 @@ impl Parameter { } } +impl GenericParamList { + /// Get all generic parameter tokens. + pub fn params(&self) -> impl Iterator { + self.syntax + .children_with_tokens() + .filter_map(rowan::NodeOrToken::into_token) + .filter(|token| token.kind() == SyntaxKind::WORD) + } +} + impl ParameterList { /// Get all parameters. pub fn params(&self) -> impl Iterator { @@ -1026,6 +1042,11 @@ impl ClassDef { .nth(0) // Get the first WORD (class keyword is KW_CLASS, not WORD) } + /// Get the generic parameter list. + pub fn generic_param_list(&self) -> Option { + self.syntax.children().find_map(GenericParamList::cast) + } + /// Get all fields. pub fn fields(&self) -> impl Iterator { self.syntax.children().filter_map(Field::cast) @@ -1074,6 +1095,11 @@ impl EnumDef { .nth(0) // Get the first WORD (enum keyword is KW_ENUM, not WORD) } + /// Get the generic parameter list. + pub fn generic_param_list(&self) -> Option { + self.syntax.children().find_map(GenericParamList::cast) + } + /// Check if this enum has a body (braces). /// Malformed enums from error recovery may not have braces. pub fn has_body(&self) -> bool { @@ -1518,6 +1544,11 @@ impl TypeAliasDef { .nth(1) // Skip "type" keyword (which is a WORD), get the actual name } + /// Get the generic parameter list. + pub fn generic_param_list(&self) -> Option { + self.syntax.children().find_map(GenericParamList::cast) + } + /// Get the aliased type expression. pub fn ty(&self) -> Option { self.syntax.children().find_map(TypeExpr::cast) diff --git a/baml_language/crates/baml_compiler_syntax/src/syntax_kind.rs b/baml_language/crates/baml_compiler_syntax/src/syntax_kind.rs index 498108fe9e..1afbb35d62 100644 --- a/baml_language/crates/baml_compiler_syntax/src/syntax_kind.rs +++ b/baml_language/crates/baml_compiler_syntax/src/syntax_kind.rs @@ -148,6 +148,7 @@ pub enum SyntaxKind { DYNAMIC_TYPE_DEF, // dynamic class/enum inside type_builder blocks // Function components + GENERIC_PARAM_LIST, PARAMETER_LIST, PARAMETER, FUNCTION_BODY, diff --git a/baml_language/crates/baml_tests/src/incremental/scenarios.rs b/baml_language/crates/baml_tests/src/incremental/scenarios.rs index c915a6cfc5..2ab7f104df 100644 --- a/baml_language/crates/baml_tests/src/incremental/scenarios.rs +++ b/baml_language/crates/baml_tests/src/incremental/scenarios.rs @@ -3,13 +3,164 @@ //! These tests verify that editing BAML files only recomputes the necessary //! queries, demonstrating Salsa's "early cutoff" optimization. -use baml_compiler_hir::{FunctionLoc, function_body, function_signature}; +use std::collections::HashMap; + +use baml_base::Name; +use baml_compiler_hir::{ + FunctionLoc, GenericParams, QualifiedName, class_generic_params, enum_generic_params, + function_body, function_generic_params, function_signature, hash_name, + type_alias_generic_params, +}; use baml_compiler_tir::function_type_inference; use baml_db::{SourceFile, baml_compiler_hir}; use salsa::Setter; use super::IncrementalTestDb; +fn param_names(params: &GenericParams) -> Vec { + params + .type_param_names() + .map(|name| name.as_str().to_string()) + .collect() +} + +fn find_item_by_name<'db, T>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, + kind: &str, + mut pick: impl FnMut(&baml_compiler_hir::ItemId<'db>, &baml_compiler_hir::ItemTree) -> Option, +) -> T { + let item_tree = baml_compiler_hir::file_item_tree(db, file); + let items = baml_compiler_hir::file_items(db, file); + match items + .items(db) + .iter() + .find_map(|item| pick(item, &item_tree)) + { + Some(value) => value, + None => panic!("{kind} not found: {name}"), + } +} + +fn collect_items<'db, T>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + mut pick: impl FnMut(&baml_compiler_hir::ItemId<'db>, &baml_compiler_hir::ItemTree) -> Option, +) -> Vec { + let item_tree = baml_compiler_hir::file_item_tree(db, file); + let items = baml_compiler_hir::file_items(db, file); + items + .items(db) + .iter() + .filter_map(|item| pick(item, &item_tree)) + .collect() +} + +fn find_function_by_name<'db>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, +) -> FunctionLoc<'db> { + let funcs = find_functions_by_name(db, file, name); + match funcs.first().copied() { + Some(func) => func, + None => panic!("function not found: {name}"), + } +} + +fn find_functions_by_name<'db>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, +) -> Vec> { + let mut funcs: Vec<_> = collect_items(db, file, |item, item_tree| match item { + baml_compiler_hir::ItemId::Function(func_id) => { + let func = &item_tree[func_id.id(db)]; + if func.compiler_generated.is_none() && func.name.as_str() == name { + Some(*func_id) + } else { + None + } + } + _ => None, + }); + funcs.sort_by_key(|loc| loc.id(db).index()); + funcs +} + +fn find_class_by_name<'db>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, +) -> baml_compiler_hir::ClassLoc<'db> { + find_item_by_name(db, file, name, "class", |item, item_tree| match item { + baml_compiler_hir::ItemId::Class(class_id) => { + let class_def = &item_tree[class_id.id(db)]; + if class_def.name.as_str() == name { + Some(*class_id) + } else { + None + } + } + _ => None, + }) +} + +fn find_enum_by_name<'db>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, +) -> baml_compiler_hir::EnumLoc<'db> { + find_item_by_name(db, file, name, "enum", |item, item_tree| match item { + baml_compiler_hir::ItemId::Enum(enum_id) => { + let enum_def = &item_tree[enum_id.id(db)]; + if enum_def.name.as_str() == name { + Some(*enum_id) + } else { + None + } + } + _ => None, + }) +} + +fn find_type_alias_by_name<'db>( + db: &'db baml_project::ProjectDatabase, + file: SourceFile, + name: &str, +) -> baml_compiler_hir::TypeAliasLoc<'db> { + find_item_by_name(db, file, name, "type alias", |item, item_tree| match item { + baml_compiler_hir::ItemId::TypeAlias(alias_id) => { + let alias = &item_tree[alias_id.id(db)]; + if alias.name.as_str() == name { + Some(*alias_id) + } else { + None + } + } + _ => None, + }) +} + +fn find_name_hash_collision(prefix: &str) -> (String, String) { + // Pigeonhole: with 2^16 + 1 candidates and only 2^16 possible hash values, + // a collision is guaranteed by the pigeonhole principle. + let mut seen: HashMap = HashMap::new(); + for i in 0u32..=(u16::MAX as u32 + 1) { + let candidate = format!("{prefix}{i}"); + let hash = hash_name(&Name::new(&candidate)); + if let Some(existing) = seen.get(&hash) { + if existing != &candidate { + return (existing.clone(), candidate); + } + } else { + seen.insert(hash, candidate); + } + } + unreachable!("pigeonhole guarantees a collision within 2^16 + 1 candidates") +} + /// Query all function bodies in a file. /// This is a helper to avoid manually extracting function IDs in tests. fn query_all_function_bodies(db: &baml_project::ProjectDatabase, file: SourceFile) { @@ -31,6 +182,130 @@ fn query_all_function_signatures(db: &baml_project::ProjectDatabase, file: Sourc } } +#[test] +fn generic_params_from_items_and_methods() { + let mut test_db = IncrementalTestDb::new(); + let file = test_db.db_mut().add_file( + "test.baml", + r#" +function Foo(x: T) -> U { + return x; +} + +function NoGenerics(x: int) -> int { + return x; +} + +class Box { + value T + + function map(self, x: U) -> U { + return x; + } +} + +enum Option { + Some + None +} + +type Id = T +"#, + ); + + let db = test_db.db(); + + let foo_id = find_function_by_name(db, file, "Foo"); + let foo_params = function_generic_params(db, foo_id); + assert_eq!( + param_names(&foo_params), + vec!["T".to_string(), "U".to_string()] + ); + + let no_generics_id = find_function_by_name(db, file, "NoGenerics"); + let no_generics_params = function_generic_params(db, no_generics_id); + assert!(no_generics_params.is_empty()); + + let class_id = find_class_by_name(db, file, "Box"); + let class_params = class_generic_params(db, class_id); + assert_eq!(param_names(&class_params), vec!["T".to_string()]); + + let method_name = QualifiedName::local_method_from_str("Box", "map"); + let method_id = find_function_by_name(db, file, method_name.as_str()); + let method_params = function_generic_params(db, method_id); + assert_eq!(param_names(&method_params), vec!["U".to_string()]); + + let enum_id = find_enum_by_name(db, file, "Option"); + let enum_params = enum_generic_params(db, enum_id); + assert_eq!(param_names(&enum_params), vec!["T".to_string()]); + + let alias_id = find_type_alias_by_name(db, file, "Id"); + let alias_params = type_alias_generic_params(db, alias_id); + assert_eq!(param_names(&alias_params), vec!["T".to_string()]); +} + +#[test] +fn generic_params_respect_duplicate_name_indices() { + let mut test_db = IncrementalTestDb::new(); + let file = test_db.db_mut().add_file( + "test.baml", + r#" +function Dup(x: T) -> T { + return x; +} + +function Dup(x: U) -> U { + return x; +} +"#, + ); + + let db = test_db.db(); + let dups = find_functions_by_name(db, file, "Dup"); + assert_eq!(dups.len(), 2); + + let first_params = function_generic_params(db, dups[0]); + let second_params = function_generic_params(db, dups[1]); + + assert_eq!(param_names(&first_params), vec!["T".to_string()]); + assert_eq!(param_names(&second_params), vec!["U".to_string()]); +} + +#[test] +fn generic_params_handle_hash_collisions() { + let mut test_db = IncrementalTestDb::new(); + let (first_name, second_name) = find_name_hash_collision("Collide"); + assert_ne!(first_name, second_name); + assert_eq!( + hash_name(&Name::new(&first_name)), + hash_name(&Name::new(&second_name)) + ); + + let contents = format!( + r#" +function {first_name}(x: T) -> T {{ + return x; +}} + +function {second_name}(x: U) -> U {{ + return x; +}} +"#, + ); + + let file = test_db.db_mut().add_file("test.baml", &contents); + let db = test_db.db(); + + let first_id = find_function_by_name(db, file, &first_name); + let second_id = find_function_by_name(db, file, &second_name); + + let first_params = function_generic_params(db, first_id); + let second_params = function_generic_params(db, second_id); + + assert_eq!(param_names(&first_params), vec!["T".to_string()]); + assert_eq!(param_names(&second_params), vec!["U".to_string()]); +} + /// Test that editing a function body doesn't invalidate the item tree. /// /// The ItemTree only contains function names, not bodies. So changing a @@ -471,18 +746,13 @@ function NewName(x: string) -> string { /// Helper to get all function locations from a file. fn get_function_locs(db: &baml_project::ProjectDatabase, file: SourceFile) -> Vec> { - let items = baml_compiler_hir::file_items(db, file); - items - .items(db) - .iter() - .filter_map(|item| { - if let baml_compiler_hir::ItemId::Function(func_id) = item { - Some(*func_id) - } else { - None - } - }) - .collect() + collect_items(db, file, |item, _item_tree| { + if let baml_compiler_hir::ItemId::Function(func_id) = item { + Some(*func_id) + } else { + None + } + }) } /// Query type inference for all functions in a file. @@ -777,18 +1047,12 @@ function Greet(name: string) -> string { "## .to_string()); - // Adding a class changes the project-level type context, which may - // invalidate type inference. This test documents current behavior. - // Ideally we'd achieve early cutoff here too. - let (_, executed) = test_db.log_executed(|db| query_all_type_inference(db, file)); - let inference_count = executed - .iter() - .filter(|s| s.contains("function_type_inference")) - .count(); - - // Document behavior - this may be 0 (ideal) or 1 (acceptable) - println!( - "Type inference re-executions after adding unrelated class: {}", - inference_count + // Adding a class invalidates the project-level typing context, which + // currently causes all 3 functions (Greet + its 2 LLM-generated variants) + // to re-run type inference. The ideal count is 0 — once Salsa early-cutoff + // is implemented for project-level context changes, update this to 0. + test_db.assert_executed( + |db| query_all_type_inference(db, file), + &[("function_type_inference", 3)], ); }