Skip to content

Commit e0727a3

Browse files
authored
Merge pull request #188 from luminal-ai/visualization
Modules for visualization, serialized_egraph, egglog_utils, example for visualization
2 parents 8a4b2c3 + f9599ce commit e0727a3

File tree

22 files changed

+595
-20
lines changed

22 files changed

+595
-20
lines changed

Cargo.toml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "luminal"
33
version = "0.2.0"
4-
edition = "2024"
4+
edition.workspace = true
55
rust-version = "1.85"
66
description = "Deep learning at the speed of light."
77
license = "MIT OR Apache-2.0"
@@ -22,18 +22,23 @@ regex = "1.9.5"
2222
rustc-hash = "2.1.1"
2323
uuid = { version = "1.7.0", features = ["v4"] }
2424
as-any = "0.3.1"
25-
egg = "0.9.5"
2625
symbolic_expressions = "5.0.3"
2726
serde = { version = "1.0.202", features = ["derive"] }
2827
thread_local = "1.1.8"
2928
generational-box = "0.5.6"
3029
serde_json = "1.0.140"
30+
egg = "0.9.5"
3131
egglog = "1.0.0"
3232
egglog-ast = "1.0.0"
33-
egraph-serialize = "0.3.0"
33+
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
3434
tracing = "0.1.43"
3535
paste = "1.0.15"
3636
pretty-duration = "0.1.1"
37+
anyhow = "1.0"
38+
graphviz-rust = { version = "0.9", default-features = false}
39+
40+
[workspace.package]
41+
edition = "2024"
3742

3843
[dev-dependencies]
3944
candle-core = "0.9.1"
@@ -43,13 +48,13 @@ ordered-float = "5.1.0"
4348
[workspace]
4449
members = [
4550
"examples/llama",
46-
#"examples/*",
51+
"examples/visualization",
4752
"crates/luminal_nn",
4853
"crates/luminal_cuda",
4954
"crates/luminal_training",
5055
"docs/company",
5156
]
5257
exclude = [
53-
"examples/yolo_v8",
58+
5459
"crates/luminal_cuda",
5560
]

crates/luminal_cuda/src/block/ops.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ use super::CustomState;
44
use cudarc::driver::{CudaStream, DevicePtr};
55
use itertools::Itertools;
66
use luminal::{
7-
graph::{extract_expr, extract_expr_list, SerializedEGraph},
7+
graph::{extract_expr, extract_expr_list},
88
prelude::ENodeId,
9+
serialized_egraph::SerializedEGraph,
910
shape::Expression,
1011
utils::{
1112
flatten_mul_strides, CStructBuilder, EgglogOp, LLIROp,

crates/luminal_cuda/src/kernel/ops.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ use cudarc::{
66
};
77
use itertools::Itertools;
88
use luminal::{
9-
graph::{extract_dtype, extract_expr, extract_expr_list, SerializedEGraph},
9+
graph::{extract_dtype, extract_expr, extract_expr_list},
1010
op::DType,
1111
prelude::ENodeId,
12+
serialized_egraph::SerializedEGraph,
1213
shape::Expression,
1314
utils::{
1415
flatten_mul_strides, EgglogOp, LLIROp,

crates/luminal_nn/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![allow(unused_imports)]
2+
13
mod activation;
24
pub use activation::*;
35
mod convolution;

examples/llama/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ luminal_cuda = { path = "../../crates/luminal_cuda" }
1212
itertools = "0.12.1"
1313
tokenizers = "0.15.2"
1414
tracing = "0.1.43"
15-
rustc-hash = "2.1.1"
1615
tracing-subscriber = {version="0.3", features=["env-filter"]}
1716
tracing-perfetto-sdk-layer = "0.13.0"
1817
tracing-perfetto-sdk-schema = "0.13.0"

examples/llama/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use itertools::Itertools;
44
use luminal::{
55
graph::{Graph, Runtime},
66
op::DType,
7+
prelude::FxHashMap,
78
};
89
use luminal_cuda::{
910
block::IntoBlockOp,
1011
runtime::{record_exec_timings_to_file, CudaRuntime, CustomState},
1112
};
1213
use model::*;
13-
use rustc_hash::*;
1414
use std::{fs::File, io::Write, time::Duration};
1515
use tokenizers::Tokenizer;
1616
use tracing::{span, Level};

examples/visualization/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*.dot
2+
*.html
3+
*.svg

examples/visualization/Cargo.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "visualization"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[features]
7+
8+
[dependencies]
9+
anyhow = "1.0"
10+
egglog = "1.0"
11+
egglog-ast = "1.0.0"
12+
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
13+
itertools = "0.12.1"
14+
luminal = { path = "../.." }
15+
luminal_cuda = { path = "../../crates/luminal_cuda" }
16+
luminal_nn = { path = "../../crates/luminal_nn" }
17+
rustc-hash = "2.1"
18+
tokenizers = "0.15.2"
19+
tracing = "0.1.43"
20+
tracing-appender = "0.2.4"
21+
tracing-perfetto-sdk-layer = "0.13.0"
22+
tracing-perfetto-sdk-schema = "0.13.0"
23+
tracing-perfetto-sdk-sys = "0.13.0"
24+
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

examples/visualization/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Visualization in Luminal
2+
3+
## Design Choices
4+
Luminal produces intermediate files rather than complete visualizations
5+
6+
The two primary file types are:
7+
- `.html` files
8+
- `.dot` files
9+
10+
These files enable interactive viewing which is often necessary for making visualizations interpretable.
11+
12+
## VSCode Extensions
13+
We recommend the following extensions for VSCode users.
14+
The integrated nature of these extensions makes viewing these files easy even on remote machines via ssh.
15+
16+
- `Live Preview` by microsoft.
17+
- `Graphviz Interactive Preview` by tintinweb
18+
19+
## Example Provided
20+
In the example program, as simple program is defined.
21+
From this a HLIR graph is created and visualized.
22+
A saturated EGraph is created and visualized.
23+
Finally an LLIR graph is extracted and visualized.

examples/visualization/src/main.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use std::fs;
2+
3+
use luminal::{
4+
self,
5+
graph::{hlir_to_egglog, Graph, Runtime},
6+
prelude::*,
7+
serialized_egraph::SerializedEGraph,
8+
visualization::{ToDot, ToHtml},
9+
};
10+
use luminal_cuda::runtime::{CudaRuntime, CustomState};
11+
12+
use egglog::{prelude::RustSpan, var, EGraph};
13+
use egglog_ast::span::Span;
14+
use rustc_hash::FxHashMap;
15+
16+
fn main() {
17+
// Create a new graph
18+
let mut cx = Graph::new();
19+
20+
// Create input tensor using constant values
21+
22+
let (m, n, k) = (4096, 14336, 9192);
23+
24+
let a = cx.tensor((m, k));
25+
let b = cx.tensor((k, n));
26+
27+
let _c = a.matmul(b);
28+
29+
let ctx = luminal_cuda::cudarc::driver::CudaContext::new(0).unwrap();
30+
ctx.bind_to_thread().unwrap();
31+
let _stream = ctx.default_stream();
32+
let _custom_state: FxHashMap<String, CustomState> = FxHashMap::default();
33+
34+
println!("Visualizing HLIR");
35+
fs::write("HLIR.dot", cx.graph.to_dot().unwrap()).unwrap();
36+
37+
println!("Building and Saturating EGraph");
38+
cx.build_search_space::<CudaRuntime>();
39+
40+
let (program, root) = hlir_to_egglog(&cx);
41+
42+
let mut ops = <CudaRuntime as Runtime>::Ops::into_vec();
43+
ops.extend(<luminal::op::Ops as IntoEgglogOp>::into_vec());
44+
45+
let mut egglog_obj: EGraph = egglog::EGraph::default();
46+
47+
// setup the rules and datatypes
48+
egglog_obj
49+
.parse_and_run_program(None, luminal::egglog_utils::BASE)
50+
.unwrap();
51+
egglog_obj
52+
.parse_and_run_program(None, &luminal::egglog_utils::op_defs_string(&ops))
53+
.unwrap();
54+
egglog_obj
55+
.parse_and_run_program(None, &luminal::egglog_utils::op_rewrites_string(&ops))
56+
.unwrap();
57+
egglog_obj
58+
.parse_and_run_program(None, luminal::egglog_utils::BASE_CLEANUP)
59+
.unwrap();
60+
egglog_obj
61+
.parse_and_run_program(None, &luminal::egglog_utils::op_cleanups_string(&ops))
62+
.unwrap();
63+
64+
// load the program
65+
egglog_obj.parse_and_run_program(None, &program).unwrap();
66+
67+
// run the graph
68+
egglog_obj
69+
.parse_and_run_program(None, luminal::egglog_utils::RUN_SCHEDULE)
70+
.unwrap();
71+
72+
// EGraph Optimization Complete
73+
println!("Visualizing EGraph");
74+
// save the egraph visualizations
75+
fs::write("egraph.html", egglog_obj.to_html().unwrap()).unwrap();
76+
fs::write("egraph.dot", egglog_obj.to_dot().unwrap()).unwrap();
77+
78+
let (sort, value) = egglog_obj.eval_expr(&var!(root)).unwrap();
79+
let s_egraph = SerializedEGraph::new(&egglog_obj, vec![(sort, value)]);
80+
let llir_graphs = egglog_to_llir(&s_egraph, &ops, 100);
81+
82+
let example_llir_graph = llir_graphs.last().unwrap();
83+
84+
println!("Visualizing LLIR Graph");
85+
fs::write("LLIR.dot", example_llir_graph.to_dot().unwrap()).unwrap();
86+
}

0 commit comments

Comments
 (0)