Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions crates/cairo-lang-lowering/src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ use crate::objects::{
};
use crate::{
Block, BlockEnd, Location, Lowered, MatchArm, MatchEnumInfo, MatchEnumValue, MatchInfo,
StatementDesnap, StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
VarRemapping, VarUsage, Variable,
StatementDesnap, StatementEnumConstruct, StatementIntoBox, StatementSnapshot,
StatementStructConstruct, VarRemapping, VarUsage, Variable,
};

type LookupCache = (CacheLookups, Vec<(DefsFunctionWithBodyIdCached, MultiLoweringCached)>);
Expand Down Expand Up @@ -909,6 +909,9 @@ enum StatementCached {
// Enums.
EnumConstruct(StatementEnumConstructCached),

// Boxing.
IntoBox(StatementIntoBoxCached),

Snapshot(StatementSnapshotCached),
Desnap(StatementDesnapCached),
}
Expand All @@ -933,6 +936,9 @@ impl StatementCached {
Statement::Desnap(stmt) => {
StatementCached::Desnap(StatementDesnapCached::new(stmt, ctx))
}
Statement::IntoBox(stmt) => {
StatementCached::IntoBox(StatementIntoBoxCached::new(stmt, ctx))
}
}
}
fn embed<'db>(self, ctx: &mut CacheLoadingContext<'db>) -> Statement<'db> {
Expand All @@ -946,6 +952,7 @@ impl StatementCached {
StatementCached::EnumConstruct(stmt) => Statement::EnumConstruct(stmt.embed(ctx)),
StatementCached::Snapshot(stmt) => Statement::Snapshot(stmt.embed(ctx)),
StatementCached::Desnap(stmt) => Statement::Desnap(stmt.embed(ctx)),
StatementCached::IntoBox(stmt) => Statement::IntoBox(stmt.embed(ctx)),
}
}
}
Expand Down Expand Up @@ -1246,6 +1253,23 @@ impl StatementDesnapCached {
}
}

#[derive(Serialize, Deserialize)]
struct StatementIntoBoxCached {
input: VarUsageCached,
output: usize,
}
impl StatementIntoBoxCached {
fn new<'db>(stmt: StatementIntoBox<'db>, ctx: &mut CacheSavingContext<'db>) -> Self {
Self { input: VarUsageCached::new(stmt.input, ctx), output: stmt.output.index() }
}
fn embed<'db>(self, ctx: &mut CacheLoadingContext<'db>) -> StatementIntoBox<'db> {
StatementIntoBox {
input: self.input.embed(ctx),
output: ctx.lowered_variables_id[self.output],
}
}
}

#[derive(Serialize, Deserialize, Clone)]
struct LocationCached {
/// The stable location of the object.
Expand Down
3 changes: 2 additions & 1 deletion crates/cairo-lang-lowering/src/concretize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ pub fn concretize_lowered<'db>(
Statement::Snapshot(_)
| Statement::Desnap(_)
| Statement::StructConstruct(_)
| Statement::StructDestructure(_) => {}
| Statement::StructDestructure(_)
| Statement::IntoBox(_) => {}
}
}
if let BlockEnd::Match { info } = &mut block.end {
Expand Down
15 changes: 13 additions & 2 deletions crates/cairo-lang-lowering/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use itertools::Itertools;
use salsa::Database;

use crate::objects::{
MatchExternInfo, Statement, StatementCall, StatementConst, StatementStructDestructure,
VariableId,
MatchExternInfo, Statement, StatementCall, StatementConst, StatementIntoBox,
StatementStructDestructure, VariableId,
};
use crate::{
Block, BlockEnd, Lowered, MatchArm, MatchEnumInfo, MatchEnumValue, MatchInfo, StatementDesnap,
Expand Down Expand Up @@ -188,6 +188,7 @@ impl<'db> DebugWithDb<'db> for Statement<'db> {
Statement::EnumConstruct(stmt) => stmt.fmt(f, ctx),
Statement::Snapshot(stmt) => stmt.fmt(f, ctx),
Statement::Desnap(stmt) => stmt.fmt(f, ctx),
Statement::IntoBox(stmt) => stmt.fmt(f, ctx),
}
}
}
Expand Down Expand Up @@ -365,3 +366,13 @@ impl<'db> DebugWithDb<'db> for StatementDesnap<'db> {
write!(f, ")")
}
}

impl<'db> DebugWithDb<'db> for StatementIntoBox<'db> {
type Db = LoweredFormatter<'db>;

fn fmt(&self, f: &mut std::fmt::Formatter<'_>, ctx: &Self::Db) -> std::fmt::Result {
write!(f, "into_box(")?;
self.input.fmt(f, ctx)?;
write!(f, ")")
}
}
22 changes: 21 additions & 1 deletion crates/cairo-lang-lowering/src/lower/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use cairo_lang_semantic as semantic;
use cairo_lang_semantic::ConcreteVariant;
use cairo_lang_semantic::corelib::core_box_ty;
use cairo_lang_semantic::items::constant::ConstValueId;
use cairo_lang_utils::{Intern, extract_matches};
use itertools::chain;
Expand All @@ -15,7 +16,7 @@ use crate::objects::{
Statement, StatementCall, StatementConst, StatementStructConstruct, StatementStructDestructure,
VarUsage,
};
use crate::{StatementDesnap, StatementEnumConstruct, StatementSnapshot};
use crate::{StatementDesnap, StatementEnumConstruct, StatementIntoBox, StatementSnapshot};

#[derive(Clone, Default)]
pub struct StatementsBuilder<'db> {
Expand Down Expand Up @@ -256,3 +257,22 @@ impl<'db> StructConstruct<'db> {
VarUsage { var_id: output, location: self.location }
}
}

/// Generator for [StatementIntoBox].
pub struct IntoBox<'db> {
pub input: VarUsage<'db>,
pub location: LocationId<'db>,
}
impl<'db> IntoBox<'db> {
pub fn add(
self,
ctx: &mut LoweringContext<'db, '_>,
builder: &mut StatementsBuilder<'db>,
) -> VarUsage<'db> {
let input_ty = ctx.variables[self.input.var_id].ty;
let output_ty = core_box_ty(ctx.db, input_ty);
let output = ctx.new_var(VarRequest { ty: output_ty, location: self.location });
builder.push_statement(Statement::IntoBox(StatementIntoBox { input: self.input, output }));
VarUsage { var_id: output, location: self.location }
}
}
12 changes: 11 additions & 1 deletion crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ fn perform_function_call<'db>(
function_call_info;

// If the function is not extern, simply call it.
if function.try_get_extern_function_id(ctx.db).is_none() {
let Some(extern_function_id) = function.try_get_extern_function_id(ctx.db) else {
let call_result = generators::Call {
function: function.lowered(ctx.db),
inputs,
Expand Down Expand Up @@ -1390,6 +1390,16 @@ fn perform_function_call<'db>(

// Extern function.
assert!(coupon_input.is_none(), "Extern functions cannot have a __coupon__ argument.");

// Handle into_box specially - emit IntoBox instead of a call.
let info = ctx.db.core_info();
if extern_function_id == info.into_box {
assert!(extra_ret_tys.is_empty(), "into_box should not have extra return types");
let input = inputs.into_iter().exactly_one().expect("into_box expects exactly one input");
let res = generators::IntoBox { input, location }.add(ctx, &mut builder.statements);
return Ok((vec![], LoweredExpr::AtVariable(res)));
}

let ret_tys = extern_facade_return_tys(ctx, ret_ty);
let call_result = generators::Call {
function: function.lowered(ctx.db),
Expand Down
18 changes: 18 additions & 0 deletions crates/cairo-lang-lowering/src/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ pub enum Statement<'db> {

Snapshot(StatementSnapshot<'db>),
Desnap(StatementDesnap<'db>),

// Boxing.
IntoBox(StatementIntoBox<'db>),
}
impl<'db> Statement<'db> {
pub fn inputs(&self) -> &[VarUsage<'db>] {
Expand All @@ -305,6 +308,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => std::slice::from_ref(&stmt.input),
Statement::Snapshot(stmt) => std::slice::from_ref(&stmt.input),
Statement::Desnap(stmt) => std::slice::from_ref(&stmt.input),
Statement::IntoBox(stmt) => std::slice::from_ref(&stmt.input),
}
}

Expand All @@ -317,6 +321,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => std::slice::from_mut(&mut stmt.input),
Statement::Snapshot(stmt) => std::slice::from_mut(&mut stmt.input),
Statement::Desnap(stmt) => std::slice::from_mut(&mut stmt.input),
Statement::IntoBox(stmt) => std::slice::from_mut(&mut stmt.input),
}
}

Expand All @@ -329,6 +334,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => std::slice::from_ref(&stmt.output),
Statement::Snapshot(stmt) => stmt.outputs.as_slice(),
Statement::Desnap(stmt) => std::slice::from_ref(&stmt.output),
Statement::IntoBox(stmt) => std::slice::from_ref(&stmt.output),
}
}

Expand All @@ -341,6 +347,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => std::slice::from_mut(&mut stmt.output),
Statement::Snapshot(stmt) => stmt.outputs.as_mut_slice(),
Statement::Desnap(stmt) => std::slice::from_mut(&mut stmt.output),
Statement::IntoBox(stmt) => std::slice::from_mut(&mut stmt.output),
}
}
pub fn location(&self) -> Option<LocationId<'db>> {
Expand All @@ -353,6 +360,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => Some(stmt.input.location),
Statement::Snapshot(stmt) => Some(stmt.input.location),
Statement::Desnap(stmt) => Some(stmt.input.location),
Statement::IntoBox(stmt) => Some(stmt.input.location),
}
}
pub fn location_mut(&mut self) -> Option<&mut LocationId<'db>> {
Expand All @@ -364,6 +372,7 @@ impl<'db> Statement<'db> {
Statement::EnumConstruct(stmt) => Some(&mut stmt.input.location),
Statement::Snapshot(stmt) => Some(&mut stmt.input.location),
Statement::Desnap(stmt) => Some(&mut stmt.input.location),
Statement::IntoBox(stmt) => Some(&mut stmt.input.location),
}
}
}
Expand Down Expand Up @@ -470,6 +479,15 @@ pub struct StatementDesnap<'db> {
pub output: VariableId,
}

/// A statement that constructs a box from a value.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StatementIntoBox<'db> {
/// The value to box.
pub input: VarUsage<'db>,
/// The variable to bind the boxed value to.
pub output: VariableId,
}

/// An arm of a match statement.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MatchArm<'db> {
Expand Down
37 changes: 20 additions & 17 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::utils::InliningStrategy;
use crate::{
Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, MatchArm, MatchEnumInfo,
MatchExternInfo, MatchInfo, Statement, StatementCall, StatementConst, StatementDesnap,
StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableArena, VariableId,
};

Expand Down Expand Up @@ -308,6 +308,25 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
};
self.var_info.insert(*output, value);
}
Statement::IntoBox(StatementIntoBox { input, output }) => {
let var_info = self.var_info.get(&input.var_id);
let const_value = match var_info {
Some(VarInfo::Const(val)) => Some(*val),
Some(VarInfo::Snapshot(info)) => {
try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
}
_ => None,
};
let var_info =
var_info.cloned().or_else(|| var_info_if_copy(self.variables, *input));
if let Some(var_info) = var_info {
self.var_info.insert(*output, VarInfo::Box(var_info.into()));
}

if let Some(const_value) = const_value {
*stmt = Statement::Const(StatementConst::new_boxed(const_value, *output));
}
}
}
}

Expand Down Expand Up @@ -601,19 +620,6 @@ impl<'db, 'mt> ConstFoldingContext<'db, 'mt> {
self.storage_base_address_const.concretize(db, vec![arg]).lowered(db);
}
None
} else if id == self.into_box {
let input = stmt.inputs[0];
let var_info = self.var_info.get(&input.var_id);
let const_value = match var_info {
Some(VarInfo::Const(val)) => Some(*val),
Some(VarInfo::Snapshot(info)) => {
try_extract_matches!(info.as_ref(), VarInfo::Const).copied()
}
_ => None,
};
let var_info = var_info.cloned().or_else(|| var_info_if_copy(self.variables, input))?;
self.var_info.insert(stmt.outputs[0], VarInfo::Box(var_info.into()));
Some(Statement::Const(StatementConst::new_boxed(const_value?, stmt.outputs[0])))
} else if id == self.unbox {
if let VarInfo::Box(inner) = self.var_info.get(&stmt.inputs[0].var_id)? {
let inner = inner.as_ref().clone();
Expand Down Expand Up @@ -1438,8 +1444,6 @@ pub struct ConstFoldingLibfuncInfo<'db> {
felt_mul: ExternFunctionId<'db>,
/// The `felt252_div` libfunc.
felt_div: ExternFunctionId<'db>,
/// The `into_box` libfunc.
into_box: ExternFunctionId<'db>,
/// The `unbox` libfunc.
unbox: ExternFunctionId<'db>,
/// The `box_forward_snapshot` libfunc.
Expand Down Expand Up @@ -1598,7 +1602,6 @@ impl<'db> ConstFoldingLibfuncInfo<'db> {
felt_add: core.extern_function_id("felt252_add"),
felt_mul: core.extern_function_id("felt252_mul"),
felt_div: core.extern_function_id("felt252_div"),
into_box: box_module.extern_function_id("into_box"),
unbox: box_module.extern_function_id("unbox"),
box_forward_snapshot: box_module.generic_function_id("box_forward_snapshot"),
eq_fns,
Expand Down
13 changes: 11 additions & 2 deletions crates/cairo-lang-lowering/src/optimizations/dedup_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::ids::FunctionId;
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{
Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
StatementStructDestructure, VarRemapping, VarUsage, VariableArena, VariableId,
};

Expand Down Expand Up @@ -59,7 +59,10 @@ enum CanonicStatement<'db> {
input: CanonicVar,
output: CanonicVar,
},

BoxConstruct {
input: CanonicVar,
output: CanonicVar,
},
Snapshot {
input: CanonicVar,
outputs: [CanonicVar; 2],
Expand Down Expand Up @@ -163,6 +166,12 @@ impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
input: self.handle_input(input),
output: self.handle_output(output),
},
Statement::IntoBox(StatementIntoBox { input, output }) => {
CanonicStatement::BoxConstruct {
input: self.handle_input(input),
output: self.handle_output(output),
}
}
}
}
}
Expand Down
Loading
Loading