@@ -362,7 +362,7 @@ CheckStatus check_for(AstFor* fornode) {
362362 return Check_Success ;
363363}
364364
365- static b32 add_case_to_switch_statement (AstSwitch * switchnode , u64 case_value , AstBlock * block , OnyxFilePos pos ) {
365+ static b32 add_case_to_switch_statement (AstSwitch * switchnode , u64 case_value , AstSwitchCase * casestmt , OnyxFilePos pos ) {
366366 assert (switchnode -> switch_kind == Switch_Kind_Integer || switchnode -> switch_kind == Switch_Kind_Union );
367367
368368 switchnode -> min_case = bh_min (switchnode -> min_case , case_value );
@@ -373,7 +373,7 @@ static b32 add_case_to_switch_statement(AstSwitch* switchnode, u64 case_value, A
373373 return 1 ;
374374 }
375375
376- bh_imap_put (& switchnode -> case_map , case_value , (u64 ) block );
376+ bh_imap_put (& switchnode -> case_map , case_value , (u64 ) casestmt );
377377 return 0 ;
378378}
379379
@@ -488,7 +488,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
488488 AstSwitchCase * sc = switchnode -> cases [i ];
489489
490490 if (sc -> capture && bh_arr_length (sc -> values ) != 1 ) {
491- ERROR (sc -> token -> pos , "Expected exactly one value in switch-case when using a capture, i.e. `case X => Y { ... }`." );
491+ ERROR (sc -> token -> pos , "Expected exactly one value in switch-case when using a capture, i.e. `case value: X { ... }`." );
492492 }
493493
494494 if (sc -> capture && switchnode -> switch_kind != Switch_Kind_Union ) {
@@ -519,7 +519,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
519519
520520 // NOTE: This is inclusive!!!!
521521 fori (case_value , lower , upper + 1 ) {
522- if (add_case_to_switch_statement (switchnode , case_value , sc -> block , rl -> token -> pos ))
522+ if (add_case_to_switch_statement (switchnode , case_value , sc , rl -> token -> pos ))
523523 return Check_Error ;
524524 }
525525
@@ -573,7 +573,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
573573 if (!is_valid )
574574 ERROR_ ((* value )-> token -> pos , "Case statement expected compile time known integer. Got '%s'." , onyx_ast_node_kind_string ((* value )-> kind ));
575575
576- if (add_case_to_switch_statement (switchnode , integer_value , sc -> block , sc -> block -> token -> pos ))
576+ if (add_case_to_switch_statement (switchnode , integer_value , sc , sc -> block -> token -> pos ))
577577 return Check_Error ;
578578
579579 break ;
@@ -592,7 +592,7 @@ CheckStatus check_switch(AstSwitch* switchnode) {
592592 if (found ) break ;
593593
594594 CaseToBlock ctb ;
595- ctb .block = sc -> block ;
595+ ctb .casestmt = sc ;
596596 ctb .original_value = * value ;
597597 ctb .comparison = make_binary_op (context .ast_alloc , Binary_Op_Equal , switchnode -> expr , * value );
598598 ctb .comparison -> token = (* value )-> token ;
@@ -607,13 +607,52 @@ CheckStatus check_switch(AstSwitch* switchnode) {
607607 sc -> flags |= Ast_Flag_Has_Been_Checked ;
608608
609609 check_switch_case_block :
610- CHECK (block , sc -> block );
610+ if (switchnode -> is_expr ) {
611+ if (!sc -> body_is_expr ) {
612+ onyx_report_error (sc -> token -> pos , Error_Critical , "Inside a switch expression, all cases must return a value." );
613+ ERROR (sc -> token -> pos , "Change the case statement to look like 'case X => expr'." );
614+ }
615+ } else {
616+ if (sc -> body_is_expr ) {
617+ ERROR (sc -> token -> pos , "This kind of case statement is only allowed in switch expressions, not switch statements." );
618+ }
619+ }
620+
621+ if (sc -> body_is_expr ) {
622+ CHECK (expression , & sc -> expr );
623+ if (switchnode -> type == NULL ) {
624+ switchnode -> type = resolve_expression_type (sc -> expr );
625+ } else {
626+ TYPE_CHECK (& sc -> expr , switchnode -> type ) {
627+ ERROR_ (sc -> token -> pos , "Expected case expression to be of type '%s', got '%s'." ,
628+ type_get_name (switchnode -> type ),
629+ type_get_name (sc -> expr -> type ));
630+ }
631+ }
632+
633+ } else {
634+ CHECK (block , sc -> block );
635+ }
611636
612637 switchnode -> yield_return_index += 1 ;
613638 }
614639
615640 if (switchnode -> default_case ) {
616- CHECK (block , switchnode -> default_case );
641+ if (switchnode -> is_expr ) {
642+ AstTyped * * default_case = (AstTyped * * ) & switchnode -> default_case ;
643+ CHECK (expression , default_case );
644+
645+ if (switchnode -> type ) {
646+ TYPE_CHECK (default_case , switchnode -> type ) {
647+ ERROR_ ((* default_case )-> token -> pos , "Expected case expression to be of type '%s', got '%s'." ,
648+ type_get_name (switchnode -> type ),
649+ type_get_name ((* default_case )-> type ));
650+ }
651+ }
652+
653+ } else {
654+ CHECK (block , switchnode -> default_case );
655+ }
617656
618657 } else if (switchnode -> switch_kind == Switch_Kind_Union ) {
619658 // If there is no default case, and this is a union switch,
@@ -2403,6 +2442,10 @@ CheckStatus check_expression(AstTyped** pexpr) {
24032442 ERROR_ (cl -> token -> pos , "Cannot pass '%b' by pointer because it is not an l-value." , cl -> token -> text , cl -> token -> length );
24042443 }
24052444
2445+ if (cl -> captured_value -> kind == Ast_Kind_Local ) {
2446+ cl -> captured_value -> flags |= Ast_Flag_Address_Taken ;
2447+ }
2448+
24062449 expr -> type = type_make_pointer (context .ast_alloc , cl -> captured_value -> type );
24072450
24082451 } else {
@@ -2411,14 +2454,22 @@ CheckStatus check_expression(AstTyped** pexpr) {
24112454 break ;
24122455 }
24132456
2457+ case Ast_Kind_Switch : {
2458+ AstSwitch * switch_node = (AstSwitch * ) expr ;
2459+ assert (switch_node -> is_expr );
2460+
2461+ CHECK (switch , switch_node );
2462+ break ;
2463+ }
2464+
2465+ case Ast_Kind_Switch_Case : break ;
24142466 case Ast_Kind_File_Contents : break ;
24152467 case Ast_Kind_Overloaded_Function : break ;
24162468 case Ast_Kind_Enum_Value : break ;
24172469 case Ast_Kind_Polymorphic_Proc : break ;
24182470 case Ast_Kind_Package : break ;
24192471 case Ast_Kind_Error : break ;
24202472 case Ast_Kind_Unary_Field_Access : break ;
2421- case Ast_Kind_Switch_Case : break ;
24222473 case Ast_Kind_Foreign_Block : break ;
24232474 case Ast_Kind_Zero_Value : break ;
24242475 case Ast_Kind_Interface : break ;
@@ -3567,8 +3618,10 @@ CheckStatus check_constraint(AstConstraint *constraint) {
35673618 }
35683619
35693620 assert (constraint -> interface -> entity && constraint -> interface -> entity -> scope );
3621+ assert (constraint -> interface -> scope );
3622+ assert (constraint -> interface -> scope -> parent == constraint -> interface -> entity -> scope );
35703623
3571- constraint -> scope = scope_create (context .ast_alloc , constraint -> interface -> entity -> scope , constraint -> token -> pos );
3624+ constraint -> scope = scope_create (context .ast_alloc , constraint -> interface -> scope , constraint -> token -> pos );
35723625
35733626 if (bh_arr_length (constraint -> type_args ) != bh_arr_length (constraint -> interface -> params )) {
35743627 ERROR_ (constraint -> token -> pos , "Wrong number of arguments given to interface. Expected %d, got %d." ,
0 commit comments