Skip to content

Commit b281c06

Browse files
committed
Merge remote-tracking branch 'origin/canary' into sam/jetbrains
2 parents b67454f + 4e2ed58 commit b281c06

File tree

118 files changed

+6579
-4653
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

118 files changed

+6579
-4653
lines changed

engine/baml-compiler/src/codegen.rs

Lines changed: 92 additions & 2152 deletions
Large diffs are not rendered by default.

engine/baml-compiler/src/hir/lowering.rs

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
//!
33
//! This files contains the convertions between Baml AST nodes to HIR nodes.
44
5+
use std::collections::HashSet;
6+
57
use baml_types::{
8+
ir_type::TypeGeneric,
69
type_meta::{self, base::StreamingBehavior},
710
Constraint, ConstraintLevel, TypeIR, TypeValue,
811
};
@@ -48,6 +51,38 @@ impl Hir {
4851
}
4952
}
5053

54+
let enums = HashSet::<&str>::from_iter(hir.enums.iter().map(|e| e.name.as_str()));
55+
56+
let param_type: fn(&mut Parameter) -> &mut TypeIR = |p| &mut p.r#type;
57+
58+
// Patch return types because only here in the code we have the full
59+
// context for enums.
60+
hir.expr_functions
61+
.iter_mut()
62+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
63+
.chain(
64+
hir.llm_functions
65+
.iter_mut()
66+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type)),
67+
)
68+
.chain(hir.classes.iter_mut().flat_map(|c| {
69+
c.methods
70+
.iter_mut()
71+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
72+
}))
73+
.flat_map(|(parameters, return_type)| parameters.chain(std::iter::once(return_type)))
74+
.for_each(|ty| match ty {
75+
TypeIR::Class { name, meta, .. } if enums.contains(name.as_str()) => {
76+
*ty = TypeIR::Enum {
77+
name: name.to_owned(),
78+
dynamic: false, // TODO: How to know if it's dynamic.
79+
meta: meta.clone(),
80+
}
81+
}
82+
83+
_ => {}
84+
});
85+
5186
hir
5287
}
5388
}
@@ -125,22 +160,12 @@ pub fn type_ir_from_ast(type_: &ast::FieldType) -> TypeIR {
125160
};
126161

127162
match type_ {
128-
ast::FieldType::Symbol(_, name, _) => {
129-
if name.name().starts_with("Enum") {
130-
TypeIR::Enum {
131-
name: name.name().to_string(),
132-
dynamic: false,
133-
meta,
134-
}
135-
} else {
136-
TypeIR::Class {
137-
name: name.name().to_string(),
138-
mode: baml_types::ir_type::StreamingMode::NonStreaming,
139-
dynamic: false,
140-
meta,
141-
}
142-
}
143-
}
163+
ast::FieldType::Symbol(_, name, _) => TypeIR::Class {
164+
name: name.name().to_string(),
165+
mode: baml_types::ir_type::StreamingMode::NonStreaming,
166+
dynamic: false,
167+
meta,
168+
},
144169
ast::FieldType::Primitive(_, prim, _, _) => TypeIR::Primitive(*prim, meta),
145170
ast::FieldType::List(_, inner, dims, _, _) => {
146171
// Respect multi-dimensional arrays (e.g., int[][] has dims=2)

engine/baml-compiler/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@ pub mod test {
99
use internal_baml_diagnostics::Diagnostics;
1010
use internal_baml_parser_database::{parse_and_diagnostics, ParserDatabase};
1111

12+
use crate::{hir, thir};
13+
1214
/// Shim helper function for testing.
1315
pub fn ast(source: &'static str) -> anyhow::Result<ParserDatabase> {
14-
let (parser_db, diagnostics) = parse_and_diagnostics(source)?;
16+
let (parser_db, mut diagnostics) = parse_and_diagnostics(source)?;
17+
18+
if diagnostics.has_errors() {
19+
let errors = diagnostics.to_pretty_string();
20+
anyhow::bail!("{errors}");
21+
}
1522

23+
// Here because of cycle dependencies between crates and shit.
24+
// TODO: We're building this like 3 different times, needs refactoring.
25+
thir::typecheck::typecheck(&hir::Hir::from_ast(&parser_db.ast), &mut diagnostics);
1626
if diagnostics.has_errors() {
1727
let errors = diagnostics.to_pretty_string();
1828
anyhow::bail!("{errors}");

engine/baml-compiler/src/thir.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
///
33
use baml_types::ir_type::TypeIR;
44

5-
use crate::hir::{self, AssignOp, BinaryOperator, Enum, LlmFunction, UnaryOperator};
5+
use crate::hir::{self, AssignOp, BinaryOperator, LlmFunction, UnaryOperator};
66

77
pub mod interpret;
88
pub mod typecheck;
@@ -54,6 +54,14 @@ pub struct Class<T> {
5454
pub span: Span,
5555
}
5656

57+
#[derive(Clone, Debug)]
58+
pub struct Enum {
59+
pub name: String,
60+
pub variants: Vec<hir::EnumVariant>,
61+
pub span: Span,
62+
pub ty: TypeIR, // TODO: Used for type checking, but do we need this?
63+
}
64+
5765
/// A BAML expression term.
5866
/// T is the type of the metadata.
5967
#[derive(Debug, Clone)]

engine/baml-compiler/src/thir/typecheck.rs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub fn typecheck_returning_context<'a>(
4444
.map(|c| (c.name.clone(), c))
4545
.collect();
4646

47-
let enums = hir
47+
let enums: BamlMap<String, hir::Enum> = hir
4848
.enums
4949
.clone()
5050
.into_iter()
@@ -54,6 +54,7 @@ pub fn typecheck_returning_context<'a>(
5454
// Create typing context with all functions
5555
let mut typing_context = TypeContext::new();
5656
typing_context.classes.extend(classes.clone());
57+
typing_context.enums.extend(enums.clone());
5758

5859
// Add expr functions to typing context
5960
for func in &hir.expr_functions {
@@ -271,11 +272,32 @@ pub fn typecheck_returning_context<'a>(
271272
})
272273
.collect();
273274

275+
// TODO: Those are HIR enums, figure out if there's something different we
276+
// would need in a THIR enum? Does it need a "type"?.
277+
let thir_enums = enums
278+
.iter()
279+
.map(|(name, enum_def)| {
280+
(
281+
name.clone(),
282+
thir::Enum {
283+
name: enum_def.name.clone(),
284+
variants: enum_def.variants.clone(),
285+
span: enum_def.span.clone(),
286+
ty: TypeIR::Enum {
287+
name: enum_def.name.clone(),
288+
dynamic: false,
289+
meta: Default::default(),
290+
},
291+
},
292+
)
293+
})
294+
.collect();
295+
274296
(
275297
THir {
276298
llm_functions: hir.llm_functions.clone(),
277299
classes: thir_classes,
278-
enums,
300+
enums: thir_enums,
279301
expr_functions,
280302
global_assignments: BamlMap::new(),
281303
},
@@ -302,6 +324,7 @@ pub struct TypeContext<'func> {
302324
// Variables in scope with mutability info
303325
pub vars: BamlMap<String, VarInfo>,
304326
pub classes: BamlMap<String, hir::Class>,
327+
pub enums: BamlMap<String, hir::Enum>,
305328
// Used for knowing whether `break` and `continue` are inside a loop or not.
306329
pub is_inside_loop: bool,
307330

@@ -336,6 +359,7 @@ impl TypeContext<'_> {
336359
symbols: BamlMap::new(),
337360
vars,
338361
classes: BamlMap::new(),
362+
enums: BamlMap::new(),
339363
is_inside_loop: false,
340364
function_return_type: None,
341365
}
@@ -428,7 +452,9 @@ impl TypeContext<'_> {
428452
}
429453
hir::Expression::Identifier(name, _) => {
430454
// Look up type in context
431-
self.get_type(name).cloned()
455+
self.get_type(name)
456+
.cloned()
457+
.or_else(|| self.enums.get(name).map(|e| TypeIR::r#enum(&e.name)))
432458
}
433459
hir::Expression::Array(items, _) => {
434460
// Infer array type from first item
@@ -575,6 +601,14 @@ impl TypeContext<'_> {
575601
None
576602
}
577603
}
604+
TypeIR::Enum {
605+
name: enum_name, ..
606+
} => {
607+
// Look up field in enum definition
608+
self.enums
609+
.get(&enum_name)
610+
.map(|enum_def| TypeIR::r#enum(&enum_def.name))
611+
}
578612
_ => None, // Not a class
579613
}
580614
} else {
@@ -1162,6 +1196,14 @@ pub fn typecheck_expression(
11621196
BamlValueWithMeta::String(value.clone(), (span.clone(), Some(TypeIR::string()))),
11631197
),
11641198
hir::Expression::Identifier(name, span) => {
1199+
// Enum access: let x = Shape.Rectangle
1200+
if let Some(enum_def) = context.enums.get(name) {
1201+
return thir::Expr::Var(
1202+
name.clone(),
1203+
(span.clone(), Some(TypeIR::r#enum(&enum_def.name))),
1204+
);
1205+
}
1206+
11651207
// Look up type in context
11661208
let var_type = context.get_type(name).cloned();
11671209
if var_type.is_none() {
@@ -1740,6 +1782,20 @@ pub fn typecheck_expression(
17401782
None
17411783
}
17421784
}
1785+
Some(TypeIR::Enum {
1786+
name: enum_name, ..
1787+
}) => {
1788+
// Look up field in enum definition
1789+
if let Some(enum_def) = context.enums.get(enum_name) {
1790+
Some(TypeIR::r#enum(&enum_def.name))
1791+
} else {
1792+
diagnostics.push_error(DatamodelError::new_validation_error(
1793+
&format!("Enum {enum_name} not found"),
1794+
span.clone(),
1795+
));
1796+
None
1797+
}
1798+
}
17431799
_ => {
17441800
diagnostics.push_error(DatamodelError::new_validation_error(
17451801
"Can only access fields on class instances",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//! Compiler tests for array construction.
2+
3+
use baml_vm::Instruction;
4+
5+
mod common;
6+
use common::{assert_compiles, Program};
7+
8+
#[test]
9+
fn array_constructor() -> anyhow::Result<()> {
10+
assert_compiles(Program {
11+
source: "
12+
fn main() -> int[] {
13+
let a = [1, 2, 3];
14+
a
15+
}
16+
",
17+
expected: vec![(
18+
"main",
19+
vec![
20+
Instruction::LoadConst(0),
21+
Instruction::LoadConst(1),
22+
Instruction::LoadConst(2),
23+
Instruction::AllocArray(3),
24+
Instruction::LoadVar(1),
25+
Instruction::Return,
26+
],
27+
)],
28+
})
29+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//! Compiler tests for assert statements.
2+
3+
use baml_vm::{BinOp, CmpOp, Instruction};
4+
5+
mod common;
6+
use common::{assert_compiles, Program};
7+
8+
#[test]
9+
fn assert_statement_ok() -> anyhow::Result<()> {
10+
assert_compiles(Program {
11+
source: "
12+
fn assertOk() -> int {
13+
assert 2 + 2 == 4;
14+
3
15+
}
16+
",
17+
expected: vec![(
18+
"assertOk",
19+
vec![
20+
Instruction::LoadConst(0), // 2
21+
Instruction::LoadConst(1), // 2
22+
Instruction::BinOp(BinOp::Add),
23+
Instruction::LoadConst(2), // 4
24+
Instruction::CmpOp(CmpOp::Eq),
25+
Instruction::Assert,
26+
Instruction::LoadConst(3), // 3
27+
Instruction::Return,
28+
],
29+
)],
30+
})
31+
}
32+
33+
#[test]
34+
fn assert_statement_not_ok() -> anyhow::Result<()> {
35+
assert_compiles(Program {
36+
source: "
37+
fn assertNotOk() -> int {
38+
assert 3 == 1;
39+
2
40+
}
41+
",
42+
expected: vec![(
43+
"assertNotOk",
44+
vec![
45+
Instruction::LoadConst(0), // 3
46+
Instruction::LoadConst(1), // 1
47+
Instruction::CmpOp(CmpOp::Eq),
48+
Instruction::Assert,
49+
Instruction::LoadConst(2), // 2
50+
Instruction::Return,
51+
],
52+
)],
53+
})
54+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//! Compiler tests for built-in method calls.
2+
3+
use baml_vm::{GlobalIndex, Instruction};
4+
5+
mod common;
6+
use common::{assert_compiles, Program};
7+
8+
#[test]
9+
fn builtin_method_call() -> anyhow::Result<()> {
10+
assert_compiles(Program {
11+
source: r#"
12+
fn main() -> int {
13+
let arr = [1, 2, 3];
14+
arr.len()
15+
}
16+
"#,
17+
expected: vec![(
18+
"main",
19+
vec![
20+
Instruction::LoadConst(0),
21+
Instruction::LoadConst(1),
22+
Instruction::LoadConst(2),
23+
Instruction::AllocArray(3),
24+
Instruction::LoadGlobal(GlobalIndex::from_raw(3)),
25+
Instruction::LoadVar(1),
26+
// call with one argument (self)
27+
Instruction::Call(1),
28+
Instruction::Return,
29+
],
30+
)],
31+
})
32+
}

0 commit comments

Comments
 (0)