diff --git a/baml_language/Cargo.lock b/baml_language/Cargo.lock index 1a3cdf431e..a68e52ec13 100644 --- a/baml_language/Cargo.lock +++ b/baml_language/Cargo.lock @@ -562,6 +562,7 @@ name = "baml_compiler_ppir" version = "0.0.0-beta" dependencies = [ "baml_base", + "baml_compiler_diagnostics", "baml_compiler_parser", "baml_compiler_syntax", "baml_workspace", @@ -786,6 +787,7 @@ dependencies = [ "baml_compiler_emit", "baml_compiler_hir", "baml_compiler_mir", + "baml_compiler_ppir", "baml_compiler_tir", "baml_compiler_vir", "baml_db", diff --git a/baml_language/crates/baml_base/src/files.rs b/baml_language/crates/baml_base/src/files.rs index 1b7f89e45e..be43339582 100644 --- a/baml_language/crates/baml_base/src/files.rs +++ b/baml_language/crates/baml_base/src/files.rs @@ -1,29 +1,109 @@ //! File management with Salsa 2022 API. //! //! Defines the core structures for accessing file contents and paths. +//! +//! ## Virtual file classification +//! +//! The compiler uses two categories of non-user files, identified by path prefix: +//! +//! | Category | Prefix | Example | +//! |-----------|-----------------|--------------------------------------| +//! | Built-in | `/` | `/baml/llm.baml` | +//! | Generated | `` | +//! +//! The canonical prefix constants live in this file. +//! Other crates should classify files through the [`SourceFile::is_builtin`], +//! [`SourceFile::is_generated`], and [`SourceFile::is_virtual`] methods, and +//! only use [`SourceFile::builtin_path_prefix`] / +//! [`SourceFile::generated_path_prefix`] when they must construct or strip +//! virtual paths. use std::path::PathBuf; use crate::FileId; -/// Input structure representing a source file in the compilation. +// ── Canonical virtual-file path prefixes ────────────────────────────────────── +const BUILTIN_PATH_PREFIX: &str = "/"; +const GENERATED_PATH_PREFIX: &str = "/`) define internal standard-library + /// types and functions. They are excluded from user-facing validation and + /// diagnostic reporting. + pub fn is_builtin(self, db: &dyn salsa::Database) -> bool { + self.has_prefix(db, BUILTIN_PATH_PREFIX) + } + + /// Returns `true` if this file was synthesized by a compiler pass. + /// + /// Generated files (path prefix ` bool { + self.has_prefix(db, GENERATED_PATH_PREFIX) + } + + /// Returns `true` if this file is virtual, i.e. either built-in or generated. + /// + /// Virtual files are not user-authored and should be skipped by passes that + /// operate only on user source (e.g. name validation, stream diagnostics). + pub fn is_virtual(self, db: &dyn salsa::Database) -> bool { + self.has_prefix(db, BUILTIN_PATH_PREFIX) || self.has_prefix(db, GENERATED_PATH_PREFIX) + } + + // ── Prefix accessors ──────────────────────────────────────────────────── + + /// Returns the path prefix string used for built-in files (`"/"`). + /// + /// Use this only when you need to **construct** or **strip** a built-in + /// path (e.g. in `baml_builtins` or `baml_compiler_hir::file_namespace`). + /// For classification, prefer [`is_builtin`](Self::is_builtin). + pub fn builtin_path_prefix() -> &'static str { + BUILTIN_PATH_PREFIX + } + + /// Returns the path prefix string used for generated files (`" &'static str { + GENERATED_PATH_PREFIX + } + + // ── private ────────────────────────────────────────────────────────────── + + fn has_prefix(self, db: &dyn salsa::Database, prefix: &str) -> bool { + self.path(db).to_string_lossy().starts_with(prefix) + } +} diff --git a/baml_language/crates/baml_builtins/src/lib.rs b/baml_language/crates/baml_builtins/src/lib.rs index 648c1510e6..baf4654577 100644 --- a/baml_language/crates/baml_builtins/src/lib.rs +++ b/baml_language/crates/baml_builtins/src/lib.rs @@ -823,7 +823,10 @@ mod tests { // ============================================================================ /// Builtin BAML source files for built-in functions. -pub const BUILTIN_PATH_PREFIX: &str = "/"; +/// The path prefix for builtin files, sourced from `baml_base::SourceFile`. +pub fn builtin_path_prefix() -> &'static str { + baml_base::SourceFile::builtin_path_prefix() +} /// /// These files are compiled together with user code and provide /// implementations for builtin namespaces like `baml.llm`. @@ -903,16 +906,17 @@ pub fn baml_sources() -> impl Iterator PackageInfo { let path = file.path(db); let path_str = path.to_string_lossy(); - if let Some(relative) = path_str.strip_prefix("/") { + if let Some(relative) = path_str.strip_prefix(baml_base::SourceFile::builtin_path_prefix()) { // /baml/llm.baml → package "baml", namespace ["llm"] // /env.baml → package "env", namespace [] let segments: Vec<&str> = relative.split('/').collect(); diff --git a/baml_language/crates/baml_compiler_diagnostics/src/compiler_error.rs b/baml_language/crates/baml_compiler_diagnostics/src/compiler_error.rs index 37eb118a65..307b5d704a 100644 --- a/baml_language/crates/baml_compiler_diagnostics/src/compiler_error.rs +++ b/baml_language/crates/baml_compiler_diagnostics/src/compiler_error.rs @@ -148,6 +148,7 @@ const DUPLICATE_VARIANT: ErrorCode = ErrorCode(13); const DUPLICATE_ATTRIBUTE: ErrorCode = ErrorCode(14); const UNKNOWN_ATTRIBUTE: ErrorCode = ErrorCode(15); const INVALID_ATTRIBUTE_CONTEXT: ErrorCode = ErrorCode(16); +const CONFLICTING_STREAM_ATTRIBUTES: ErrorCode = ErrorCode(101); // Generator diagnostics const UNKNOWN_GENERATOR_PROPERTY: ErrorCode = ErrorCode(17); diff --git a/baml_language/crates/baml_compiler_diagnostics/src/compiler_error/error_format.rs b/baml_language/crates/baml_compiler_diagnostics/src/compiler_error/error_format.rs index 017506c796..e6a939ba8b 100644 --- a/baml_language/crates/baml_compiler_diagnostics/src/compiler_error/error_format.rs +++ b/baml_language/crates/baml_compiler_diagnostics/src/compiler_error/error_format.rs @@ -2,9 +2,9 @@ use ariadne::{Label, ReportBuilder}; use baml_base::Span; use super::{ - ARGUMENT_COUNT_MISMATCH, CompilerError, DUPLICATE_ATTRIBUTE, DUPLICATE_FIELD, DUPLICATE_NAME, - DUPLICATE_VARIANT, ErrorCode, FIELD_NAME_MATCHES_TYPE_NAME, HTTP_CONFIG_NOT_BLOCK, - HirDiagnostic, INVALID_ATTRIBUTE_CONTEXT, INVALID_CLIENT_RESPONSE_TYPE, + ARGUMENT_COUNT_MISMATCH, CONFLICTING_STREAM_ATTRIBUTES, CompilerError, DUPLICATE_ATTRIBUTE, + DUPLICATE_FIELD, DUPLICATE_NAME, DUPLICATE_VARIANT, ErrorCode, FIELD_NAME_MATCHES_TYPE_NAME, + HTTP_CONFIG_NOT_BLOCK, HirDiagnostic, INVALID_ATTRIBUTE_CONTEXT, INVALID_CLIENT_RESPONSE_TYPE, INVALID_CONSTRAINT_SYNTAX, INVALID_GENERATOR_PROPERTY_VALUE, INVALID_OPERATOR, MISSING_CONDITION_PARENS, MISSING_GENERATOR_PROPERTY, MISSING_PROVIDER, MISSING_RETURN_EXPRESSION, MISSING_SEMICOLON, NEGATIVE_TIMEOUT, NO_SUCH_FIELD, @@ -334,6 +334,20 @@ where span, INVALID_ATTRIBUTE_CONTEXT, ), + HirDiagnostic::ConflictingStreamAttributes { + first_attr, + second_attr, + first_span, + second_span, + } => ( + Report::build(ReportKind::Error, second_span) + .with_message(format!( + "Conflicting stream attributes '@{first_attr}' and '@{second_attr}'" + )) + .with_label(Label::new(second_span).with_message("conflicting attribute")) + .with_label(Label::new(first_span).with_message("first attribute here")), + CONFLICTING_STREAM_ATTRIBUTES, + ), HirDiagnostic::UnknownGeneratorProperty { generator_name, property_name, diff --git a/baml_language/crates/baml_compiler_diagnostics/src/diagnostic.rs b/baml_language/crates/baml_compiler_diagnostics/src/diagnostic.rs index 60da291fa8..5975ff1e4b 100644 --- a/baml_language/crates/baml_compiler_diagnostics/src/diagnostic.rs +++ b/baml_language/crates/baml_compiler_diagnostics/src/diagnostic.rs @@ -117,9 +117,10 @@ pub enum DiagnosticId { // Constraint attribute errors (E0032) InvalidConstraintSyntax, - // Attribute value errors (E0037-E0038) + // Attribute value errors (E0037-E0038, E0101) InvalidAttributeArg, UnexpectedAttributeArg, + ConflictingStreamAttributes, // Type literal errors (E0033) UnsupportedFloatLiteral, @@ -254,6 +255,7 @@ impl DiagnosticId { // Attribute value errors DiagnosticId::InvalidAttributeArg => "E0037", DiagnosticId::UnexpectedAttributeArg => "E0038", + DiagnosticId::ConflictingStreamAttributes => "E0101", // Type literal errors DiagnosticId::UnsupportedFloatLiteral => "E0033", diff --git a/baml_language/crates/baml_compiler_diagnostics/src/errors/hir_diagnostic.rs b/baml_language/crates/baml_compiler_diagnostics/src/errors/hir_diagnostic.rs index 004d5194ff..11b6a7c023 100644 --- a/baml_language/crates/baml_compiler_diagnostics/src/errors/hir_diagnostic.rs +++ b/baml_language/crates/baml_compiler_diagnostics/src/errors/hir_diagnostic.rs @@ -65,7 +65,7 @@ pub enum HirDiagnostic { UnknownAttribute { attr_name: String, span: Span, - valid_attributes: Vec<&'static str>, + valid_attributes: &'static [&'static str], }, /// Attribute used in wrong context. @@ -245,6 +245,14 @@ pub enum HirDiagnostic { /// Attribute takes no arguments but received some (e.g., @@dynamic("unexpected")). UnexpectedAttributeArg { attr_name: String, span: Span }, + /// Conflicting stream attributes on the same type expression. + ConflictingStreamAttributes { + first_attr: &'static str, + second_attr: &'static str, + first_span: Span, + second_span: Span, + }, + // ============ Type Diagnostics ============ /// Float literal used as a type, which is not supported. UnsupportedFloatLiteral { value: String, span: Span }, diff --git a/baml_language/crates/baml_compiler_diagnostics/src/to_diagnostic.rs b/baml_language/crates/baml_compiler_diagnostics/src/to_diagnostic.rs index 2d860bb0b3..d7f3b66e4e 100644 --- a/baml_language/crates/baml_compiler_diagnostics/src/to_diagnostic.rs +++ b/baml_language/crates/baml_compiler_diagnostics/src/to_diagnostic.rs @@ -1001,6 +1001,20 @@ impl ToDiagnostic for HirDiagnostic { ) .with_primary(*span, "unexpected argument"), + HirDiagnostic::ConflictingStreamAttributes { + first_attr, + second_attr, + first_span, + second_span, + } => Diagnostic::error( + DiagnosticId::ConflictingStreamAttributes, + format!( + "Conflicting stream attributes `@{first_attr}` and `@{second_attr}`." + ), + ) + .with_primary(*second_span, "conflicting attribute") + .with_secondary(*first_span, "first attribute here"), + HirDiagnostic::UnsupportedFloatLiteral { value, span } => Diagnostic::error( DiagnosticId::UnsupportedFloatLiteral, format!("Float literal values are not supported: {value}"), diff --git a/baml_language/crates/baml_compiler_hir/src/lib.rs b/baml_language/crates/baml_compiler_hir/src/lib.rs index a8de095482..6947098920 100644 --- a/baml_language/crates/baml_compiler_hir/src/lib.rs +++ b/baml_language/crates/baml_compiler_hir/src/lib.rs @@ -287,12 +287,6 @@ pub fn function_signature_source_map<'db>( source_map } -/// The prefix used for builtin BAML files. -/// -/// Files with paths starting with this prefix are treated as builtins -/// and their functions are namespaced accordingly. -pub const BUILTIN_PATH_PREFIX: &str = "/"; - /// Derive the namespace for a file based on its path. /// /// Builtin files (paths starting with `/`) get namespaced: @@ -316,12 +310,13 @@ pub fn file_namespace(db: &dyn Db, file: SourceFile) -> Option { let path = file.path(db); let path_str = path.to_string_lossy(); - if !path_str.starts_with(BUILTIN_PATH_PREFIX) { + if !file.is_builtin(db) { return None; } // Extract path after prefix: "/baml/llm.baml" -> "baml/llm.baml" - let after_prefix = &path_str[BUILTIN_PATH_PREFIX.len()..]; + let prefix_len = SourceFile::builtin_path_prefix().len(); + let after_prefix = &path_str[prefix_len..]; // Remove .baml extension and split by / let without_ext = after_prefix.strip_suffix(".baml").unwrap_or(after_prefix); diff --git a/baml_language/crates/baml_compiler_ppir/Cargo.toml b/baml_language/crates/baml_compiler_ppir/Cargo.toml index ee9c0264df..507c004fc1 100644 --- a/baml_language/crates/baml_compiler_ppir/Cargo.toml +++ b/baml_language/crates/baml_compiler_ppir/Cargo.toml @@ -18,6 +18,7 @@ workspace = true [dependencies] baml_base = { workspace = true } +baml_compiler_diagnostics = { workspace = true } baml_compiler_parser = { workspace = true } baml_compiler_syntax = { workspace = true } baml_workspace = { workspace = true } diff --git a/baml_language/crates/baml_compiler_ppir/src/desugar.rs b/baml_language/crates/baml_compiler_ppir/src/desugar.rs index a7f056b678..42b322c299 100644 --- a/baml_language/crates/baml_compiler_ppir/src/desugar.rs +++ b/baml_language/crates/baml_compiler_ppir/src/desugar.rs @@ -9,7 +9,7 @@ use baml_compiler_syntax::{GreenNode, SyntaxNode}; use smol_str::SmolStr; use crate::{ - PpirNames, + PpirNames, StreamAttrKind, parse_stream_attr, ty::{PpirField, PpirTy, PpirTypeAttrs}, }; @@ -256,11 +256,17 @@ pub(crate) fn build_ppir_fields(class_def: &baml_compiler_syntax::ast::ClassDef) // ones (@stream.starts_as, @stream.not_null). if let Some(type_expr) = field_node.ty() { for attr in type_expr.attributes() { - if let Some(attr_name) = attr.full_name() { - match attr_name.as_str() { - "stream.starts_as" => starts_as = attr.arg_syntax_node(), - "stream.not_null" => not_null = true, - _ => {} + let Some(attr_name) = attr.full_name() else { + continue; + }; + let Some(kind) = parse_stream_attr(attr_name.as_str()) else { + continue; + }; + + match kind { + StreamAttrKind::StartsAs => starts_as = attr.arg_syntax_node(), + StreamAttrKind::NotNull => not_null = true, + StreamAttrKind::Done | StreamAttrKind::Type | StreamAttrKind::WithState => { } } } @@ -311,7 +317,7 @@ pub(crate) fn desugar_field( attrs: PpirTypeAttrs::default(), }); PpirStreamStartsAs::Explicit { green, typeof_s } - } else if type_has_block_attr(&pf.ty, "stream.not_null", names, db) { + } else if type_has_block_attr(&pf.ty, StreamAttrKind::NotNull, names, db) { PpirStreamStartsAs::Never } else { default_sap_starts_as(&stream_type) @@ -330,11 +336,17 @@ pub(crate) fn desugar_field( /// /// Only matches bare named types (e.g., `Foo`). Does NOT match `Foo[]`, `Foo?`, /// `Foo | Bar`, etc. — those use their own default `starts_as` behavior. -fn type_has_block_attr(ty: &PpirTy, attr: &str, names: PpirNames<'_>, db: &dyn crate::Db) -> bool { +fn type_has_block_attr( + ty: &PpirTy, + attr: StreamAttrKind, + names: PpirNames<'_>, + db: &dyn crate::Db, +) -> bool { let PpirTy::Named { name, .. } = ty else { return false; }; - let has_attr = |attrs: &Vec| attrs.iter().any(|a| a == attr); + let attr_name = attr.name(); + let has_attr = |attrs: &Vec| attrs.iter().any(|a| a == attr_name); names .class_names(db) .get(name.as_str()) @@ -362,7 +374,7 @@ pub fn extract_starts_as_text(green: &GreenNode) -> String { let text = child.text().to_string(); let trimmed = text.trim(); if trimmed.starts_with('"') && trimmed.ends_with('"') && trimmed.len() >= 2 { - return trimmed[1..trimmed.len() - 1].to_string(); + return trimmed[1..trimmed.len() - 1].to_owned(); } } SyntaxKind::RAW_STRING_LITERAL => { @@ -374,7 +386,7 @@ pub fn extract_starts_as_text(green: &GreenNode) -> String { if inner.starts_with('"') { if let Some(end_pos) = inner.rfind('"') { if end_pos > 0 { - return inner[1..end_pos].to_string(); + return inner[1..end_pos].to_owned(); } } } @@ -400,6 +412,6 @@ pub fn extract_starts_as_text(green: &GreenNode) -> String { | SyntaxKind::COMMA ) }) - .map(|token| token.text().to_string()) + .map(|token| token.text().to_owned()) .collect() } diff --git a/baml_language/crates/baml_compiler_ppir/src/lib.rs b/baml_language/crates/baml_compiler_ppir/src/lib.rs index 2cac22a5e1..273a93db35 100644 --- a/baml_language/crates/baml_compiler_ppir/src/lib.rs +++ b/baml_language/crates/baml_compiler_ppir/src/lib.rs @@ -24,7 +24,9 @@ use smol_str::SmolStr; mod desugar; pub mod expand_cst; pub mod normalize; +mod stream_attrs; mod ty; +mod validate; pub use desugar::{ PpirDesugaredClass, PpirDesugaredField, PpirDesugaredTypeAlias, PpirStreamStartsAs, @@ -33,7 +35,11 @@ pub use desugar::{ pub use normalize::{ StartsAs, StartsAsLiteral, default_starts_as_semantic, infer_typeof_s, parse_starts_as_value, }; +pub use stream_attrs::{ + StreamAttrArgRule, StreamAttrKind, StreamAttrTarget, parse_stream_attr, valid_stream_attr_names, +}; pub use ty::{PpirField, PpirTy, PpirTypeAttrs}; +pub use validate::ppir_stream_diagnostics; // // ──────────────────────────────────────────────────────────── DATABASE ───── @@ -126,67 +132,54 @@ pub struct PpirExpansionCst<'db> { /// no downstream queries are invalidated. #[salsa::tracked] pub fn ppir_names(db: &dyn Db, project: Project) -> PpirNames<'_> { - /// Collect @@stream.* block attribute names from a definition. - fn collect_stream_block_attrs( + /// Collect `@@stream.*` block attribute names from any definition's block attrs. + fn stream_block_attrs( block_attrs: impl Iterator, ) -> Vec { block_attrs .filter_map(|a| { let name = a.full_name()?; - if name.starts_with("stream.") { - Some(SmolStr::from(name.as_str())) - } else { - None - } + name.starts_with("stream.") + .then(|| SmolStr::from(name.as_str())) }) .collect() } let mut class_names: FxHashMap> = FxHashMap::default(); let mut enum_names: FxHashMap> = FxHashMap::default(); - let mut type_alias_names = FxHashSet::default(); + let mut type_alias_names: FxHashSet = FxHashSet::default(); for file in project.files(db) { - // Skip builtin files — they define internal types, not user-defined classes/enums/aliases. - if file - .path(db) - .to_str() - .is_some_and(|p| p.starts_with("/") || p.starts_with(" { - if let Some(class_def) = - baml_compiler_syntax::ast::ClassDef::cast(child.clone()) - { - if let Some(name_tok) = class_def.name() { - let name = SmolStr::new(name_tok.text()); - let stream_attrs = - collect_stream_block_attrs(class_def.block_attributes()); - class_names.insert(name, stream_attrs); + if let Some(def) = baml_compiler_syntax::ast::ClassDef::cast(child) { + if let Some(tok) = def.name() { + let name = SmolStr::new(tok.text()); + class_names + .entry(name) + .or_insert_with(|| stream_block_attrs(def.block_attributes())); } } } SyntaxKind::ENUM_DEF => { - if let Some(enum_def) = baml_compiler_syntax::ast::EnumDef::cast(child.clone()) - { - if let Some(name_tok) = enum_def.name() { - let name = SmolStr::new(name_tok.text()); - let stream_attrs = - collect_stream_block_attrs(enum_def.block_attributes()); - enum_names.insert(name, stream_attrs); + if let Some(def) = baml_compiler_syntax::ast::EnumDef::cast(child) { + if let Some(tok) = def.name() { + let name = SmolStr::new(tok.text()); + enum_names + .entry(name) + .or_insert_with(|| stream_block_attrs(def.block_attributes())); } } } SyntaxKind::TYPE_ALIAS_DEF => { - if let Some(alias_def) = - baml_compiler_syntax::ast::TypeAliasDef::cast(child.clone()) - { - if let Some(name_tok) = alias_def.name() { - type_alias_names.insert(SmolStr::new(name_tok.text())); + if let Some(def) = baml_compiler_syntax::ast::TypeAliasDef::cast(child) { + if let Some(tok) = def.name() { + type_alias_names.insert(SmolStr::new(tok.text())); } } } @@ -208,11 +201,20 @@ pub fn ppir_names(db: &dyn Db, project: Project) -> PpirNames<'_> { /// via clone-and-transform of the original CST. #[salsa::tracked] pub fn ppir_desugared_items(db: &dyn Db, file: SourceFile) -> PpirDesugaredItems<'_> { - let file_path = file.path(db); - if file_path - .to_str() - .is_some_and(|p| p.starts_with("/") || p.starts_with(", + tok: Option, + ) -> Option { + let name = SmolStr::new(tok?.text()); + if name.starts_with("stream_") || !seen.insert(name.clone()) { + return None; + } + Some(name) + } + + if file.is_virtual(db) { return PpirDesugaredItems::new(db, Vec::new(), Vec::new()); } @@ -220,72 +222,43 @@ pub fn ppir_desugared_items(db: &dyn Db, file: SourceFile) -> PpirDesugaredItems let project = db.project(); let names = ppir_names(db, project); - let mut desugared_classes = Vec::new(); - let mut desugared_aliases = Vec::new(); - let mut seen_class_names = FxHashSet::default(); - let mut seen_alias_names = FxHashSet::default(); + let mut desugared_classes: Vec = Vec::new(); + let mut desugared_aliases: Vec = Vec::new(); + let mut seen_class_names: FxHashSet = FxHashSet::default(); + let mut seen_alias_names: FxHashSet = FxHashSet::default(); for child in cst.children() { match child.kind() { SyntaxKind::CLASS_DEF => { - let Some(class_def) = baml_compiler_syntax::ast::ClassDef::cast(child.clone()) - else { + let Some(def) = baml_compiler_syntax::ast::ClassDef::cast(child) else { continue; }; - let Some(name_tok) = class_def.name() else { + let Some(name) = accept_name(&mut seen_class_names, def.name()) else { continue; }; - let class_name: Name = SmolStr::new(name_tok.text()); - if class_name.starts_with("stream_") { - continue; - } - if !seen_class_names.insert(class_name.clone()) { - continue; - } - - // Build PPIR fields from CST (type-level attrs captured by PpirTy::from_ast) - let ppir_fields = desugar::build_ppir_fields(&class_def); - - // Desugar each field - let desugared_fields: Vec = ppir_fields + let fields = desugar::build_ppir_fields(&def) .iter() .map(|pf| desugar::desugar_field(pf, names, db)) .collect(); - - desugared_classes.push(PpirDesugaredClass { - name: class_name, - fields: desugared_fields, - }); + desugared_classes.push(PpirDesugaredClass { name, fields }); } SyntaxKind::TYPE_ALIAS_DEF => { - let Some(alias_def) = baml_compiler_syntax::ast::TypeAliasDef::cast(child.clone()) - else { + let Some(def) = baml_compiler_syntax::ast::TypeAliasDef::cast(child) else { continue; }; - let Some(name_tok) = alias_def.name() else { + let Some(name) = accept_name(&mut seen_alias_names, def.name()) else { continue; }; - let alias_name: Name = SmolStr::new(name_tok.text()); - if alias_name.starts_with("stream_") { - continue; - } - if !seen_alias_names.insert(alias_name.clone()) { - continue; - } - - let ty = - alias_def - .ty() - .map(|te| PpirTy::from_ast(&te)) - .unwrap_or(PpirTy::Unknown { - attrs: PpirTypeAttrs::default(), - }); - + let ty = def + .ty() + .map(|te| PpirTy::from_ast(&te)) + .unwrap_or(PpirTy::Unknown { + attrs: PpirTypeAttrs::default(), + }); let expanded_body = desugar::stream_expand(&ty, names, db); - desugared_aliases.push(PpirDesugaredTypeAlias { - name: alias_name, + name, expanded_body, }); } @@ -323,8 +296,9 @@ pub fn ppir_expansion_cst(db: &dyn Db, file: SourceFile) -> PpirExpansionCst<'_> .file_name() .unwrap_or_default() .to_string_lossy() - .to_string(); - let display_path = format!(""); + .into_owned(); + let display_path = + format!("{}stream/{file_name}>", SourceFile::generated_path_prefix()); let file_id = FileId::stream_expansion(file.file_id(db)); let synth_file = SourceFile::new(db, text.clone(), PathBuf::from(&display_path), file_id); diff --git a/baml_language/crates/baml_compiler_ppir/src/normalize.rs b/baml_language/crates/baml_compiler_ppir/src/normalize.rs index ffea0cd3eb..3386672403 100644 --- a/baml_language/crates/baml_compiler_ppir/src/normalize.rs +++ b/baml_language/crates/baml_compiler_ppir/src/normalize.rs @@ -95,27 +95,28 @@ pub fn parse_starts_as_value(s: &str) -> StartsAs { if let Ok(i) = s.parse::() { return StartsAs::Literal(StartsAsLiteral::Int(i)); } - // Try float (exclude strings with alphabetic chars to avoid Foo.Bar confusion) - if s.contains('.') && !s.contains(|c: char| c.is_alphabetic()) { - if s.parse::().is_ok() { - return StartsAs::Literal(StartsAsLiteral::Float(s.to_string())); - } + // Try float: must contain '.' and have no alphabetic chars (rules out Foo.Bar) + if s.bytes().fold((false, true), |(dot, alpha_free), b| { + (dot || b == b'.', alpha_free && !b.is_ascii_alphabetic()) + }) == (true, true) + && s.parse::().is_ok() + { + return StartsAs::Literal(StartsAsLiteral::Float(s.to_owned())); } // Try enum value: Foo.Bar pattern (exactly one dot, both parts are identifiers) if let Some((left, right)) = s.split_once('.') { if is_identifier(left) && is_identifier(right) { return StartsAs::EnumValue { - enum_name: left.to_string(), - variant_name: right.to_string(), + enum_name: left.to_owned(), + variant_name: right.to_owned(), }; } } // If it looks like an expression (has parens), it's unrecognized. - // Everything else is a string literal. if s.contains('(') || s.contains(')') { - return StartsAs::Unknown(s.to_string()); + return StartsAs::Unknown(s.to_owned()); } - StartsAs::Literal(StartsAsLiteral::String(s.to_string())) + StartsAs::Literal(StartsAsLiteral::String(s.to_owned())) } } } @@ -143,7 +144,7 @@ pub fn infer_typeof_s( StartsAs::Null => Some(PpirTy::Null { attrs: d }), StartsAs::Literal(lit) => Some(match lit { StartsAsLiteral::String(s) => PpirTy::StringLiteral { - value: s.clone(), + value: s.to_owned(), attrs: d, }, StartsAsLiteral::Int(i) => PpirTy::IntLiteral { @@ -159,14 +160,12 @@ pub fn infer_typeof_s( StartsAs::EmptyList => None, StartsAs::EmptyMap => None, StartsAs::EnumValue { enum_name, .. } => { - if enum_names.contains_key(enum_name.as_str()) { - Some(PpirTy::Named { + enum_names + .contains_key(enum_name.as_str()) + .then(|| PpirTy::Named { name: SmolStr::new(enum_name), attrs: d, }) - } else { - None // Unknown type name → caller falls back to Never - } } StartsAs::Unknown(_) => None, // Cannot infer → caller falls back to Never } diff --git a/baml_language/crates/baml_compiler_ppir/src/stream_attrs.rs b/baml_language/crates/baml_compiler_ppir/src/stream_attrs.rs new file mode 100644 index 0000000000..299f2b94be --- /dev/null +++ b/baml_language/crates/baml_compiler_ppir/src/stream_attrs.rs @@ -0,0 +1,84 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StreamAttrTarget { + Type, + Block, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StreamAttrArgRule { + None, + ExactlyOne, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum StreamAttrKind { + Done, + NotNull, + StartsAs, + Type, + WithState, +} + +impl StreamAttrKind { + #[inline] + pub const fn name(self) -> &'static str { + match self { + Self::Done => "stream.done", + Self::NotNull => "stream.not_null", + Self::StartsAs => "stream.starts_as", + Self::Type => "stream.type", + Self::WithState => "stream.with_state", + } + } + + #[inline] + pub const fn arg_rule(self) -> StreamAttrArgRule { + match self { + Self::Done | Self::NotNull | Self::WithState => StreamAttrArgRule::None, + Self::StartsAs | Self::Type => StreamAttrArgRule::ExactlyOne, + } + } + + #[inline] + pub const fn supports_target(self, target: StreamAttrTarget) -> bool { + match (self, target) { + (Self::Done, StreamAttrTarget::Type | StreamAttrTarget::Block) => true, + (Self::NotNull, StreamAttrTarget::Type | StreamAttrTarget::Block) => true, + (Self::StartsAs, StreamAttrTarget::Type) => true, + (Self::Type, StreamAttrTarget::Type) => true, + (Self::WithState, StreamAttrTarget::Type) => true, + _ => false, + } + } +} + +const TYPE_STREAM_ATTR_NAMES: &[&str] = &[ + StreamAttrKind::Done.name(), + StreamAttrKind::NotNull.name(), + StreamAttrKind::StartsAs.name(), + StreamAttrKind::Type.name(), + StreamAttrKind::WithState.name(), +]; + +const BLOCK_STREAM_ATTR_NAMES: &[&str] = + &[StreamAttrKind::Done.name(), StreamAttrKind::NotNull.name()]; + +#[inline] +pub fn parse_stream_attr(name: &str) -> Option { + match name { + "stream.done" => Some(StreamAttrKind::Done), + "stream.not_null" => Some(StreamAttrKind::NotNull), + "stream.starts_as" => Some(StreamAttrKind::StartsAs), + "stream.type" => Some(StreamAttrKind::Type), + "stream.with_state" => Some(StreamAttrKind::WithState), + _ => None, + } +} + +#[inline] +pub const fn valid_stream_attr_names(target: StreamAttrTarget) -> &'static [&'static str] { + match target { + StreamAttrTarget::Type => TYPE_STREAM_ATTR_NAMES, + StreamAttrTarget::Block => BLOCK_STREAM_ATTR_NAMES, + } +} diff --git a/baml_language/crates/baml_compiler_ppir/src/tests.rs b/baml_language/crates/baml_compiler_ppir/src/tests.rs index f9f2790a27..fce21de4a0 100644 --- a/baml_language/crates/baml_compiler_ppir/src/tests.rs +++ b/baml_language/crates/baml_compiler_ppir/src/tests.rs @@ -1,4 +1,9 @@ //! Unit tests for PPIR stream expansion and normalization. +//! +//! Tests that require a `ProjectDatabase` (i.e. full compiler pipeline) live in +//! `baml_tests/tests/ppir_diagnostics.rs` to avoid pulling `baml_project` into +//! this crate's dev-dependencies, which would create a duplicate copy of +//! `baml_compiler_ppir` and break Salsa's trait-bound checks. use baml_base::Name; use rustc_hash::FxHashMap; diff --git a/baml_language/crates/baml_compiler_ppir/src/ty.rs b/baml_language/crates/baml_compiler_ppir/src/ty.rs index 84bba3618d..f8fb8b25c6 100644 --- a/baml_language/crates/baml_compiler_ppir/src/ty.rs +++ b/baml_language/crates/baml_compiler_ppir/src/ty.rs @@ -10,6 +10,8 @@ use baml_compiler_syntax::SyntaxNode; use rowan::ast::AstNode as _; use smol_str::SmolStr; +use crate::{StreamAttrKind, parse_stream_attr}; + // // ──────────────────────────────────────────────── TYPE ATTRS ───── // @@ -303,26 +305,33 @@ impl PpirTy { // Parse the type structure let mut result = Self::from_ast_structure(type_expr); - // Capture type-level annotations from ATTRIBUTE children + // Capture type-level annotations from ATTRIBUTE children. + // Validation of target legality and argument shape happens in `validate.rs`; + // PPIR only consumes the subset of stream attrs that affect desugaring. for attr in type_expr.attributes() { - if let Some(attr_name) = attr.full_name() { - match attr_name.as_str() { - "stream.done" => { - result.attrs_mut().stream_done = true; - } - "stream.type" => { - if let Some(type_arg) = attr.string_arg() { - result.attrs_mut().stream_type = - Some(Box::new(Self::from_type_name(&type_arg))); - } - } - "stream.with_state" => { - result.attrs_mut().stream_with_state = true; - } - _ => { - // Other type-level attrs (e.g., @assert, @check) ignored by PPIR + let Some(attr_name) = attr.full_name() else { + continue; + }; + let Some(kind) = parse_stream_attr(attr_name.as_str()) else { + continue; + }; + + match kind { + StreamAttrKind::Done => { + result.attrs_mut().stream_done = true; + } + StreamAttrKind::Type => { + if let Some(type_arg) = attr.string_arg() { + result.attrs_mut().stream_type = + Some(Box::new(Self::from_type_name(&type_arg))); } } + StreamAttrKind::WithState => { + result.attrs_mut().stream_with_state = true; + } + StreamAttrKind::NotNull | StreamAttrKind::StartsAs => { + // Field-level stream attrs are handled in `desugar::build_ppir_fields`. + } } } @@ -492,7 +501,7 @@ impl PpirTy { let type_arg_exprs: Vec<_> = type_args_node .children() .filter(|n| n.kind() == baml_compiler_syntax::SyntaxKind::TYPE_EXPR) - .map(|n| baml_compiler_syntax::ast::TypeExpr::cast(n).unwrap()) + .filter_map(baml_compiler_syntax::ast::TypeExpr::cast) .collect(); if type_arg_exprs.len() == 2 { diff --git a/baml_language/crates/baml_compiler_ppir/src/validate.rs b/baml_language/crates/baml_compiler_ppir/src/validate.rs new file mode 100644 index 0000000000..95caca40bc --- /dev/null +++ b/baml_language/crates/baml_compiler_ppir/src/validate.rs @@ -0,0 +1,302 @@ +use baml_base::{FileId, SourceFile, Span}; +use baml_compiler_diagnostics::HirDiagnostic; +use baml_compiler_parser::{parse_errors, syntax_tree}; +use baml_compiler_syntax::{ + SyntaxKind, TypeExpr, + ast::{Attribute, BlockAttribute, ClassDef, EnumDef, TypeAliasDef}, +}; +use rowan::{TextRange, ast::AstNode as _}; + +use crate::{ + StreamAttrArgRule, StreamAttrKind, StreamAttrTarget, parse_stream_attr, valid_stream_attr_names, +}; + +const STREAM_TYPE_ATTRS: &[&str] = valid_stream_attr_names(StreamAttrTarget::Type); +const STREAM_BLOCK_ATTRS: &[&str] = valid_stream_attr_names(StreamAttrTarget::Block); + +// ── local trait to unify Attribute / BlockAttribute span queries ────────────── + +/// Abstracts the span operations shared by [`Attribute`] and [`BlockAttribute`]. +/// +/// Both types expose identical `full_name_range`, `name`, and `args_span` +/// methods but share no common trait in `baml_compiler_syntax`. This sealed +/// trait lets us write the span helpers once. +trait AttrSpans { + fn full_name_range(&self) -> Option; + fn name_token_range(&self) -> Option; + fn args_span(&self) -> Option; + + fn best_span_range(&self) -> Option { + self.args_span() + .or_else(|| self.full_name_range()) + .or_else(|| self.name_token_range()) + } +} + +impl AttrSpans for Attribute { + fn full_name_range(&self) -> Option { + self.full_name_range() + } + fn name_token_range(&self) -> Option { + self.name().map(|t| t.text_range()) + } + fn args_span(&self) -> Option { + self.args_span() + } +} + +impl AttrSpans for BlockAttribute { + fn full_name_range(&self) -> Option { + self.full_name_range() + } + fn name_token_range(&self) -> Option { + self.name().map(|t| t.text_range()) + } + fn args_span(&self) -> Option { + self.args_span() + } +} + +// ── span helpers (one implementation, works for both attribute kinds) ───────── + +fn name_span(attr: &A, file_id: FileId) -> Option { + attr.full_name_range() + .or_else(|| attr.name_token_range()) + .map(|r| Span::new(file_id, r)) +} + +fn best_span(attr: &A, file_id: FileId) -> Option { + attr.best_span_range().map(|r| Span::new(file_id, r)) +} + +// ── public entry point ──────────────────────────────────────────────────────── + +pub fn ppir_stream_diagnostics(db: &dyn crate::Db, file: SourceFile) -> Vec { + if file.is_virtual(db) || !parse_errors(db, file).is_empty() { + return Vec::new(); + } + + let tree = syntax_tree(db, file); + let file_id = file.file_id(db); + let mut diagnostics = Vec::new(); + + for child in tree.children() { + match child.kind() { + SyntaxKind::CLASS_DEF => { + if let Some(def) = ClassDef::cast(child) { + validate_class(&def, file_id, &mut diagnostics); + } + } + SyntaxKind::ENUM_DEF => { + if let Some(def) = EnumDef::cast(child) { + validate_block_stream_attrs(def.block_attributes(), file_id, &mut diagnostics); + } + } + SyntaxKind::TYPE_ALIAS_DEF => { + if let Some(def) = TypeAliasDef::cast(child) { + if let Some(ty) = def.ty() { + validate_stream_type_expr(&ty, file_id, &mut diagnostics); + } + } + } + _ => {} + } + } + + diagnostics +} + +// ── per-definition validators ───────────────────────────────────────────────── + +fn validate_class(def: &ClassDef, file_id: FileId, out: &mut Vec) { + validate_block_stream_attrs(def.block_attributes(), file_id, out); + for field in def.fields() { + if let Some(ty) = field.ty() { + validate_stream_type_expr(&ty, file_id, out); + } + } +} + +fn validate_stream_type_expr(type_expr: &TypeExpr, file_id: FileId, out: &mut Vec) { + let mut not_null_span: Option = None; + let mut starts_as_span: Option = None; + let mut done_span: Option = None; + let mut type_span: Option = None; + + for attr in type_expr.attributes() { + let Some(attr_name) = attr.full_name() else { + continue; + }; + let key = attr_name.as_str(); + + if !key.starts_with("stream.") { + continue; + } + + let Some(kind) = parse_stream_attr(key) else { + if let Some(span) = name_span(&attr, file_id) { + out.push(HirDiagnostic::UnknownAttribute { + attr_name: key.to_owned(), + span, + valid_attributes: STREAM_TYPE_ATTRS, + }); + } + continue; + }; + + if !kind.supports_target(StreamAttrTarget::Type) { + if let Some(span) = name_span(&attr, file_id) { + out.push(HirDiagnostic::UnknownAttribute { + attr_name: key.to_owned(), + span, + valid_attributes: STREAM_TYPE_ATTRS, + }); + } + continue; + } + + let span = name_span(&attr, file_id); + + match kind.arg_rule() { + StreamAttrArgRule::None => forbid_args(&attr, kind.name(), file_id, out), + StreamAttrArgRule::ExactlyOne => require_single_arg(&attr, kind.name(), file_id, out), + } + + match kind { + StreamAttrKind::Type => { + if type_span.is_none() { + type_span = span; + } + } + StreamAttrKind::StartsAs => { + if starts_as_span.is_none() { + starts_as_span = span; + } + } + StreamAttrKind::NotNull => { + if not_null_span.is_none() { + not_null_span = span; + } + } + StreamAttrKind::Done => { + if done_span.is_none() { + done_span = span; + } + } + StreamAttrKind::WithState => {} + } + } + + if let (Some(first), Some(second)) = (not_null_span, starts_as_span) { + out.push(HirDiagnostic::ConflictingStreamAttributes { + first_attr: StreamAttrKind::NotNull.name(), + second_attr: StreamAttrKind::StartsAs.name(), + first_span: first, + second_span: second, + }); + } + + if let (Some(first), Some(second)) = (done_span, type_span) { + out.push(HirDiagnostic::ConflictingStreamAttributes { + first_attr: StreamAttrKind::Done.name(), + second_attr: StreamAttrKind::Type.name(), + first_span: first, + second_span: second, + }); + } +} + +fn validate_block_stream_attrs( + attrs: impl Iterator, + file_id: FileId, + out: &mut Vec, +) { + for attr in attrs { + let Some(attr_name) = attr.full_name() else { + continue; + }; + let key = attr_name.as_str(); + + if !key.starts_with("stream.") { + continue; + } + + let Some(kind) = parse_stream_attr(key) else { + if let Some(span) = name_span(&attr, file_id) { + out.push(HirDiagnostic::UnknownAttribute { + attr_name: key.to_owned(), + span, + valid_attributes: STREAM_BLOCK_ATTRS, + }); + } + continue; + }; + + if !kind.supports_target(StreamAttrTarget::Block) { + if let Some(span) = name_span(&attr, file_id) { + out.push(HirDiagnostic::UnknownAttribute { + attr_name: key.to_owned(), + span, + valid_attributes: STREAM_BLOCK_ATTRS, + }); + } + continue; + } + + if kind.arg_rule() == StreamAttrArgRule::None && attr.has_args() { + if let Some(span) = best_span(&attr, file_id) { + out.push(HirDiagnostic::UnexpectedAttributeArg { + attr_name: kind.name().to_owned(), + span, + }); + } + } + } +} + +// ── argument validators ─────────────────────────────────────────────────────── + +fn forbid_args(attr: &Attribute, name: &str, file_id: FileId, out: &mut Vec) { + if attr.has_args() { + if let Some(span) = best_span(attr, file_id) { + out.push(HirDiagnostic::UnexpectedAttributeArg { + attr_name: name.to_owned(), + span, + }); + } + } +} + +fn require_single_arg(attr: &Attribute, name: &str, file_id: FileId, out: &mut Vec) { + if attr.arg_count() != 1 { + if let Some(span) = best_span(attr, file_id) { + out.push(HirDiagnostic::InvalidAttributeArg { + attr_name: name.to_owned(), + span, + received: describe_args(attr), + }); + } + } +} + +// ── arg description ─────────────────────────────────────────────────────────── + +fn describe_args(attr: &Attribute) -> String { + match attr.arg_count() { + 0 => "no arguments".to_owned(), + 1 => attr + .args() + .next() + .map(|arg| match arg.kind() { + SyntaxKind::STRING_LITERAL | SyntaxKind::RAW_STRING_LITERAL => { + format!("`{}`", arg.text()) + } + SyntaxKind::EXPR | SyntaxKind::UNQUOTED_STRING => { + format!("an expression `{}`", arg.text()) + } + _ => "an unknown value".to_owned(), + }) + .unwrap_or_else(|| "an unknown value".to_owned()), + n => format!("{n} arguments"), + } +} diff --git a/baml_language/crates/baml_project/src/check.rs b/baml_language/crates/baml_project/src/check.rs index 2e2d153582..8b2f3542e3 100644 --- a/baml_language/crates/baml_project/src/check.rs +++ b/baml_language/crates/baml_project/src/check.rs @@ -21,6 +21,7 @@ use baml_compiler_hir::{ llm_function_file_offset, llm_function_meta, project_class_field_type_spans, project_type_alias_type_spans, project_type_item_spans, template_string_file_offset, }; +use baml_compiler_ppir::ppir_stream_diagnostics; use baml_compiler_tir::{self, class_field_types, enum_variants, type_aliases, typing_context}; use baml_db::{FileId, SourceFile, baml_compiler_parser}; use baml_workspace::Project; @@ -71,12 +72,16 @@ pub fn collect_diagnostics( } } - // 2. Collect HIR lowering diagnostics (per-file validation) + // 2. Collect per-file lowering diagnostics (HIR + PPIR stream) for source_file in source_files { let lowering_result = file_lowering(db, *source_file); for diag in lowering_result.diagnostics(db) { diagnostics.push(diag.to_diagnostic()); } + + for diag in ppir_stream_diagnostics(db, *source_file) { + diagnostics.push(diag.to_diagnostic()); + } } // 3. Collect validation errors (duplicates across files, reserved names) diff --git a/baml_language/crates/baml_project/src/db.rs b/baml_language/crates/baml_project/src/db.rs index 399aedb2c4..0639d58534 100644 --- a/baml_language/crates/baml_project/src/db.rs +++ b/baml_language/crates/baml_project/src/db.rs @@ -372,7 +372,7 @@ impl ProjectDatabase { pub fn non_builtin_file_paths(&self) -> impl Iterator { self.file_map .keys() - .filter(|path| !path.starts_with(baml_builtins::BUILTIN_PATH_PREFIX)) + .filter(|path| !path.starts_with(SourceFile::builtin_path_prefix())) .cloned() } diff --git a/baml_language/crates/baml_tests/Cargo.toml b/baml_language/crates/baml_tests/Cargo.toml index c997ff911b..1a80cb2125 100644 --- a/baml_language/crates/baml_tests/Cargo.toml +++ b/baml_language/crates/baml_tests/Cargo.toml @@ -15,6 +15,7 @@ baml_compiler2_ast = { workspace = true } baml_compiler2_hir = { workspace = true } baml_compiler2_tir = { workspace = true } baml_compiler_diagnostics = { workspace = true } +baml_compiler_ppir = { workspace = true } baml_compiler_emit = { workspace = true } baml_compiler_hir = { workspace = true } baml_compiler_mir = { workspace = true } diff --git a/baml_language/crates/baml_tests/tests/ppir_diagnostics.rs b/baml_language/crates/baml_tests/tests/ppir_diagnostics.rs new file mode 100644 index 0000000000..ddf089fd84 --- /dev/null +++ b/baml_language/crates/baml_tests/tests/ppir_diagnostics.rs @@ -0,0 +1,209 @@ +//! Integration tests for PPIR stream-annotation diagnostics. +//! +//! These tests exercise the full `ppir_stream_diagnostics` pipeline through a +//! real `ProjectDatabase`. They live here (in `baml_tests`) rather than in +//! `baml_compiler_ppir/src/tests.rs` because `baml_compiler_ppir` must not +//! take `baml_project` as a dev-dependency: doing so would introduce a second +//! copy of `baml_compiler_ppir` in the dependency graph, which breaks Salsa's +//! `Db` trait-bound resolution. + +use tempfile::tempdir; + +use baml_compiler_diagnostics::HirDiagnostic; +use baml_compiler_ppir::ppir_stream_diagnostics; +use baml_project::ProjectDatabase; + +// ── helpers ────────────────────────────────────────────────────────────────── + +/// Run `ppir_stream_diagnostics` on `source` inside a throwaway temp project. +fn stream_diagnostics(source: &str) -> Vec { + let dir = tempdir().expect("failed to create temp project dir"); + + let mut db = ProjectDatabase::new(); + db.set_project_root(dir.path()); + let file_path = dir.path().join("test.baml"); + let file = db.add_or_update_file(&file_path, source); + ppir_stream_diagnostics(&db, file) +} + +// ── @stream.* type-level attribute diagnostics ─────────────────────────────── + +#[test] +fn unknown_stream_type_attr_is_reported() { + let diagnostics = stream_diagnostics( + r#"class Foo { + name string @stream.typo +}"#, + ); + + let found = diagnostics.iter().any(|d| { + matches!( + d, + HirDiagnostic::UnknownAttribute { attr_name, .. } + if attr_name == "stream.typo" + ) + }); + + assert!( + found, + "expected UnknownAttribute for 'stream.typo', got: {diagnostics:#?}" + ); +} + +#[test] +fn conflicting_not_null_and_starts_as_is_reported() { + let diagnostics = stream_diagnostics( + r#"class Foo { + name string @stream.not_null @stream.starts_as("x") +}"#, + ); + + let found = diagnostics.iter().any(|d| { + matches!( + d, + HirDiagnostic::ConflictingStreamAttributes { + first_attr, + second_attr, + .. + } if *first_attr == "stream.not_null" && *second_attr == "stream.starts_as" + ) + }); + + assert!( + found, + "expected ConflictingStreamAttributes for not_null+starts_as, got: {diagnostics:#?}" + ); +} + +#[test] +fn conflicting_done_and_type_is_reported() { + let diagnostics = stream_diagnostics( + r#"class Foo { + name string @stream.done @stream.type(string) +}"#, + ); + + let found = diagnostics.iter().any(|d| { + matches!( + d, + HirDiagnostic::ConflictingStreamAttributes { + first_attr, + second_attr, + .. + } if *first_attr == "stream.done" && *second_attr == "stream.type" + ) + }); + + assert!( + found, + "expected ConflictingStreamAttributes for done+type, got: {diagnostics:#?}" + ); +} + +// ── @@stream.* block-level attribute diagnostics ───────────────────────────── + +#[test] +fn block_attr_with_args_is_rejected() { + let diagnostics = stream_diagnostics( + r#"class Foo { + @@stream.done("unexpected") + name string +}"#, + ); + + let found = diagnostics.iter().any(|d| { + matches!( + d, + HirDiagnostic::UnexpectedAttributeArg { attr_name, .. } + if attr_name == "stream.done" + ) + }); + + assert!( + found, + "expected UnexpectedAttributeArg for @@stream.done(\"unexpected\"), got: {diagnostics:#?}" + ); +} + +#[test] +fn unknown_stream_block_attr_is_reported() { + let diagnostics = stream_diagnostics( + r#"class Foo { + @@stream.unknown_block_attr + name string +}"#, + ); + + let found = diagnostics.iter().any(|d| { + matches!( + d, + HirDiagnostic::UnknownAttribute { attr_name, .. } + if attr_name == "stream.unknown_block_attr" + ) + }); + + assert!( + found, + "expected UnknownAttribute for '@@stream.unknown_block_attr', got: {diagnostics:#?}" + ); +} + +// ── clean inputs produce no diagnostics ────────────────────────────────────── + +#[test] +fn valid_stream_type_attr_produces_no_diagnostics() { + let diagnostics = stream_diagnostics( + r#"class Foo { + name string @stream.type(string) +}"#, + ); + + assert!( + diagnostics.is_empty(), + "expected no diagnostics for valid @stream.type, got: {diagnostics:#?}" + ); +} + +#[test] +fn valid_stream_done_block_attr_produces_no_diagnostics() { + let diagnostics = stream_diagnostics( + r#"class Foo { + @@stream.done + name string +}"#, + ); + + assert!( + diagnostics.is_empty(), + "expected no diagnostics for valid @@stream.done, got: {diagnostics:#?}" + ); +} + +#[test] +fn valid_stream_not_null_type_attr_produces_no_diagnostics() { + let diagnostics = stream_diagnostics( + r#"class Foo { + name string @stream.not_null +}"#, + ); + + assert!( + diagnostics.is_empty(), + "expected no diagnostics for valid @stream.not_null, got: {diagnostics:#?}" + ); +} + +#[test] +fn non_stream_class_produces_no_diagnostics() { + let diagnostics = stream_diagnostics( + r#"class Resume { + name string + age int +}"#, + ); + + assert!( + diagnostics.is_empty(), + "expected no diagnostics for plain class, got: {diagnostics:#?}" + ); +}