Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/luminal_2_eggplant/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/target
*.dot
54 changes: 54 additions & 0 deletions crates/luminal_2_eggplant/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[package]
name = "luminal_2_eggplant"
version = "0.1.0"
edition = "2024"
rust-version = "1.89"

[features]
default = []
cuda = ["dep:cudarc", "dep:luminal_cuda"]
metal = ["dep:objc2", "dep:objc2-metal", "dep:objc2-foundation"]

[dependencies]
luminal = { path = "../../" }
luminal_cuda = { path = "../luminal_cuda", optional = true }
cudarc = { version = "0.16.6", features = [
"f16",
"cuda-12080",
], optional = true }
#metal-rs = { version = "0.28.0", package = "metal", optional=true }
objc2 = { version = "0.6.2", optional = true }
objc2-metal = { version = "0.3.1", optional = true }
itertools = "0.14.0"
urlencoding = "2.1.3"
webbrowser = "1.0.4"
regex = "1.11.1"
serde_json = "1.0.140"
colored = "3.0.0"
generational-box = "0.6.2"
egg = "0.9.5"
symbolic_expressions = "5.0.3"
rustc-hash = "2.1.1"
rand = "0.9.1"
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", branch = "main" }
indexmap = "2.9.0"
serde = "1.0.219"
ratatui = "0.29.0"
crossterm = "0.29.0"
libc = "0.2.174"
unicode-width = "0.2"
objc2-foundation = { version = "0.3.1", optional = true }
eframe = "0.28"
egui = "0.28"
anyhow = "1.0.99"

# eggplant related
eggplant = "0.2.6"
derive_more = { version = "2.0.1", features = [
"deref_mut",
"deref",
"into_iterator",
"debug",
] }
strum = { version = "0.27.2", features = ["strum_macros"] }
strum_macros = "0.27.2"
105 changes: 105 additions & 0 deletions crates/luminal_2_eggplant/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use eggplant::egglog;
use eggplant::prelude::*;
use serde::Deserialize;
use serde::Serialize;

// datatype((datatype Expression(MNum i64:args_name "num")(MVar String:args_name "name")(MAdd Expression Expression:args_name "l,r")(MSub Expression Expression:args_name "l,r")(MMul Expression Expression:args_name "l,r")(MDiv Expression Expression:args_name "l,r")(MMod Expression Expression:args_name "l,r")(MMin Expression Expression:args_name "l,r")(MMax Expression Expression:args_name "l,r")(MAnd Expression Expression:args_name "l,r")(MOr Expression Expression:args_name "l,r")(MGte Expression Expression:args_name "l,r")(MLt Expression Expression:args_name "l,r")(MFloorTo Expression Expression:args_name "l,r")(MReplace Expression Expression Expression:args_name "l,r,rpl")(MAccum String:args_name "name")))#[doc = "DSl Generated"]
#[eggplant::dsl]
pub enum Expr {
MNum { num: i64 },
MVar { name: String },
MAdd { l: Expr, r: Expr },
MSub { l: Expr, r: Expr },
MMul { l: Expr, r: Expr },
MDiv { l: Expr, r: Expr },
MMod { l: Expr, r: Expr },
MMin { l: Expr, r: Expr },
MMax { l: Expr, r: Expr },
MAnd { l: Expr, r: Expr },
MOr { l: Expr, r: Expr },
MGte { l: Expr, r: Expr },
MLt { l: Expr, r: Expr },
MFloorTo { l: Expr, r: Expr },
MReplace { l: Expr, r: Expr, rpl: Expr },
MAccum { name: String },
}

#[eggplant::base_ty]
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq, Default)]
pub enum UnOp {
Exp2,
Log2,
Sqrt,
Sin,
Recip,
Neg,
#[default]
Unknown,
}
#[eggplant::base_ty]
#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq, Default)]
pub enum BinOp {
Add,
Sub,
Mul,
#[default]
Unknown,
}

#[eggplant::dsl(base=BinOp,base=UnOp)]
enum IR {
GMEM {
name: String,
},
LoopIn {
ir: IR,
l: Expr,
r: Expr,
},
LoopOut {
ir: IR,
l: Expr,
r: Expr,
},
Unary {
op: UnOp,
ir: IR,
},
Binary {
op: BinOp,
l: IR,
r: IR,
},
SwapLoops {
ir: IR,
level: i64,
},
TileLoop {
ir: IR,
level: i64,
},
MergeLoops {
ir: IR,
level: i64,
},
TCMatmul {
inp_a: IR,
inp_b: IR,
a_k_stride: Expr,
b_k_stride: Expr,
a_inner_stride: Expr,
b_inner_stride: Expr,
c_inner_stride: Expr,
num_k_loops: Expr,
},
TiledMatmulInputA {
ir: IR,
num: i64,
expr: Expr,
},
TiledMatmulInputB {
ir: IR,
num: i64,
expr: Expr,
},
}
28 changes: 28 additions & 0 deletions crates/luminal_2_eggplant/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
mod datatypes;
mod rules;
mod shortcut;
use datatypes::*;
use eggplant::{prelude::*, tx_rx_vt_pr};

tx_rx_vt_pr!(MyTx, MyPatRec);
fn main() {
// let expr: Expr<MyTx, ()> = (MNum::new(4) * 3) + 2;
let expr: Expr<MyTx, _> = MNum::new(4);
let ruleset = MyTx::new_ruleset("expr");
rules::add_rules::<MyTx>(ruleset);
expr.commit();
MyTx::run_ruleset(ruleset, RunConfig::Sat);
}

#[test]
fn test_const_fold() {
let expr: Expr<MyTx, _> = MNum::new(3) * MNum::new(4) + MNum::new(2);
expr.commit();
let ruleset = MyTx::new_ruleset("expr");
rules::add_rules::<MyTx>(ruleset);
MyTx::run_ruleset(ruleset, RunConfig::Sat);

let ans: Expr<MyTx, _> = MNum::new(12);
ans.commit();
assert!(MyTx::canonical_raw(&expr) == MyTx::canonical_raw(&ans))
}
100 changes: 100 additions & 0 deletions crates/luminal_2_eggplant/src/rules/basic_expr_rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use crate::datatypes::*;
use eggplant::prelude::*;
use eggplant::wrap::{G, RuleCtx, RuleSetId};
macro_rules! fold {
($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => {
T::add_rule(
stringify!($pat_name),
$ruleset,
|| {
let l = MNum::query();
let r = MNum::query();
let p = $ty::query(&l, &r);
#[eggplant::pat_vars_catch]
struct $pat_name {
l: MNum,
r: MNum,
p: $ty,
}
},
|ctx, pat| {
let cal = $f(ctx.devalue(pat.l.num), ctx.devalue(pat.r.num));
let op_value = ctx.insert_m_num(cal);
ctx.union(pat.p, op_value);
},
);
};
}

macro_rules! commu {
($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => {
T::add_rule(
stringify!($pat_name),
$ruleset,
|| {
let l = Expr::query_leaf();
let r = Expr::query_leaf();
let p = $ty::query(&l, &r);
#[eggplant::pat_vars_catch]
struct $pat_name {
l: Expr,
r: Expr,
p: $ty,
}
},
|ctx, pat| {
let op = $f(ctx, pat.l, pat.r);
ctx.union(pat.p, op);
},
);
};
}

macro_rules! assoc {
($ty:ident,$f:expr,$pat_name:ident,$ruleset:ident) => {
T::add_rule(
stringify!($pat_name),
$ruleset,
|| {
let ll = Expr::query_leaf();
let lr = Expr::query_leaf();
let l = $ty::query(&ll, &lr);
let r = Expr::query_leaf();
let p = $ty::query(&l, &r);
#[eggplant::pat_vars_catch]
struct $pat_name {
ll: Expr,
lr: Expr,
r: Expr,
p: $ty,
}
},
|ctx, pat| {
let r = $f(ctx, pat.lr, pat.r);
let p = $f(ctx, pat.ll, r);
ctx.union(pat.p, p);
},
);
};
}
pub fn assoc<T: G>(ruleset: RuleSetId) {
assoc!(MAdd, RuleCtx::insert_m_add, AddAssocPat, ruleset);
assoc!(MMul, RuleCtx::insert_m_mul, MulAssocPat, ruleset);
}

pub fn commu<T: G>(ruleset: RuleSetId) {
commu!(MAdd, RuleCtx::insert_m_add, AddCommuPat, ruleset);
commu!(MMul, RuleCtx::insert_m_mul, MulCommuPat, ruleset);
}

pub fn const_fold<T: G>(ruleset: RuleSetId) {
use std::cmp::*;
use std::ops::*;

fold!(MAdd, Add::add, AddPat, ruleset);
fold!(MSub, Sub::sub, SubPat, ruleset);
fold!(MMul, Mul::mul, MulPat, ruleset);
fold!(MMax, max, MaxPat, ruleset);
fold!(MMin, min, MinPat, ruleset);
fold!(MAnd, BitAnd::bitand, BitAndPat, ruleset);
}
7 changes: 7 additions & 0 deletions crates/luminal_2_eggplant/src/rules/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use eggplant::wrap::{G, RuleSetId};
mod basic_expr_rules;
pub fn add_rules<T: G>(ruleset: RuleSetId) {
basic_expr_rules::commu::<T>(ruleset);
basic_expr_rules::const_fold::<T>(ruleset);
basic_expr_rules::assoc::<T>(ruleset);
}
Loading