Skip to content

Commit 882e769

Browse files
authored
refactor: decompose llvm_codegen into functional modules (#262)
Split the monolithic `llvm_codegen.rs` into specialized submodules for better organization and maintainability. Changes: - **Module Structure**: - `llvm_temporary/llvm_codegen/` (New directory) - `mod.rs`: Re-exports public functions. - `ir.rs`: Core AST traversal and module/function setup (`generate_ir`). - `address.rs`: Memory address calculation (`generate_address_ir`). - `types.rs`: Type conversion (`wave_type_to_llvm_type`) and `VariableInfo`. - `consts.rs`: Constant creation (`create_llvm_const_value`). - `format.rs`: Format string parsing (`wave_format_to_c`, `wave_format_to_scanf`). - `legacy.rs`: Deprecated `TokenType`-based functions (kept for compatibility). - **Refactoring**: - Moved relevant functions from `llvm_codegen.rs` to their respective new files. - Updated `wave_format_to_c` signature to accept `Context` reference (required for float type checking). - Updated `gen_print_format_ir` in `statement/io.rs` to pass `context` to `wave_format_to_c`. - **Cleanup**: Removed the original single `llvm_codegen.rs` file. This modularization clarifies the responsibilities of the codegen backend and simplifies future extensions. Signed-off-by: LunaStev <luna@lunastev.org>
1 parent 6030073 commit 882e769

File tree

9 files changed

+479
-442
lines changed

9 files changed

+479
-442
lines changed

llvm_temporary/src/llvm_temporary/llvm_codegen.rs

Lines changed: 0 additions & 441 deletions
This file was deleted.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use inkwell::context::Context;
2+
use inkwell::values::PointerValue;
3+
4+
use parser::ast::Expression;
5+
use std::collections::HashMap;
6+
7+
use super::types::VariableInfo;
8+
9+
pub fn generate_address_ir<'ctx>(
10+
context: &'ctx Context,
11+
builder: &'ctx inkwell::builder::Builder<'ctx>,
12+
expr: &Expression,
13+
variables: &mut HashMap<String, VariableInfo<'ctx>>,
14+
module: &'ctx inkwell::module::Module<'ctx>,
15+
) -> PointerValue<'ctx> {
16+
match expr {
17+
Expression::Grouped(inner) => {
18+
generate_address_ir(context, builder, inner, variables, module)
19+
}
20+
21+
Expression::Variable(name) => {
22+
let var_info = variables
23+
.get(name)
24+
.unwrap_or_else(|| panic!("Variable {} not found", name));
25+
26+
var_info.ptr
27+
}
28+
29+
Expression::Deref(inner_expr) => {
30+
let mut inner: &Expression = inner_expr.as_ref();
31+
while let Expression::Grouped(g) = inner {
32+
inner = g.as_ref();
33+
}
34+
35+
match inner {
36+
Expression::Variable(var_name) => {
37+
let ptr_to_ptr = variables
38+
.get(var_name)
39+
.unwrap_or_else(|| panic!("Variable {} not found", var_name))
40+
.ptr;
41+
42+
let actual_ptr = builder.build_load(ptr_to_ptr, "deref_target").unwrap();
43+
actual_ptr.into_pointer_value()
44+
}
45+
_ => panic!("Cannot take address: deref target is not a variable"),
46+
}
47+
}
48+
49+
_ => panic!("Cannot take address of this expression"),
50+
}
51+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use inkwell::context::Context;
2+
use inkwell::types::BasicTypeEnum;
3+
use inkwell::values::{BasicValue, BasicValueEnum};
4+
5+
use parser::ast::{Expression, Literal, WaveType};
6+
use std::collections::HashMap;
7+
8+
use super::types::wave_type_to_llvm_type;
9+
10+
pub(super) fn create_llvm_const_value<'ctx>(
11+
context: &'ctx Context,
12+
ty: &WaveType,
13+
expr: &Expression,
14+
) -> BasicValueEnum<'ctx> {
15+
let struct_types = HashMap::new();
16+
let llvm_type = wave_type_to_llvm_type(context, ty, &struct_types);
17+
match (expr, llvm_type) {
18+
(Expression::Literal(Literal::Number(n)), BasicTypeEnum::IntType(int_ty)) => {
19+
int_ty.const_int(*n as u64, true).as_basic_value_enum()
20+
}
21+
(Expression::Literal(Literal::Float(f)), BasicTypeEnum::FloatType(float_ty)) => {
22+
float_ty.const_float(*f).as_basic_value_enum()
23+
}
24+
_ => panic!("Constant expression must be a literal of a compatible type."),
25+
}
26+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
use inkwell::context::Context;
2+
use inkwell::types::{AnyTypeEnum, BasicTypeEnum};
3+
4+
pub fn wave_format_to_c<'ctx>(
5+
context: &'ctx Context,
6+
format: &str,
7+
arg_types: &[BasicTypeEnum<'ctx>],
8+
) -> String {
9+
let mut result = String::new();
10+
let mut chars = format.chars().peekable();
11+
let mut arg_index = 0;
12+
13+
while let Some(c) = chars.next() {
14+
if c == '{' {
15+
if let Some('}') = chars.peek() {
16+
chars.next(); // consume '}'
17+
18+
let ty = arg_types
19+
.get(arg_index)
20+
.unwrap_or_else(|| panic!("Missing argument for format"));
21+
22+
let fmt = match ty {
23+
BasicTypeEnum::IntType(int_ty) => {
24+
let bits = int_ty.get_bit_width();
25+
match bits {
26+
8 => "%hhd",
27+
16 => "%hd",
28+
32 => "%d",
29+
64 => "%ld",
30+
128 => "%lld",
31+
_ => "%d",
32+
}
33+
}
34+
BasicTypeEnum::FloatType(float_ty) => {
35+
if *float_ty == context.f32_type() {
36+
"%f"
37+
} else {
38+
"%lf"
39+
}
40+
}
41+
BasicTypeEnum::PointerType(_) => "%p",
42+
_ => panic!("Unsupported type in format"),
43+
};
44+
45+
result.push_str(fmt);
46+
arg_index += 1;
47+
continue;
48+
}
49+
}
50+
51+
result.push(c);
52+
}
53+
54+
result
55+
}
56+
57+
pub fn wave_format_to_scanf(format: &str, arg_types: &[AnyTypeEnum]) -> String {
58+
let mut result = String::new();
59+
let mut chars = format.chars().peekable();
60+
let mut arg_index = 0;
61+
62+
while let Some(c) = chars.next() {
63+
if c == '{' {
64+
if let Some('}') = chars.peek() {
65+
chars.next(); // consume '}'
66+
67+
let ty = arg_types
68+
.get(arg_index)
69+
.unwrap_or_else(|| panic!("Missing argument for format"));
70+
71+
let fmt = match ty {
72+
AnyTypeEnum::IntType(_) => "%d",
73+
AnyTypeEnum::FloatType(_) => "%f",
74+
AnyTypeEnum::PointerType(_) => {
75+
panic!("Cannot input into a pointer type directly")
76+
}
77+
_ => panic!("Unsupported type in scanf format"),
78+
};
79+
80+
result.push_str(fmt);
81+
arg_index += 1;
82+
continue;
83+
}
84+
}
85+
86+
result.push(c);
87+
}
88+
89+
result
90+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use inkwell::context::Context;
2+
use inkwell::passes::{PassManager, PassManagerBuilder};
3+
use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum};
4+
use inkwell::values::{BasicValue, BasicValueEnum, FunctionValue};
5+
use inkwell::OptimizationLevel;
6+
7+
use parser::ast::{ASTNode, FunctionNode, Mutability, VariableNode, WaveType};
8+
use std::collections::HashMap;
9+
10+
use crate::llvm_temporary::statement::generate_statement_ir;
11+
12+
use super::consts::create_llvm_const_value;
13+
use super::types::{wave_type_to_llvm_type, VariableInfo};
14+
15+
pub unsafe fn generate_ir(ast_nodes: &[ASTNode]) -> String {
16+
let context: &'static Context = Box::leak(Box::new(Context::create()));
17+
let module: &'static _ = Box::leak(Box::new(context.create_module("main")));
18+
let builder: &'static _ = Box::leak(Box::new(context.create_builder()));
19+
20+
let pass_manager_builder = PassManagerBuilder::create();
21+
pass_manager_builder.set_optimization_level(OptimizationLevel::Aggressive);
22+
let pass_manager: PassManager<inkwell::module::Module> = PassManager::create(());
23+
pass_manager_builder.populate_module_pass_manager(&pass_manager);
24+
25+
let struct_types: HashMap<String, inkwell::types::StructType> = HashMap::new();
26+
let mut struct_field_indices: HashMap<String, HashMap<String, u32>> = HashMap::new();
27+
28+
let mut global_consts: HashMap<String, BasicValueEnum> = HashMap::new();
29+
30+
for ast in ast_nodes {
31+
if let ASTNode::Variable(VariableNode {
32+
name,
33+
type_name,
34+
initial_value,
35+
mutability,
36+
}) = ast
37+
{
38+
if *mutability == Mutability::Const {
39+
let initial_value = initial_value
40+
.as_ref()
41+
.expect("Constant must be initialized.");
42+
let const_val = create_llvm_const_value(context, type_name, initial_value);
43+
global_consts.insert(name.clone(), const_val);
44+
}
45+
}
46+
}
47+
48+
let mut struct_types: HashMap<String, inkwell::types::StructType> = HashMap::new();
49+
for ast in ast_nodes {
50+
if let ASTNode::Struct(struct_node) = ast {
51+
let field_types: Vec<BasicTypeEnum> = struct_node
52+
.fields
53+
.iter()
54+
.map(|(_, ty)| wave_type_to_llvm_type(context, ty, &struct_types))
55+
.collect();
56+
let struct_ty = context.struct_type(&field_types, false);
57+
struct_types.insert(struct_node.name.clone(), struct_ty);
58+
59+
let mut index_map = HashMap::new();
60+
for (i, (field_name, _)) in struct_node.fields.iter().enumerate() {
61+
index_map.insert(field_name.clone(), i as u32);
62+
}
63+
struct_field_indices.insert(struct_node.name.clone(), index_map);
64+
}
65+
}
66+
67+
let mut proto_functions: Vec<(String, FunctionNode)> = Vec::new();
68+
for ast in ast_nodes {
69+
if let ASTNode::ProtoImpl(proto_impl) = ast {
70+
for method in &proto_impl.methods {
71+
let new_name = format!("{}_{}", proto_impl.target, method.name);
72+
let mut new_fn = method.clone();
73+
new_fn.name = new_name.clone();
74+
proto_functions.push((new_name, new_fn));
75+
}
76+
}
77+
}
78+
79+
let mut functions: HashMap<String, FunctionValue> = HashMap::new();
80+
81+
let function_nodes: Vec<FunctionNode> = ast_nodes
82+
.iter()
83+
.filter_map(|ast| {
84+
if let ASTNode::Function(f) = ast {
85+
Some(f.clone())
86+
} else {
87+
None
88+
}
89+
})
90+
.chain(proto_functions.iter().map(|(_, f)| f.clone()))
91+
.collect();
92+
93+
for FunctionNode {
94+
name,
95+
parameters,
96+
return_type,
97+
..
98+
} in &function_nodes
99+
{
100+
let param_types: Vec<BasicMetadataTypeEnum> = parameters
101+
.iter()
102+
.map(|p| wave_type_to_llvm_type(context, &p.param_type, &struct_types).into())
103+
.collect();
104+
105+
let fn_type = match return_type {
106+
None | Some(WaveType::Void) => context.void_type().fn_type(&param_types, false),
107+
Some(wave_ret_ty) => {
108+
let llvm_ret_type = wave_type_to_llvm_type(context, wave_ret_ty, &struct_types);
109+
llvm_ret_type.fn_type(&param_types, false)
110+
}
111+
};
112+
113+
let function = module.add_function(name, fn_type, None);
114+
functions.insert(name.clone(), function);
115+
}
116+
117+
for func_node in &function_nodes {
118+
let function = *functions.get(&func_node.name).unwrap();
119+
let entry_block = context.append_basic_block(function, "entry");
120+
builder.position_at_end(entry_block);
121+
122+
let mut variables: HashMap<String, VariableInfo> = HashMap::new();
123+
let mut string_counter = 0;
124+
let mut loop_exit_stack = vec![];
125+
let mut loop_continue_stack = vec![];
126+
127+
for (i, param) in func_node.parameters.iter().enumerate() {
128+
let llvm_type = wave_type_to_llvm_type(context, &param.param_type, &struct_types);
129+
let alloca = builder.build_alloca(llvm_type, &param.name).unwrap();
130+
let param_val = function.get_nth_param(i as u32).unwrap();
131+
builder.build_store(alloca, param_val).unwrap();
132+
133+
variables.insert(
134+
param.name.clone(),
135+
VariableInfo {
136+
ptr: alloca,
137+
mutability: Mutability::Let,
138+
ty: param.param_type.clone(),
139+
},
140+
);
141+
}
142+
143+
for stmt in &func_node.body {
144+
if let ASTNode::Statement(_) | ASTNode::Variable(_) = stmt {
145+
generate_statement_ir(
146+
context,
147+
builder,
148+
module,
149+
&mut string_counter,
150+
stmt,
151+
&mut variables,
152+
&mut loop_exit_stack,
153+
&mut loop_continue_stack,
154+
function,
155+
&global_consts,
156+
&struct_types,
157+
&struct_field_indices,
158+
);
159+
} else {
160+
panic!("Unsupported node inside function '{}'", func_node.name);
161+
}
162+
}
163+
164+
let current_block = builder.get_insert_block().unwrap();
165+
if current_block.get_terminator().is_none() {
166+
let is_void_like = match &func_node.return_type {
167+
None => true,
168+
Some(WaveType::Void) => true,
169+
_ => false,
170+
};
171+
172+
if is_void_like {
173+
builder.build_return(None).unwrap();
174+
} else {
175+
panic!(
176+
"Non-void function '{}' is missing a return statement",
177+
func_node.name
178+
);
179+
// builder.build_unreachable().unwrap();
180+
}
181+
}
182+
}
183+
184+
pass_manager.run_on(module);
185+
module.print_to_string().to_string()
186+
}

0 commit comments

Comments
 (0)