Skip to content

Commit 97e477a

Browse files
committed
Merge branch 'develop' of github.com:Champii/Rock
2 parents 5f699c9 + 054fd32 commit 97e477a

File tree

8 files changed

+49
-3
lines changed

8 files changed

+49
-3
lines changed

src/lib/ast/tree.rs

+10
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,21 @@ impl Expression {
468468
}
469469
}
470470

471+
#[allow(dead_code)]
472+
pub fn is_return(&self) -> bool {
473+
matches!(&self, Expression::Return(_))
474+
}
475+
471476
#[allow(dead_code)]
472477
pub fn new_unary(unary: UnaryExpr) -> Expression {
473478
Expression::UnaryExpr(unary)
474479
}
475480

481+
#[allow(dead_code)]
482+
pub fn new_return(expr: Expression) -> Expression {
483+
Expression::Return(Box::new(expr))
484+
}
485+
476486
#[allow(dead_code)]
477487
pub fn new_binop(unary: UnaryExpr, operator: Operator, expr: Expression) -> Expression {
478488
Expression::BinopExpr(unary, operator, Box::new(expr))

src/lib/codegen/codegen_context.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,17 @@ impl<'a> CodegenContext<'a> {
322322

323323
builder.position_at_end(basic_block);
324324

325-
let stmt = body
325+
let first_return_idx = body
326326
.stmts
327327
.iter()
328+
.position(|s| s.is_return())
329+
.unwrap_or(body.stmts.len());
330+
331+
let stmts = body.stmts.iter().take(first_return_idx + 1);
332+
333+
// FIXME: Add warning here for unreachable statements
334+
335+
let stmt = stmts
328336
.map(|stmt| self.lower_stmt(stmt, builder))
329337
.last()
330338
.unwrap()?;
@@ -545,7 +553,8 @@ impl<'a> CodegenContext<'a> {
545553
ExpressionKind::Return(expr) => {
546554
let val = self.lower_expression(expr, builder)?;
547555

548-
builder.build_return(Some(&val.as_basic_value_enum()));
556+
// This is disabled because this is handled in lower_body
557+
// builder.build_return(Some(&val.as_basic_value_enum()));
549558

550559
val
551560
}

src/lib/hir/tree.rs

+15
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@ impl Statement {
337337
StatementKind::For(f) => f.get_hir_id(),
338338
}
339339
}
340+
341+
pub fn is_return(&self) -> bool {
342+
match &*self.kind {
343+
StatementKind::Expression(expr) => expr.is_return(),
344+
_ => false,
345+
}
346+
}
340347
}
341348

342349
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -516,6 +523,14 @@ impl Expression {
516523
panic!("Not a literal");
517524
}
518525
}
526+
527+
pub fn is_return(&self) -> bool {
528+
if let ExpressionKind::Return(_) = &*self.kind {
529+
true
530+
} else {
531+
false
532+
}
533+
}
519534
}
520535

521536
#[derive(Debug, Clone, Serialize, Deserialize)]

src/lib/parser/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,10 @@ pub fn parse_assign_left_side(input: Parser) -> Res<Parser, AssignLeftSide> {
668668

669669
pub fn parse_expression(input: Parser) -> Res<Parser, Expression> {
670670
alt((
671+
map(
672+
preceded(terminated(tag("return"), space1), parse_expression),
673+
Expression::new_return,
674+
),
671675
map(
672676
tuple((
673677
parse_unary,
@@ -681,7 +685,6 @@ pub fn parse_expression(input: Parser) -> Res<Parser, Expression> {
681685
map(parse_native_operator, |(op, id1, id2)| {
682686
Expression::new_native_operator(op, id1, id2)
683687
}),
684-
// TODO: Return
685688
))(input)
686689
}
687690

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
main =
2+
let a = 42
3+
return a
4+
28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
42

src/lib/testcases/basic/early_return/main.rk.stdout

Whitespace-only changes.

src/lib/tests.rs

+4
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ fn testcases_basic_simple_struct_main() {
106106
run("testcases/basic/simple_struct/main.rk", include_str!("testcases/basic/simple_struct/main.rk"), include_str!("testcases/basic/simple_struct/main.rk.out"), include_str!("testcases/basic/simple_struct/main.rk.stdout"));
107107
}
108108
#[test]
109+
fn testcases_basic_early_return_main() {
110+
run("testcases/basic/early_return/main.rk", include_str!("testcases/basic/early_return/main.rk"), include_str!("testcases/basic/early_return/main.rk.out"), include_str!("testcases/basic/early_return/main.rk.stdout"));
111+
}
112+
#[test]
109113
fn testcases_basic_bool_false_main() {
110114
run("testcases/basic/bool_false/main.rk", include_str!("testcases/basic/bool_false/main.rk"), include_str!("testcases/basic/bool_false/main.rk.out"), include_str!("testcases/basic/bool_false/main.rk.stdout"));
111115
}

0 commit comments

Comments
 (0)