diff --git a/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/typing.rs b/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/typing.rs index d6868f4b7c..69a3ba3242 100644 --- a/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/typing.rs +++ b/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/typing.rs @@ -168,7 +168,7 @@ impl Pass<'_> { for item in &array.items { item_type_ids.push(self.typing_of_expression(item).as_type_id()?); } - let element_type = self.types.common_mobile_type(&item_type_ids)?; + let element_type = self.types.type_of_array_literal(&item_type_ids)?; Some(self.types.register_type(Type::FixedSizeArray { element_type, size: array.items.len(), diff --git a/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/visitor.rs b/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/visitor.rs index 7205235d6e..79450d9db3 100644 --- a/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/visitor.rs +++ b/crates/solidity-v2/outputs/cargo/semantic/src/passes/p4_resolve_references/visitor.rs @@ -177,9 +177,9 @@ impl Visitor for Pass<'_> { // TODO(validation) SDR[47]: both true_expression and false_expression should // have the compatible types let type_id = match (true_type_id, false_type_id) { - (Some(true_type_id), Some(false_type_id)) => self - .types - .common_mobile_type(&[true_type_id, false_type_id]), + (Some(true_type_id), Some(false_type_id)) => { + self.types.common_mobile_type(true_type_id, false_type_id) + } _ => None, }; self.binder.set_node_type(node.id(), type_id); diff --git a/crates/solidity-v2/outputs/cargo/semantic/src/passes/tests/typing.rs b/crates/solidity-v2/outputs/cargo/semantic/src/passes/tests/typing.rs index 556be16c91..c02512b629 100644 --- a/crates/solidity-v2/outputs/cargo/semantic/src/passes/tests/typing.rs +++ b/crates/solidity-v2/outputs/cargo/semantic/src/passes/tests/typing.rs @@ -4,7 +4,7 @@ use slang_solidity_v2_common::versions::LanguageVersion; use slang_solidity_v2_ir::ir::{self, NodeIdGenerator}; use super::build_file; -use crate::binder::{Binder, Typing}; +use crate::binder::Binder; use crate::context::SemanticFile; use crate::passes::common::node_id_for_expression_typing; use crate::passes::{ @@ -12,65 +12,118 @@ use crate::passes::{ }; use crate::types::{DataLocation, LiteralKind, Type, TypeId, TypeRegistry}; -/// Parses a Solidity expression as the value of a top-level `uint constant`, -/// runs all semantic passes, and returns a clone of the inferred type assigned -/// to that value expression along with the populated `TypeRegistry`. -fn type_of_value_expression(input: &str) -> (Type, TypeRegistry) { - let (expr_type, types) = try_type_of_value_expression(input); - let expr_type = expr_type.expect("expected resolved type for value expression"); - (expr_type, types) -} - -/// Like `type_of_value_expression`, but returns `None` if the expression -/// did not resolve to a `Typing::Resolved`. Lets tests assert on the -/// "unresolved" outcome without panicking. -fn try_type_of_value_expression(input: &str) -> (Option, TypeRegistry) { - try_type_of_value_expression_in_context("", input) -} - -/// Like `type_of_value_expression`, with extra source-unit-level setup -/// (e.g. free function or contract definitions) prepended before the -/// `uint constant x = …` declaration. -fn type_of_value_expression_in_context(setup: &str, input: &str) -> (Type, TypeRegistry) { - let (expr_type, types) = try_type_of_value_expression_in_context(setup, input); - let expr_type = expr_type.expect("expected resolved type for value expression"); - (expr_type, types) -} +/// Wraps each expression in a no-op expression statement inside the body of an +/// `__test()` function of a synthesized `Test` contract, runs every semantic +/// pass, and returns the typing recorded for each expression (in input order) +/// along with the populated type registry. Non-`Resolved` typings come back +/// as `None`. +/// +/// `contract_context` is optional contract-level setup — state variables, +/// nested struct definitions, sibling member functions, etc. — inserted +/// before the `__test()` definition. +fn type_of_expressions( + language_version: LanguageVersion, + contract_context: Option<&str>, + expressions: &[&str], +) -> (Vec>, TypeRegistry) { + let context_block = contract_context.unwrap_or(""); + let expression_statements = expressions + .iter() + .map(|expr| format!("{expr};")) + .collect::>() + .join("\n"); + let source = format!( + "contract Test {{\n\ + {context_block}\n\ + function __test() internal {{\n\ + {expression_statements}\n\ + }}\n\ + }}\n" + ); -fn try_type_of_value_expression_in_context( - setup: &str, - input: &str, -) -> (Option, TypeRegistry) { - let source = format!("{setup}\nuint constant x = {input};"); let mut id_generator = NodeIdGenerator::default(); - let language_version = LanguageVersion::LATEST; let file = build_file("test.sol", &source, &mut id_generator, language_version); let files = [file]; let mut binder = Binder::default(); let mut types = TypeRegistry::default(); - p1_collect_definitions::run(&files, &mut binder); p2_linearise_contracts::run(&files, &mut binder); p3_type_definitions::run(&files, &mut binder, &mut types, language_version); p4_resolve_references::run(&files, &mut binder, &mut types, language_version); - // The constant declaration is always appended last; earlier members may - // exist when the test passes setup (free functions, contracts, etc.). - let value_expr = match files[0].ir_root().members.last().unwrap() { - ir::SourceUnitMember::ConstantDefinition(definition) => definition.value.as_ref().unwrap(), - other => panic!("expected ConstantDefinition, got {other:?}"), + let contract = match files[0].ir_root().members.last().unwrap() { + ir::SourceUnitMember::ContractDefinition(c) => c, + other => panic!("expected ContractDefinition, got {other:?}"), }; + let function = contract + .members + .iter() + .find_map(|member| match member { + ir::ContractMember::FunctionDefinition(f) + if f.name.as_ref().is_some_and(|n| n.unparse() == "__test") => + { + Some(f) + } + _ => None, + }) + .expect("__test function not found"); + let block = function.body.as_ref().expect("__test has a body"); - let expr_type = match binder.node_typing( - node_id_for_expression_typing(value_expr) - .expect("expression registers its typing in the binder"), - ) { - Typing::Resolved(type_id) => Some(types.get_type_by_id(type_id).clone()), - _ => None, - }; + let typings = block + .statements + .iter() + .filter_map(|stmt| match stmt { + ir::Statement::ExpressionStatement(s) => { + let node_id = node_id_for_expression_typing(&s.expression) + .expect("expression registers its typing in the binder"); + Some( + binder + .node_typing(node_id) + .as_type_id() + .map(|type_id| types.get_type_by_id(type_id)) + .cloned(), + ) + } + _ => None, + }) + .collect(); + + (typings, types) +} + +/// Convenience wrapper for `type_of_expressions` with a single expression and +/// no contract context. Panics if the typing didn't resolve. +fn type_of_expression(expr: &str) -> (Type, TypeRegistry) { + let (expr_type, types) = try_type_of_expression(expr); + ( + expr_type.expect("expected resolved type for expression"), + types, + ) +} + +/// Convenience wrapper for `type_of_expressions` with a single expression and +/// no contract context. Returns `None` if the typing didn't resolve. +fn try_type_of_expression(expr: &str) -> (Option, TypeRegistry) { + let (typings, types) = type_of_expressions(LanguageVersion::LATEST, None, &[expr]); + let typing = typings.into_iter().next().expect("at least one expression"); + (typing, types) +} - (expr_type, types) +/// Like `type_of_expression`, but with contract-level setup (state variables, +/// member functions, …) inserted before the `__test()` function. +fn type_of_expression_in_context(context: &str, expr: &str) -> (Type, TypeRegistry) { + let (expr_type, types) = try_type_of_expression_in_context(context, expr); + ( + expr_type.expect("expected resolved type for expression"), + types, + ) +} + +fn try_type_of_expression_in_context(context: &str, expr: &str) -> (Option, TypeRegistry) { + let (typings, types) = type_of_expressions(LanguageVersion::LATEST, Some(context), &[expr]); + let typing = typings.into_iter().next().expect("at least one expression"); + (typing, types) } fn register_uint_type(types: &mut TypeRegistry, bits: u32) -> TypeId { @@ -82,7 +135,7 @@ fn register_uint_type(types: &mut TypeRegistry, bits: u32) -> TypeId { #[test] fn test_value_bearing_integer_literal_types() { - let (type_, _) = type_of_value_expression("127"); + let (type_, _) = type_of_expression("127"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -90,7 +143,7 @@ fn test_value_bearing_integer_literal_types() { }) ); - let (type_, _) = type_of_value_expression("-128"); + let (type_, _) = type_of_expression("-128"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -100,7 +153,7 @@ fn test_value_bearing_integer_literal_types() { // Hex literals carry source byte width as `HexInteger`, distinct from // decimal `Integer` so the byte-array conversion rule can fire. - let (type_, _) = type_of_value_expression("0xff"); + let (type_, _) = type_of_expression("0xff"); assert_eq!( type_, Type::Literal(LiteralKind::HexInteger { @@ -110,7 +163,7 @@ fn test_value_bearing_integer_literal_types() { ); // Source byte width is preserved across leading zeros. - let (type_, _) = type_of_value_expression("0x0012"); + let (type_, _) = type_of_expression("0x0012"); assert_eq!( type_, Type::Literal(LiteralKind::HexInteger { @@ -120,7 +173,7 @@ fn test_value_bearing_integer_literal_types() { ); // Folding a hex literal demotes it to a plain `Integer` (provenance lost). - let (type_, _) = type_of_value_expression("0x10 + 0"); + let (type_, _) = type_of_expression("0x10 + 0"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -132,7 +185,7 @@ fn test_value_bearing_integer_literal_types() { #[test] fn test_binary_arithmetic_folds_to_narrowed_literal() { // Addition. - let (type_, _) = type_of_value_expression("1 + 1"); + let (type_, _) = type_of_expression("1 + 1"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -141,7 +194,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Multiplication. - let (type_, _) = type_of_value_expression("3 * 4"); + let (type_, _) = type_of_expression("3 * 4"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -150,7 +203,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Power. - let (type_, _) = type_of_value_expression("2 ** 10"); + let (type_, _) = type_of_expression("2 ** 10"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -159,7 +212,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Shift. - let (type_, _) = type_of_value_expression("1 << 32"); + let (type_, _) = type_of_expression("1 << 32"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -168,7 +221,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Reducible rational arithmetic normalises back to an integer. - let (type_, _) = type_of_value_expression("1.5 * 2"); + let (type_, _) = type_of_expression("1.5 * 2"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -177,7 +230,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Non-reducing rational division stays rational. - let (type_, _) = type_of_value_expression("5 / 2"); + let (type_, _) = type_of_expression("5 / 2"); assert_eq!( type_, Type::Literal(LiteralKind::Rational { @@ -186,7 +239,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { ); // Negation of a folded constant. - let (type_, _) = type_of_value_expression("-(1 + 1)"); + let (type_, _) = type_of_expression("-(1 + 1)"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -198,7 +251,7 @@ fn test_binary_arithmetic_folds_to_narrowed_literal() { #[test] fn test_binary_bitwise_folds_to_literal() { // OR - let (type_, _) = type_of_value_expression("1 | 2"); + let (type_, _) = type_of_expression("1 | 2"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -207,7 +260,7 @@ fn test_binary_bitwise_folds_to_literal() { ); // AND - let (type_, _) = type_of_value_expression("12 & 10"); + let (type_, _) = type_of_expression("12 & 10"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -216,7 +269,7 @@ fn test_binary_bitwise_folds_to_literal() { ); // XOR - let (type_, _) = type_of_value_expression("6 ^ 3"); + let (type_, _) = type_of_expression("6 ^ 3"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -226,7 +279,7 @@ fn test_binary_bitwise_folds_to_literal() { // Folding hex operands demotes the result to a plain `Integer` // (mirroring the additive folding behaviour). - let (type_, _) = type_of_value_expression("0xf0 | 0x0f"); + let (type_, _) = type_of_expression("0xf0 | 0x0f"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -236,7 +289,7 @@ fn test_binary_bitwise_folds_to_literal() { // Bitwise AND with a negative literal: BigInt uses arbitrary-precision // two's-complement, so `-1 & 0xff` masks to the low byte. - let (type_, _) = type_of_value_expression("(-1) & 0xff"); + let (type_, _) = type_of_expression("(-1) & 0xff"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -245,7 +298,7 @@ fn test_binary_bitwise_folds_to_literal() { ); // Bitwise OR of a folded constant feeds further folding. - let (type_, _) = type_of_value_expression("(1 | 2) ^ 4"); + let (type_, _) = type_of_expression("(1 | 2) ^ 4"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -257,7 +310,7 @@ fn test_binary_bitwise_folds_to_literal() { #[test] fn test_bitwise_not_folds_to_literal() { // ~x = -x - 1 (two's complement on an infinite-precision integer). - let (type_, _) = type_of_value_expression("~1"); + let (type_, _) = type_of_expression("~1"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -265,7 +318,7 @@ fn test_bitwise_not_folds_to_literal() { }) ); - let (type_, _) = type_of_value_expression("~0"); + let (type_, _) = type_of_expression("~0"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -274,7 +327,7 @@ fn test_bitwise_not_folds_to_literal() { ); // Double-complement returns to the original value. - let (type_, _) = type_of_value_expression("~(-1)"); + let (type_, _) = type_of_expression("~(-1)"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -283,7 +336,7 @@ fn test_bitwise_not_folds_to_literal() { ); // Folding `~hex` demotes the result to a plain `Integer`. - let (type_, _) = type_of_value_expression("~0xff"); + let (type_, _) = type_of_expression("~0xff"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -292,7 +345,7 @@ fn test_bitwise_not_folds_to_literal() { ); // `~` of a folded constant. - let (type_, _) = type_of_value_expression("~(1 | 2)"); + let (type_, _) = type_of_expression("~(1 | 2)"); assert_eq!( type_, Type::Literal(LiteralKind::Integer { @@ -304,23 +357,23 @@ fn test_bitwise_not_folds_to_literal() { #[test] fn test_bitwise_operations_unresolved_for_rationals() { // Bitwise binary operators don't apply to non-reducing rationals. - let (type_, _) = try_type_of_value_expression("1.5 | 1"); + let (type_, _) = try_type_of_expression("1.5 | 1"); assert_eq!(type_, None); - let (type_, _) = try_type_of_value_expression("1 & 0.5"); + let (type_, _) = try_type_of_expression("1 & 0.5"); assert_eq!(type_, None); - let (type_, _) = try_type_of_value_expression("0.5 ^ 0.25"); + let (type_, _) = try_type_of_expression("0.5 ^ 0.25"); assert_eq!(type_, None); // Likewise for the unary bitwise NOT. - let (type_, _) = try_type_of_value_expression("~0.5"); + let (type_, _) = try_type_of_expression("~0.5"); assert_eq!(type_, None); } #[test] fn test_implicit_conversion_uses_literal_value() { - let (_, mut types) = type_of_value_expression("0"); + let (_, mut types) = type_of_expression("0"); let int8 = types.register_type(Type::Integer { signed: true, @@ -388,7 +441,7 @@ fn test_implicit_conversion_uses_literal_value() { #[test] fn test_hex_literal_to_byte_array_conversion() { - let (_, mut types) = type_of_value_expression("0"); + let (_, mut types) = type_of_expression("0"); let bytes1 = types.register_type(Type::ByteArray { width: 1 }); let bytes2 = types.register_type(Type::ByteArray { width: 2 }); @@ -446,7 +499,7 @@ fn test_hex_literal_to_byte_array_conversion() { #[test] fn test_conditional_expression_unifies_branch_types() { // Both branches reify to uint8 — common type is uint8. - let (type_, _) = type_of_value_expression("true ? 1 : 2"); + let (type_, _) = type_of_expression("true ? 1 : 2"); assert_eq!( type_, Type::Integer { @@ -456,7 +509,7 @@ fn test_conditional_expression_unifies_branch_types() { ); // uint8 (1) widens to uint16 (256). - let (type_, _) = type_of_value_expression("true ? 1 : 256"); + let (type_, _) = type_of_expression("true ? 1 : 256"); assert_eq!( type_, Type::Integer { @@ -466,7 +519,7 @@ fn test_conditional_expression_unifies_branch_types() { ); // int8 (-1) and int8 (1) — common type is int8. - let (type_, _) = type_of_value_expression("true ? -1 : -128"); + let (type_, _) = type_of_expression("true ? -1 : -128"); assert_eq!( type_, Type::Integer { @@ -476,7 +529,7 @@ fn test_conditional_expression_unifies_branch_types() { ); // Both branches are string literals — both reify to `string memory`. - let (type_, _) = type_of_value_expression(r#"true ? "abc" : "x""#); + let (type_, _) = type_of_expression(r#"true ? "abc" : "x""#); assert_eq!( type_, Type::String { @@ -489,19 +542,19 @@ fn test_conditional_expression_unifies_branch_types() { fn test_conditional_expression_unresolved_when_branches_incompatible() { // uint8 (1) and int8 (-1): neither converts to the other at the same // bit width, so unification fails and the conditional is unresolved. - let (type_, _) = try_type_of_value_expression("true ? 1 : -1"); + let (type_, _) = try_type_of_expression("true ? 1 : -1"); assert_eq!(type_, None); // A non-reducing rational has no `reified` type yet, so any conditional // involving one is unresolved. - let (type_, _) = try_type_of_value_expression("true ? 0.5 : 1"); + let (type_, _) = try_type_of_expression("true ? 0.5 : 1"); assert_eq!(type_, None); } #[test] fn test_array_literal_unifies_element_types() { // Homogeneous uint8 elements. - let (expr_type, types) = type_of_value_expression("[1, 2, 3]"); + let (expr_type, types) = type_of_expression("[1, 2, 3]"); let Type::FixedSizeArray { element_type, size, @@ -515,7 +568,7 @@ fn test_array_literal_unifies_element_types() { assert_eq!(element_type, types.uint8()); // Mixed widths widen to the largest required. - let (expr_type, mut types) = type_of_value_expression("[1, 256, 3]"); + let (expr_type, mut types) = type_of_expression("[1, 256, 3]"); let Type::FixedSizeArray { element_type, size, .. } = expr_type @@ -526,7 +579,7 @@ fn test_array_literal_unifies_element_types() { assert_eq!(element_type, register_uint_type(&mut types, 16)); // Negative values force the result to a signed type. - let (expr_type, mut types) = type_of_value_expression("[-1, -2]"); + let (expr_type, mut types) = type_of_expression("[-1, -2]"); let Type::FixedSizeArray { element_type, .. } = expr_type else { panic!("expected FixedSizeArray, got {expr_type:?}"); }; @@ -539,7 +592,7 @@ fn test_array_literal_unifies_element_types() { ); // String literal arrays reify each element to `string memory`. - let (expr_type, types) = type_of_value_expression(r#"["abc", "x"]"#); + let (expr_type, types) = type_of_expression(r#"["abc", "x"]"#); let Type::FixedSizeArray { element_type, size, .. } = expr_type @@ -553,32 +606,48 @@ fn test_array_literal_unifies_element_types() { #[test] fn test_array_literal_unresolved_when_elements_incompatible() { // uint8 (1) and int8 (-1) cannot be unified (same bit width, opposite sign). - let (type_, _) = try_type_of_value_expression("[1, -1]"); + let (type_, _) = try_type_of_expression("[1, -1]"); assert_eq!(type_, None); // Non-reducing rationals don't reify yet — array unification fails. - let (type_, _) = try_type_of_value_expression("[0.5, 1]"); + let (type_, _) = try_type_of_expression("[0.5, 1]"); assert_eq!(type_, None); } #[test] fn test_conditional_expression_unifies_byte_arrays() { - let (expr_type, types) = type_of_value_expression("true ? bytes32(0) : bytes32(1)"); + let (expr_type, types) = type_of_expression("true ? bytes32(0) : bytes32(1)"); assert_eq!(expr_type, *types.get_type_by_id(types.bytes32())); } #[test] fn test_conditional_expression_widens_byte_arrays() { - let (expr_type, types) = type_of_value_expression("true ? bytes20(0) : bytes32(0)"); + let (expr_type, types) = type_of_expression("true ? bytes20(0) : bytes32(0)"); assert_eq!(expr_type, *types.get_type_by_id(types.bytes32())); - let (expr_type, types) = type_of_value_expression("true ? bytes32(0) : bytes20(0)"); + let (expr_type, types) = type_of_expression("true ? bytes32(0) : bytes20(0)"); assert_eq!(expr_type, *types.get_type_by_id(types.bytes32())); } #[test] fn test_array_literal_unifies_byte_array_elements() { - let (expr_type, types) = type_of_value_expression("[bytes32(0), bytes32(1)]"); + let (expr_type, types) = type_of_expression("[bytes32(0), bytes32(1)]"); + let Type::FixedSizeArray { + element_type, + size, + location, + } = expr_type + else { + panic!("expected FixedSizeArray, got {expr_type:?}"); + }; + assert_eq!(size, 2); + assert_eq!(location, DataLocation::Memory); + assert_eq!(element_type, types.bytes32()); +} + +#[test] +fn test_array_literal_unifies_byte_array_and_literal_zero() { + let (expr_type, types) = type_of_expression("[bytes32(0), 0]"); let Type::FixedSizeArray { element_type, size, @@ -592,12 +661,121 @@ fn test_array_literal_unifies_byte_array_elements() { assert_eq!(element_type, types.bytes32()); } +#[test] +fn test_conditional_expression_does_not_unify_byte_array_and_literal_zero() { + let (type_, _) = try_type_of_expression("true ? bytes32(0) : 0"); + assert_eq!(type_, None); +} + +#[test] +fn test_array_literal_does_not_unify_when_literal_is_first_and_byte_array_follows() { + // The first element of the array is used to find the common type + // Matches solc behaviour + let (type_, _) = try_type_of_expression("[0, bytes32(0)]"); + assert_eq!(type_, None); +} + +#[test] +fn test_array_literal_widens_past_first_element_integer_type() { + let (expr_type, mut types) = type_of_expression("[uint8(0), 256]"); + let Type::FixedSizeArray { + element_type, size, .. + } = expr_type + else { + panic!("expected FixedSizeArray, got {expr_type:?}"); + }; + assert_eq!(size, 2); + assert_eq!(element_type, register_uint_type(&mut types, 16)); +} + +#[test] +fn test_array_literal_unifies_byte_array_and_matching_hex_literal() { + let (expr_type, types) = type_of_expression("[bytes1(0x01), 0x01]"); + let Type::FixedSizeArray { + element_type, size, .. + } = expr_type + else { + panic!("expected FixedSizeArray, got {expr_type:?}"); + }; + assert_eq!(size, 2); + assert_eq!(element_type, types.bytes1()); +} + +#[test] +fn test_conditional_expression_loses_hex_literal_specialness() { + let (type_, _) = try_type_of_expression("true ? bytes1(0x01) : 0x01"); + assert_eq!(type_, None); +} + +#[test] +fn test_conditional_expression_widens_literal_to_concrete_integer() { + let (expr_type, types) = type_of_expression("true ? uint256(0) : 0"); + assert_eq!(expr_type, *types.get_type_by_id(types.uint256())); + + let (expr_type, types) = type_of_expression("true ? 0 : uint256(0)"); + assert_eq!(expr_type, *types.get_type_by_id(types.uint256())); +} + +#[test] +fn test_conditional_expression_unifies_mappings() { + let (expr_type, types) = try_type_of_expression_in_context( + "mapping(uint => uint) m1; mapping(uint => uint) m2;", + "true ? m1 : m2", + ); + let Some(Type::Mapping { + key_type_id, + value_type_id, + }) = expr_type + else { + panic!("expected Mapping, got {expr_type:?}"); + }; + assert_eq!(key_type_id, types.uint256()); + assert_eq!(value_type_id, types.uint256()); +} + +#[test] +fn test_conditional_expression_unifies_literal_tuples() { + let (expr_type, types) = type_of_expression("true ? (1, 2) : (3, 4)"); + let Type::Tuple { types: tuple_types } = expr_type else { + panic!("expected Tuple, got {expr_type:?}"); + }; + + assert_eq!(tuple_types.len(), 2); + assert_eq!(tuple_types[0], types.uint8()); + assert_eq!(tuple_types[1], types.uint8()); +} + +#[test] +fn test_mappings_only_unify_on_equal_elements() { + // Mappings must match on key and value types + let (expr_type, _) = try_type_of_expression_in_context( + "mapping(uint => int128) m1; mapping(uint => int256) m2;", + "true ? m1 : m2", + ); + assert_eq!(None, expr_type); +} + +#[test] +fn test_array_literal_rejects_mapping_element() { + let (type_, _) = try_type_of_expression_in_context( + "mapping(uint => uint) m1; mapping(uint => uint) m2;", + "[m1, m2]", + ); + assert_eq!(type_, None); +} + +#[test] +fn test_array_literal_does_not_unify_byte_array_and_non_zero_literal() { + let (type_, _) = try_type_of_expression("[bytes32(0), 1]"); + assert_eq!(type_, None); +} + #[test] fn test_bitwise_or_widens_byte_arrays() { - let (expr_type, types) = type_of_value_expression("bytes20(0) | bytes32(0)"); + let (expr_type, types) = type_of_expression("bytes20(0) | bytes32(0)"); assert_eq!(expr_type, *types.get_type_by_id(types.bytes32())); - let (expr_type, types) = type_of_value_expression("bytes32(0) | bytes20(0)"); + let (expr_type, types) = type_of_expression("bytes32(0) | bytes20(0)"); assert_eq!(expr_type, *types.get_type_by_id(types.bytes32())); } @@ -607,7 +785,7 @@ fn test_overload_resolution_widens_byte_array_argument() { function pick(bytes32 a) pure returns (uint8) { a; return 1; } function pick(string memory a) pure returns (uint16) { a; return 2; } "; - let (type_, _) = type_of_value_expression_in_context(setup, "pick(bytes20(0))"); + let (type_, _) = type_of_expression_in_context(setup, "pick(bytes20(0))"); assert_eq!( type_, Type::Integer { @@ -623,7 +801,7 @@ fn test_overload_resolution_rejects_byte_array_narrowing() { function pick(bytes20 a) pure returns (uint8) { a; return 1; } function pick(string memory a) pure returns (uint16) { a; return 2; } "; - let (type_, _) = try_type_of_value_expression_in_context(setup, "pick(bytes32(0))"); + let (type_, _) = try_type_of_expression_in_context(setup, "pick(bytes32(0))"); // Neither overload matches: `bytes32` does not convert to `bytes20` nor // to `string`. The call is unresolved. assert_eq!(type_, None); @@ -631,184 +809,100 @@ fn test_overload_resolution_rejects_byte_array_narrowing() { #[test] fn test_conditional_expression_unifies_booleans() { - let (type_, _) = type_of_value_expression("true ? true : false"); + let (type_, _) = type_of_expression("true ? true : false"); assert_eq!(type_, Type::Boolean); } #[test] fn test_string_literal_byte_count_with_escapes() { // Plain ASCII: one byte per char. - let (type_, _) = type_of_value_expression(r#""abc""#); + let (type_, _) = type_of_expression(r#""abc""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 3 })); // Each `\n`, `\t`, etc. decodes to a single byte. - let (type_, _) = type_of_value_expression(r#""\n\t\\""#); + let (type_, _) = type_of_expression(r#""\n\t\\""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 3 })); // `\xNN` escapes decode to one byte each, regardless of the 4-char source // length per escape. - let (type_, _) = type_of_value_expression(r#""\x41\x42""#); + let (type_, _) = type_of_expression(r#""\x41\x42""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 2 })); // Line continuations (`\`) decode to nothing. - let (type_, _) = type_of_value_expression("\"a\\\nb\""); + let (type_, _) = type_of_expression("\"a\\\nb\""); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 2 })); // Concatenated string literals: byte counts add up across pieces. - let (type_, _) = type_of_value_expression(r#""abc" "de""#); + let (type_, _) = type_of_expression(r#""abc" "de""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 5 })); } #[test] fn test_unicode_string_literal_byte_count() { // ASCII unicode-string literal: one byte per char. - let (type_, _) = type_of_value_expression(r#"unicode"abc""#); + let (type_, _) = type_of_expression(r#"unicode"abc""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 3 })); // Multi-byte UTF-8 passes through with its full byte length: // `€` is 3 bytes in UTF-8. - let (type_, _) = type_of_value_expression(r#"unicode"€""#); + let (type_, _) = type_of_expression(r#"unicode"€""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 3 })); // `\uNNNN` escapes decode to their UTF-8 byte length: // `\u20AC` (€) → 3 bytes, `\u00A2` (¢) → 2 bytes, `\u0024` ($) → 1 byte. - let (type_, _) = type_of_value_expression(r#"unicode"\u20AC\u00A2\u0024""#); + let (type_, _) = type_of_expression(r#"unicode"\u20AC\u00A2\u0024""#); assert_eq!(type_, Type::Literal(LiteralKind::String { bytes: 6 })); } #[test] fn test_hex_string_literal_byte_count() { // Pairs of hex digits, no separators: one byte per pair. - let (type_, _) = type_of_value_expression(r#"hex"414243""#); + let (type_, _) = type_of_expression(r#"hex"414243""#); assert_eq!(type_, Type::Literal(LiteralKind::HexString { bytes: 3 })); // Underscore separators don't contribute to the decoded length. - let (type_, _) = type_of_value_expression(r#"hex"41_42""#); + let (type_, _) = type_of_expression(r#"hex"41_42""#); assert_eq!(type_, Type::Literal(LiteralKind::HexString { bytes: 2 })); // Concatenated hex string literals: byte counts add up across pieces. - let (type_, _) = type_of_value_expression(r#"hex"4142" hex"43""#); + let (type_, _) = type_of_expression(r#"hex"4142" hex"43""#); assert_eq!(type_, Type::Literal(LiteralKind::HexString { bytes: 3 })); } -/// Locates a function definition by `contract_name` and `function_name` within -/// a source unit. Panics if either the contract or the function is not found. -fn find_contract_function<'a>( - source_unit: &'a ir::SourceUnit, - contract_name: &str, - function_name: &str, -) -> &'a ir::FunctionDefinition { - let contract = source_unit - .members - .iter() - .find_map(|member| match member { - ir::SourceUnitMember::ContractDefinition(contract) - if contract.name.unparse() == contract_name => - { - Some(contract) - } - _ => None, - }) - .unwrap_or_else(|| panic!("contract {contract_name} not found")); - - contract - .members - .iter() - .find_map(|member| match member { - ir::ContractMember::FunctionDefinition(function) - if function - .name - .as_ref() - .is_some_and(|name| name.unparse() == function_name) => - { - Some(function) - } - _ => None, - }) - .unwrap_or_else(|| panic!("function {function_name} not found in contract {contract_name}")) -} - -/// Resolves expression statement types in the body of the given function, in -/// source order. Skips non-expression statements. -fn expression_statement_types( - function: &ir::FunctionDefinition, - binder: &Binder, - types: &TypeRegistry, -) -> Vec { - let body = function.body.as_ref().expect("function has no body"); - body.statements - .iter() - .filter_map(|stmt| { - let ir::Statement::ExpressionStatement(expression_statement) = stmt else { - return None; - }; - let node_id = node_id_for_expression_typing(&expression_statement.expression) - .expect("expression has no NodeId for typing"); - let Typing::Resolved(type_id) = binder.node_typing(node_id) else { - panic!("expression did not resolve to a type"); - }; - Some(types.get_type_by_id(type_id).clone()) - }) - .collect() -} - #[test] fn test_data_locations_of_state_variable_and_getter_accesses() { - const CONTENTS: &str = r#" -contract Test { - struct Foo { - bytes xs; - } - bytes public bs; - Foo public foo; - - function test(Test t) internal view { - bs; // bytes storage - foo.xs; // bytes storage - t.bs(); // bytes memory - t.foo(); // bytes memory - } -} -"#; - - let mut id_generator = NodeIdGenerator::default(); - let language_version = LanguageVersion::LATEST; - let file = build_file("test.sol", CONTENTS, &mut id_generator, language_version); - let files = [file]; - - let mut binder = Binder::default(); - let mut types = TypeRegistry::default(); - - p1_collect_definitions::run(&files, &mut binder); - p2_linearise_contracts::run(&files, &mut binder); - p3_type_definitions::run(&files, &mut binder, &mut types, language_version); - p4_resolve_references::run(&files, &mut binder, &mut types, language_version); - - let function = find_contract_function(files[0].ir_root(), "Test", "test"); - let expression_types = expression_statement_types(function, &binder, &types); - // In source order: - // - `bs;` — internal access to a `bytes` storage variable: `bytes storage`. - // - `foo.xs;` — `xs` is declared with `Inherited` location inside the + // - `bs` — internal access to a `bytes` storage variable: `bytes storage`. + // - `foo.xs` — `xs` is declared with `Inherited` location inside the // struct; the member access propagates the operand's storage location. - // - `t.bs();` — external call to the auto-generated getter of `bytes bs`; + // - `t.bs()` — external call to the auto-generated getter of `bytes bs`; // the returned reference type lives in memory. - // - `t.foo();` — external call to the auto-generated getter of `Foo foo`. + // - `t.foo()` — external call to the auto-generated getter of `Foo foo`. // `Foo` has a single returnable field (`bytes xs`), so the getter // returns just `bytes`, again in memory. + let (typings, _) = type_of_expressions( + LanguageVersion::LATEST, + Some( + "struct Foo { bytes xs; }\n\ + bytes public bs;\n\ + Foo public foo;\n\ + Test t;", + ), + &["bs", "foo.xs", "t.bs()", "t.foo()"], + ); let expected = vec![ - Type::Bytes { + Some(Type::Bytes { location: DataLocation::Storage, - }, - Type::Bytes { + }), + Some(Type::Bytes { location: DataLocation::Storage, - }, - Type::Bytes { + }), + Some(Type::Bytes { location: DataLocation::Memory, - }, - Type::Bytes { + }), + Some(Type::Bytes { location: DataLocation::Memory, - }, + }), ]; - assert_eq!(expression_types, expected); + assert_eq!(typings, expected); } diff --git a/crates/solidity-v2/outputs/cargo/semantic/src/types/mod.rs b/crates/solidity-v2/outputs/cargo/semantic/src/types/mod.rs index c1dbc09e05..13a2e8e291 100644 --- a/crates/solidity-v2/outputs/cargo/semantic/src/types/mod.rs +++ b/crates/solidity-v2/outputs/cargo/semantic/src/types/mod.rs @@ -103,6 +103,25 @@ pub enum LiteralKind { Address, } +impl LiteralKind { + /// Returns the non-literal `Type` this literal flows into when its source + /// position needs a concrete EVM type. + pub(crate) fn mobile_type(&self) -> Option { + match self { + LiteralKind::Integer { value } | LiteralKind::HexInteger { value, .. } => { + numbers::smallest_integer_type_to_fit(value) + } + // TODO: not supported yet, but narrow the rational type to the + // smallest fixed/ufixed available (eg. 1.2 -> ufixed8x1). + LiteralKind::Rational { .. } => None, + LiteralKind::HexString { .. } | LiteralKind::String { .. } => Some(Type::String { + location: DataLocation::Memory, + }), + LiteralKind::Address => Some(Type::Address { payable: false }), + } + } +} + #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct FunctionType { pub definition_id: Option, // this may point to a FunctionDefinition @@ -335,30 +354,10 @@ impl Type { } } - /// Returns a new non-literal `Type` this can flow into. For use when - /// computing the type of literal arrays or conditional branches. - pub(crate) fn mobile_type(&self) -> Option { - match self { - Type::Literal( - LiteralKind::Integer { value } | LiteralKind::HexInteger { value, .. }, - ) => numbers::smallest_integer_type_to_fit(value), - Type::Literal(LiteralKind::Rational { .. }) => { - // TODO: not supported yet, but narrow the rational type to the - // smallest fixed/ufixed available (eg. 1.2 -> ufixed8x1). - None - } - Type::Literal(LiteralKind::HexString { .. } | LiteralKind::String { .. }) => { - Some(Type::String { - location: DataLocation::Memory, - }) - } - Type::Literal(LiteralKind::Address) => Some(Type::Address { payable: false }), - - // Some values cannot be elements of arrays - Type::Mapping { .. } | Type::Tuple { .. } | Type::Void => None, - - // Return self for all other cases - _ => Some(self.clone()), - } + /// Whether this type can appear as the element type of an array literal. + /// + /// TODO: This probably has a better way to resolve it, looking at the storage location + pub(crate) fn can_be_array_element(&self) -> bool { + !matches!(self, Type::Mapping { .. } | Type::Tuple { .. } | Type::Void) } } diff --git a/crates/solidity-v2/outputs/cargo/semantic/src/types/registry.rs b/crates/solidity-v2/outputs/cargo/semantic/src/types/registry.rs index a9cfbb3fad..54506d8e35 100644 --- a/crates/solidity-v2/outputs/cargo/semantic/src/types/registry.rs +++ b/crates/solidity-v2/outputs/cargo/semantic/src/types/registry.rs @@ -130,6 +130,7 @@ impl TypeRegistry { self.types.get_index(type_id.0).unwrap() } + #[allow(clippy::too_many_lines)] pub(crate) fn implicitly_convertible_to( &self, from_type_id: TypeId, @@ -141,19 +142,6 @@ impl TypeRegistry { let from_type = self.get_type_by_id(from_type_id); let to_type = self.get_type_by_id(to_type_id); - self.internal_implicitly_convertible_to(from_type, to_type) - } - - #[allow(clippy::too_many_lines)] - fn internal_implicitly_convertible_to(&self, from_type: &Type, to_type: &Type) -> bool { - // The public `implicitly_convertible_to` already - // short-circuits on `TypeId` equality, but this function is also called - // directly by `common_mobile_type` with `&Type` values constructed from - // `mobile_type()`, bypassing that short-circuit. - if from_type == to_type { - return true; - } - match (from_type, to_type) { ( Type::Address { @@ -552,30 +540,75 @@ impl TypeRegistry { self.register_type(type_with_location) } + /// Computes the mobile type of `type_id` and returns its `TypeId`. + pub(crate) fn compute_mobile_type(&mut self, type_id: TypeId) -> Option { + match self.get_type_by_id(type_id).clone() { + Type::Literal(kind) => { + let mobile = kind.mobile_type()?; + Some(self.register_type(mobile)) + } + Type::Tuple { types: element_ids } => { + let mobile_ids: Option> = element_ids + .iter() + .map(|id| self.compute_mobile_type(*id)) + .collect(); + Some(self.register_type(Type::Tuple { types: mobile_ids? })) + } + _ => Some(type_id), + } + } + + /// Returns the common type of two types, making them mobile first. + /// + /// Stricter than [`Self::type_of_array_literal`] — a literal that + /// flows into `bytesN` via a literal-specific rule won't + /// flow that way through a ternary, because the literal is mobile-typed + /// to an integer first. + /// + /// Matches solc's ternary semantics. + pub(crate) fn common_mobile_type(&mut self, left: TypeId, right: TypeId) -> Option { + let left_mobile = self.compute_mobile_type(left)?; + let right_mobile = self.compute_mobile_type(right)?; + if self.implicitly_convertible_to(left_mobile, right_mobile) { + return Some(right_mobile); + } + if self.implicitly_convertible_to(right_mobile, left_mobile) { + return Some(left_mobile); + } + None + } + /// Return a type that can be stored in the EVM and can hold values of all /// the given types. The first element dictates the type class. Returns - /// `None` if the types cannot be reified or they are not compatible. This - /// is used to unify types of literal arrays and conditional branches. - pub(crate) fn common_mobile_type(&mut self, type_ids: &[TypeId]) -> Option { - if type_ids.is_empty() { + /// `None` if the types cannot be reified, they are not compatible, or they don't belong + /// in an array literal. + /// + /// This is used to unify types of literal arrays. + /// Only the first element is mobile-typed unconditionally. + pub(crate) fn type_of_array_literal(&mut self, type_ids: &[TypeId]) -> Option { + let (first_id, rest) = type_ids.split_first()?; + if !self.get_type_by_id(*first_id).can_be_array_element() { + // TODO(validation) SDR[750]: Error if the element type can't be an array element return None; } - let initial_type = self.get_type_by_id(type_ids[0]); - let mut element_type = initial_type.mobile_type()?; - - for item_type_id in &type_ids[1..] { - let item_type = self.get_type_by_id(*item_type_id).mobile_type()?; - if self.internal_implicitly_convertible_to(&item_type, &element_type) { - // ok, `element_type` can already hold `item_type` - } else if self.internal_implicitly_convertible_to(&element_type, &item_type) { - // `item_type` is "bigger" - element_type = item_type; - } else { - // TODO(validation) SDR[1741]: types are not compatible + let mut element_type_id = self.compute_mobile_type(*first_id)?; + for &item_type_id in rest { + if self.implicitly_convertible_to(item_type_id, element_type_id) { + // Item already fits the accumulator + continue; + } + if !self.get_type_by_id(item_type_id).can_be_array_element() { + // TODO(validation) SDR[750]: Error if the element type can't be an array element + return None; + } + let item_mobile_type_id = self.compute_mobile_type(item_type_id)?; + if !self.implicitly_convertible_to(element_type_id, item_mobile_type_id) { + // TODO(validation) SDR[1741,1353]: types are not compatible return None; } + element_type_id = item_mobile_type_id; } - Some(self.register_type(element_type)) + Some(element_type_id) } // Returns true if a function type overrides another