Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion front/parser/src/import.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ast::{ASTNode, StatementNode};
use crate::ast::{ASTNode};
use crate::parse;
use error::error::{WaveError, WaveErrorKind};
use lexer::Lexer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ pub(crate) fn gen<'ctx, 'a>(
iv.as_basic_value_enum()
}

Some(BasicTypeEnum::ArrayType(at)) => {
let elem = at.get_element_type();
return gen(env, lit, Some(elem));
}

Some(BasicTypeEnum::PointerType(ptr_ty)) => {
let s = v.as_str();
let (neg, raw) = parse_signed_decimal(s);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::ExprGenEnv;
use inkwell::values::{BasicValue, BasicValueEnum};
use parser::ast::{Expression, WaveType};
use parser::ast::{Expression};
use crate::llvm_temporary::llvm_codegen::generate_address_ir;

pub(crate) fn gen_struct_literal<'ctx, 'a>(
Expand Down
4 changes: 2 additions & 2 deletions llvm_temporary/src/llvm_temporary/llvm_codegen/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use inkwell::values::{BasicValue, BasicValueEnum};
use parser::ast::{Expression, Literal, WaveType};
use std::collections::HashMap;

use super::types::wave_type_to_llvm_type;
use super::types::{wave_type_to_llvm_type, TypeFlavor};

fn parse_signed_decimal<'a>(s: &'a str) -> (bool, &'a str) {
if let Some(rest) = s.strip_prefix('-') {
Expand All @@ -28,7 +28,7 @@ pub(super) fn create_llvm_const_value<'ctx>(
expr: &Expression,
) -> BasicValueEnum<'ctx> {
let struct_types = HashMap::new();
let llvm_type = wave_type_to_llvm_type(context, ty, &struct_types);
let llvm_type = wave_type_to_llvm_type(context, ty, &struct_types, TypeFlavor::AbiC);

match (expr, llvm_type) {
// new: int literal is string-based
Expand Down
14 changes: 7 additions & 7 deletions llvm_temporary/src/llvm_temporary/llvm_codegen/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::collections::HashMap;
use crate::llvm_temporary::statement::generate_statement_ir;

use super::consts::create_llvm_const_value;
use super::types::{wave_type_to_llvm_type, VariableInfo};
use super::types::{wave_type_to_llvm_type, TypeFlavor, VariableInfo};

fn build_struct_field_map(ast: &[ASTNode]) -> HashMap<String, Vec<String>> {
let mut m = HashMap::new();
Expand Down Expand Up @@ -81,7 +81,7 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode]) -> String {
let field_types: Vec<BasicTypeEnum> = struct_node
.fields
.iter()
.map(|(_, ty)| wave_type_to_llvm_type(context, ty, &struct_types))
.map(|(_, ty)| wave_type_to_llvm_type(context, ty, &struct_types, TypeFlavor::AbiC))
.collect();

st.set_body(&field_types, false);
Expand Down Expand Up @@ -134,13 +134,13 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode]) -> String {
{
let param_types: Vec<BasicMetadataTypeEnum> = parameters
.iter()
.map(|p| wave_type_to_llvm_type(context, &p.param_type, &struct_types).into())
.map(|p| wave_type_to_llvm_type(context, &p.param_type, &struct_types, TypeFlavor::AbiC).into())
.collect();

let fn_type = match return_type {
None | Some(WaveType::Void) => context.void_type().fn_type(&param_types, false),
Some(wave_ret_ty) => {
let llvm_ret_type = wave_type_to_llvm_type(context, wave_ret_ty, &struct_types);
let llvm_ret_type = wave_type_to_llvm_type(context, wave_ret_ty, &struct_types, TypeFlavor::AbiC);
llvm_ret_type.fn_type(&param_types, false)
}
};
Expand All @@ -149,13 +149,13 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode]) -> String {
let param_types: Vec<BasicMetadataTypeEnum> = ext
.params
.iter()
.map(|(_, ty)| wave_type_to_llvm_type(context, ty, &struct_types).into())
.map(|(_, ty)| wave_type_to_llvm_type(context, ty, &struct_types, TypeFlavor::AbiC).into())
.collect();

let fn_type = match &ext.return_type {
WaveType::Void => context.void_type().fn_type(&param_types, false),
ret_ty => {
let llvm_ret = wave_type_to_llvm_type(context, ret_ty, &struct_types);
let llvm_ret = wave_type_to_llvm_type(context, ret_ty, &struct_types, TypeFlavor::AbiC);
llvm_ret.fn_type(&param_types, false)
}
};
Expand Down Expand Up @@ -184,7 +184,7 @@ pub unsafe fn generate_ir(ast_nodes: &[ASTNode]) -> String {
let mut loop_continue_stack = vec![];

for (i, param) in func_node.parameters.iter().enumerate() {
let llvm_type = wave_type_to_llvm_type(context, &param.param_type, &struct_types);
let llvm_type = wave_type_to_llvm_type(context, &param.param_type, &struct_types, TypeFlavor::AbiC);
let alloca = builder.build_alloca(llvm_type, &param.name).unwrap();
let param_val = function.get_nth_param(i as u32).unwrap();
builder.build_store(alloca, param_val).unwrap();
Expand Down
54 changes: 33 additions & 21 deletions llvm_temporary/src/llvm_temporary/llvm_codegen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ use std::collections::HashMap;

pub type StructFieldMap = HashMap<String, HashMap<String, u32>>;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum TypeFlavor {
Value,
AbiC,
}

pub fn build_field_map(fields: &[(String, parser::ast::WaveType)]) -> HashMap<String, u32> {
let mut m = HashMap::new();
for (i, (name, _ty)) in fields.iter().enumerate() {
Expand All @@ -32,40 +38,46 @@ pub fn wave_type_to_llvm_type<'ctx>(
context: &'ctx Context,
wave_type: &WaveType,
struct_types: &HashMap<String, inkwell::types::StructType<'ctx>>,
flavor: TypeFlavor,
) -> BasicTypeEnum<'ctx> {
match wave_type {
WaveType::Int(bits) => context
.custom_width_int_type(*bits as u32)
.as_basic_type_enum(),
WaveType::Uint(bits) => context
.custom_width_int_type(*bits as u32)
.as_basic_type_enum(),
WaveType::Int(bits) | WaveType::Uint(bits) => {
context.custom_width_int_type(*bits as u32).as_basic_type_enum()
}

WaveType::Float(bits) => match bits {
32 => context.f32_type().as_basic_type_enum(),
64 => context.f64_type().as_basic_type_enum(),
_ => panic!("Unsupported float bit width: {}", bits),
},
WaveType::Bool => context.bool_type().as_basic_type_enum(),
WaveType::Char => context.i8_type().as_basic_type_enum(),
WaveType::Byte => context.i8_type().as_basic_type_enum(),
WaveType::Void => context.i8_type().as_basic_type_enum(), // fallback (shouldn't be used)
WaveType::Pointer(inner) => wave_type_to_llvm_type(context, inner, struct_types)

WaveType::Bool => {
if flavor == TypeFlavor::AbiC {
context.i8_type().as_basic_type_enum()
} else {
context.bool_type().as_basic_type_enum()
}
}

WaveType::Char | WaveType::Byte => context.i8_type().as_basic_type_enum(),

WaveType::Void => context.i8_type().as_basic_type_enum(),

WaveType::Pointer(inner) => wave_type_to_llvm_type(context, inner, struct_types, flavor)
.ptr_type(AddressSpace::default())
.as_basic_type_enum(),

WaveType::Array(inner, size) => {
let inner_ty = wave_type_to_llvm_type(context, inner, struct_types);
let inner_ty = wave_type_to_llvm_type(context, inner, struct_types, flavor);
inner_ty.array_type(*size as u32).as_basic_type_enum()
}
WaveType::String => context
.i8_type()
.ptr_type(AddressSpace::default())

WaveType::String => context.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum(),

WaveType::Struct(name) => struct_types
.get(name)
.unwrap_or_else(|| panic!("Struct type '{}' not found", name))
.as_basic_type_enum(),
WaveType::Struct(name) => {
let struct_ty = struct_types
.get(name)
.unwrap_or_else(|| panic!("Struct type '{}' not found", name));
struct_ty.as_basic_type_enum()
}
}
}

Expand Down
111 changes: 86 additions & 25 deletions llvm_temporary/src/llvm_temporary/statement/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use inkwell::module::Module;
use inkwell::values::{BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallableValue, PointerValue};
use inkwell::{InlineAsmDialect};
use parser::ast::{Expression, Literal};
use std::collections::{HashMap, HashSet};
use std::collections::{HashMap};
use inkwell::types::{AnyTypeEnum, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, StringRadix};
use crate::llvm_temporary::llvm_codegen::plan::*;

Expand Down Expand Up @@ -82,11 +82,29 @@ pub(super) fn gen_asm_stmt_ir<'ctx>(
let iv = val.into_int_value();
let target_ty = context.custom_width_int_type(bits);

if iv.get_type() != target_ty {
val = builder
.build_int_truncate(iv, target_ty, "asm_in_trunc")
.unwrap()
.as_basic_value_enum();
let src_bits = iv.get_type().get_bit_width();
let dst_bits = bits;

if src_bits != dst_bits {
if src_bits > dst_bits {
val = builder
.build_int_truncate(iv, target_ty, "asm_in_trunc")
.unwrap()
.as_basic_value_enum();
} else {
let signed = infer_signedness(inp.value, variables).unwrap_or(false);
val = if signed {
builder
.build_int_s_extend(iv, target_ty, "asm_in_sext")
.unwrap()
.as_basic_value_enum()
} else {
builder
.build_int_z_extend(iv, target_ty, "asm_in_zext")
.unwrap()
.as_basic_value_enum()
};
}
}
}
}
Expand All @@ -101,9 +119,19 @@ pub(super) fn gen_asm_stmt_ir<'ctx>(
let mut out_tys: Vec<BasicTypeEnum<'ctx>> = Vec::with_capacity(plan.outputs.len());

for o in &plan.outputs {
let (place, ty) = resolve_out_place_and_type(context, builder, variables, o.target);
let (place, dst_ty) = resolve_out_place_and_type(context, builder, variables, o.target);

let mut asm_ty = dst_ty;
if let Some(reg) = extract_reg_from_constraint(&o.reg_norm) {
if let Some(bits) = reg_width_bits(&reg) {
if dst_ty.is_int_type() {
asm_ty = context.custom_width_int_type(bits).as_basic_type_enum();
}
}
}

out_places.push(place);
out_tys.push(ty);
out_tys.push(asm_ty);
}

let fn_type = if out_tys.is_empty() {
Expand Down Expand Up @@ -152,6 +180,36 @@ pub(super) fn gen_asm_stmt_ir<'ctx>(
}
}

fn infer_signedness<'ctx>(
expr: &Expression,
variables: &HashMap<String, VariableInfo<'ctx>>,
) -> Option<bool> {
match expr {
Expression::Variable(name) => variables.get(name).map(|v| match &v.ty {
parser::ast::WaveType::Int(_) => true,
parser::ast::WaveType::Uint(_) => false,
_ => true,
}),
Expression::Grouped(inner) => infer_signedness(inner, variables),
Expression::Deref(inner) => {
if let Expression::Variable(name) = inner.as_ref() {
variables.get(name).and_then(|v| match &v.ty {
parser::ast::WaveType::Pointer(inner_ty) => match inner_ty.as_ref() {
parser::ast::WaveType::Int(_) => Some(true),
parser::ast::WaveType::Uint(_) => Some(false),
_ => None,
},
_ => None,
})
} else { None }
}
Expression::Literal(Literal::Int(s)) => {
if s.trim_start().starts_with('-') { Some(true) } else { None }
}
_ => None,
}
}

fn resolve_out_place_and_type<'ctx>(
context: &'ctx inkwell::context::Context,
builder: &'ctx inkwell::builder::Builder<'ctx>,
Expand Down Expand Up @@ -245,29 +303,32 @@ fn coerce_basic_value_for_store<'ctx>(
}

// pointer <- pointer/int
if dst_ty.is_pointer_type() {
let dst_ptr = dst_ty.into_pointer_type();
if dst_ty.is_int_type() {
let dst_int = dst_ty.into_int_type();

if value.is_int_value() {
let v = value.into_int_value();
let src_bits = v.get_type().get_bit_width();
let dst_bits = dst_int.get_bit_width();

if src_bits == dst_bits {
return v.as_basic_value_enum();
} else if src_bits > dst_bits {
return builder.build_int_truncate(v, dst_int, "asm_int_trunc").unwrap().as_basic_value_enum();
} else {
return builder.build_int_z_extend(v, dst_int, "asm_int_zext").unwrap().as_basic_value_enum();
}
}

if value.is_pointer_value() {
return builder
.build_bit_cast(value.into_pointer_value(), dst_ptr, "asm_ptr_cast")
.unwrap()
.as_basic_value_enum();
return builder.build_ptr_to_int(value.into_pointer_value(), dst_int, "asm_ptr_to_int").unwrap().as_basic_value_enum();
}

if value.is_int_value() {
return builder
.build_int_to_ptr(value.into_int_value(), dst_ptr, "asm_int_to_ptr")
.unwrap()
.as_basic_value_enum();
if value.is_float_value() {
return builder.build_float_to_signed_int(value.into_float_value(), dst_int, "asm_fptosi").unwrap().as_basic_value_enum();
}

panic!(
"Cannot coerce asm output '{}' from {:?} to pointer {:?}",
name,
value.get_type(),
dst_ty
);
panic!("Cannot coerce asm output '{}' to int {:?}", name, dst_ty);
}

// int <- int/pointer/float
Expand Down
5 changes: 3 additions & 2 deletions llvm_temporary/src/llvm_temporary/statement/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use inkwell::types::{BasicTypeEnum, StructType};
use inkwell::values::{BasicValue, BasicValueEnum};
use parser::ast::{Expression, VariableNode, WaveType};
use std::collections::HashMap;
use crate::llvm_temporary::llvm_codegen::types::TypeFlavor;

#[derive(Copy, Clone, Debug)]
pub enum CoercionMode {
Expand Down Expand Up @@ -111,7 +112,7 @@ pub(super) fn gen_variable_ir<'ctx>(
} = var_node;

unsafe {
let llvm_type = wave_type_to_llvm_type(context, type_name, struct_types);
let llvm_type = wave_type_to_llvm_type(context, type_name, struct_types, TypeFlavor::AbiC);
let alloca = builder.build_alloca(llvm_type, name).unwrap();

if let (WaveType::Array(element_type, size), Some(Expression::ArrayLiteral(values))) =
Expand All @@ -125,7 +126,7 @@ pub(super) fn gen_variable_ir<'ctx>(
);
}

let llvm_element_type = wave_type_to_llvm_type(context, element_type, struct_types);
let llvm_element_type = wave_type_to_llvm_type(context, element_type, struct_types, TypeFlavor::AbiC);

for (i, value_expr) in values.iter().enumerate() {
let value = generate_expression_ir(
Expand Down