Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions src/compiler_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ impl ToIds for () {
}
}

#[allow(clippy::implicit_hasher)]
impl<T: ToIds> ToIds for FxHashMap<String, T> {
fn to_ids(&self) -> Vec<NodeIndex> {
self.values().flat_map(|i| i.to_ids()).collect()
Expand Down Expand Up @@ -192,7 +193,7 @@ impl<C: Compiler + Debug> Compiler for Looped<C> {
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut remap: T) {
let mut linearized = None;
loop {
self.0.compile(graph, &mut remap);
let _cur_output = self.0.compile(graph, &mut remap);
graph.toposort();
if linearized == graph.linearized_graph {
break;
Expand Down Expand Up @@ -317,7 +318,7 @@ tuple_impls!(
// Helpers

impl Graph {
/// Add op on the graph, and get back a NewOp
/// Add `op` on the graph, and get back a `NewOp`
///
/// ```rust
/// use luminal::prelude::*;
Expand All @@ -337,7 +338,7 @@ impl Graph {
num_srcs: 0,
}
}
/// Add op on the graph, and get back a NewOp. Just like add_op, except a boxed op is expected.
/// Add `op` on the graph, and get back a `NewOp`. Just like `add_op`, except a boxed op is expected.
pub fn add_boxed_op(&mut self, op: Box<dyn Operator + 'static>) -> NewOp<'_> {
self.linearized_graph = None;
NewOp {
Expand Down Expand Up @@ -407,11 +408,7 @@ impl Graph {
}
if show_shapes
&& new_graph.contains_node(id_map[&edge.target()])
&& edge
.weight()
.as_data()
.map(|d| !d.2.is_empty())
.unwrap_or_default()
&& edge.weight().as_data().is_some_and(|d| !d.2.is_empty())
{
new_graph
.node_weight_mut(id_map[&edge.target()])
Expand Down Expand Up @@ -712,7 +709,7 @@ fn backtrack_match(
mapping.insert(pattern_root, main_root);
let main_parents = get_parents(main_graph, main_root, |e| !e.weight().is_schedule());
'pattern_loop: for pattern_parent in get_parents(pattern_graph, pattern_root, |_| true) {
for parent in main_parents.iter() {
for parent in &main_parents {
if mapping.values().any(|&v| v == *parent) {
// This main node was used already, skip it
continue;
Expand Down Expand Up @@ -764,24 +761,21 @@ fn test_node(
return false;
}
for (a, b) in a_sh.iter().zip(b_sh.dims().into_iter()) {
match a.to_usize() {
Some(n) => {
if b.to_usize().map(|i| i != n).unwrap_or(true) {
return false;
}
if let Some(n) = a.to_usize() {
if b.to_usize() != Some(n) {
return false;
}
None => {
let c = a
.to_symbols()
.pop()
.expect("Selector dimension must be either a symbol or number");
if let Some(expected) = shape_map.get(&c) {
if b != *expected {
return false;
}
} else {
shape_map.insert(c, b);
} else {
let c = a
.to_symbols()
.pop()
.expect("Selector dimension must be either a symbol or number");
if let Some(expected) = shape_map.get(&c) {
if b != *expected {
return false;
}
} else {
shape_map.insert(c, b);
}
}
}
Expand Down Expand Up @@ -955,6 +949,5 @@ pub fn debug() -> bool {
std::env::var("DEBUG")
.unwrap_or_default()
.parse::<i32>()
.map(|i| i == 1)
.unwrap_or_default()
.is_ok_and(|i| i == 1)
}
52 changes: 12 additions & 40 deletions src/generic_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,11 @@ impl Compiler for CSE {
while eliminated {
eliminated = false;
let mut srcs_set: HashMap<Vec<NodeIndex>, Vec<NodeIndex>> = HashMap::new();
for node in graph.graph.node_indices().collect_vec() {
if graph
.graph
.node_weight(node)
.unwrap()
.as_any()
.is::<Function>()
{
for node in graph.collect_node_indices() {
if graph.this_node_is::<Function>(node) {
continue;
}
let srcs = graph
.graph
.edges_directed(node, petgraph::Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.source())
.collect_vec();
let srcs = graph.get_incomings(node);

if let Some(other_nodes) = srcs_set.get(&srcs) {
for other_node in other_nodes {
Expand Down Expand Up @@ -108,40 +96,23 @@ impl Compiler for RemoveSingleReductions {
type Output = ();
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut ids: T) {
for node in graph.graph.node_indices().collect::<Vec<_>>() {
let dim = if let Some(red) = graph
.graph
.node_weight(node)
.unwrap()
.as_any()
.downcast_ref::<SumReduce>()
{
let dim = if let Some(red) = graph.get_this_node_is::<SumReduce>(node) {
Some(red.0)
} else {
graph
.graph
.node_weight(node)
.unwrap()
.as_any()
.downcast_ref::<MaxReduce>()
.map(|red| red.0)
graph.get_this_node_is::<MaxReduce>(node).map(|red| red.0)
};
if let Some(dim) = dim {
if graph
.graph
.edges_directed(node, Direction::Incoming)
.next()
.map(|e| {
e.weight()
.as_data()
.map(|w| {
w.2.dims[w.2.indexes[dim]]
.to_usize()
.map(|i| i == 1)
.unwrap_or_default()
})
.unwrap_or_default()
.is_some_and(|e| {
e.weight().as_data().is_some_and(|w| {
w.2.dims[w.2.indexes[dim]]
.to_usize()
.is_some_and(|i| i == 1)
})
})
.unwrap_or_default()
{
let upstream = graph
.graph
Expand Down Expand Up @@ -243,6 +214,7 @@ pub struct ArithmeticElimination;

impl Compiler for ArithmeticElimination {
type Output = ();
#[allow(clippy::too_many_lines)]
fn compile<T: ToIdsMut>(&self, graph: &mut Graph, mut ids: T) {
// x + 0, 0 + x
let zero = constant(0.);
Expand Down
149 changes: 141 additions & 8 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ use std::{
use super::compiler_utils::{ToIds, ToIdsMut};
use colored::Colorize;
use itertools::Itertools;
use petgraph::{stable_graph::StableGraph, visit::EdgeRef, Direction};
use rustc_hash::{FxHashMap, FxHashSet};
use petgraph::{
stable_graph::StableGraph,
visit::{EdgeRef, IntoEdgeReferences},
Direction,
};
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};

pub type StorageGraph = StableGraph<Box<dyn Operator>, Dependency>;

Expand All @@ -27,7 +31,8 @@ pub struct Graph {
pub graph: StorageGraph,
/// Tensors marked in this set will not get deleted when the graph is ran
pub no_delete: FxHashSet<NodeIndex>,
/// Tensors marked in this set need to be retrieved later (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently)
/// Tensors marked in this set need to be retrieved later
/// (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently)
pub to_retrieve: FxHashMap<NodeIndex, (u8, ShapeTracker)>,
/// A cached list of nodes to run, source nodes, and view nodes to delete after execution.
#[allow(clippy::type_complexity)]
Expand Down Expand Up @@ -205,7 +210,7 @@ impl Graph {
get_source_tensors(&self.no_delete, &mut self.tensors, src_ids, &consumers);

// Substitute in the dyn dims
for (_, st) in srcs.iter_mut() {
for (_, st) in &mut srcs {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}

Expand All @@ -230,7 +235,7 @@ impl Graph {
self.toposort();
}
let mut dim_stack = Vec::new();
for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() {
for (node, src_ids) in self.linearized_graph.as_ref().unwrap() {
if self.tensors.contains_key(&(*node, 0)) {
continue;
}
Expand All @@ -245,7 +250,7 @@ impl Graph {
.collect_vec();

// Substitute in the dyn dims
for (_, st) in srcs.iter_mut() {
for (_, st) in &mut srcs {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}

Expand Down Expand Up @@ -284,7 +289,7 @@ impl Graph {
(width.saturating_sub(" Executing ".len())) / 2
);
let start = std::time::Instant::now();
for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() {
for (node, src_ids) in self.linearized_graph.as_ref().unwrap() {
if self.tensors.contains_key(&(*node, 0)) {
continue;
}
Expand All @@ -295,7 +300,7 @@ impl Graph {
get_source_tensors(&self.no_delete, &mut self.tensors, src_ids, &consumers);

// Substitute in the dyn dims
for (_, st) in srcs.iter_mut() {
for (_, st) in &mut srcs {
st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack);
}

Expand Down Expand Up @@ -363,6 +368,134 @@ impl Graph {
println!("Total: {}", format_duration(&start.elapsed()).bold());
self.reset();
}

/// The `Operator` associated to this `node` is of type `F`?
/// Assuming this `node` exists in the graph
#[inline]
pub fn this_node_is<F: 'static>(&self, node: NodeIndex) -> bool {
self.graph.node_weight(node).unwrap().as_any().is::<F>()
}

/// If the `Operator` associated to this `node` is of type `F`,
/// give that, otherwise None.
/// Assuming this `node` exists in the graph
#[inline]
pub fn get_this_node_is<F: 'static>(&self, node: NodeIndex) -> Option<&F> {
self.graph
.node_weight(node)
.unwrap()
.as_any()
.downcast_ref::<F>()
}

/// Gather the nodes which have an edge going to `node` (assume exists)
/// and where the edge connecting them is a data dependency
/// They are properly sorted according to the `input_order`
/// field of the data dependencies.
#[inline]
pub fn get_incomings(&self, node: NodeIndex) -> Vec<NodeIndex> {
self.graph
.edges_directed(node, petgraph::Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.source())
.collect_vec()
}

/// All of the node indices
#[inline]
pub fn collect_node_indices(&self) -> Vec<NodeIndex> {
self.graph.node_indices().collect_vec()
}

/// The shape trackers for all the sources of `cur_node` (assume exists)
#[inline]
pub fn get_source_shapes(&self, cur_node: &NodeIndex) -> Vec<ShapeTracker> {
self.get_sources(*cur_node)
.into_iter()
.map(|(_, _, a)| a)
.collect_vec()
}

/// Return all those in `to_retrieve` whose operators are not `FType`
#[inline]
pub fn do_to_retrieve<FType: 'static>(&self) -> Vec<(NodeIndex, (u8, ShapeTracker))> {
self.to_retrieve
.iter()
.map(|(a, b)| (*a, *b))
// Filter to non-FType
.filter(|(n, _)| !self.node_weight(*n).unwrap().as_any().is::<FType>())
.collect::<Vec<_>>()
}

pub fn to_retrieve_graph_tensors(&mut self) -> impl Iterator<Item = GraphTensor> + '_ {
let mut retrieving = self
.to_retrieve
.iter()
.map(|(a, (b, c))| (*a, (*b, *c)))
.collect_vec();
retrieving.sort_by_key(|(a, (_, _))| *a);
retrieving
.into_iter()
.map(|(id, (_, shape))| GraphTensor::from_id(id, shape, self))
}

/// CAUTION: All `GraphTensor`s which refer to self or other
/// will be unusable from here on
pub fn disjoint_union(mut self, mut other: Self) -> Self {
let mut other_remap = FxHashMap::<NodeIndex, NodeIndex>::with_capacity_and_hasher(
other.graph.node_count(),
FxBuildHasher,
);
for other_node_idx in other.collect_node_indices() {
let other_node_weight = other.graph.node_weight_mut(other_node_idx).unwrap();
let mut dummy: Box<dyn Operator> = Box::new(Add);
core::mem::swap(&mut dummy, other_node_weight);
let new_node_idx = self.graph.add_node(dummy);
other_remap.insert(other_node_idx, new_node_idx);
}
for other_edge_ref in other.edge_references() {
let (a, b) = (other_edge_ref.source(), other_edge_ref.target());
let (a_new, b_new) = (*other_remap.get(&a).unwrap(), *other_remap.get(&b).unwrap());
self.graph.add_edge(a_new, b_new, *other_edge_ref.weight());
}
self.tensors.extend(
other
.tensors
.into_iter()
.map(|((a, b), c)| ((*other_remap.get(&a).unwrap(), b), c)),
);
let bad_keys: Vec<char> = self
.dyn_map
.keys()
.filter_map(|key| {
if other.dyn_map.contains_key(key)
&& other.dyn_map.get(key) != self.dyn_map.get(key)
{
Some(*key)
} else {
None
}
})
.collect();
assert!(bad_keys.is_empty());
self.dyn_map.extend(other.dyn_map);
self.no_delete.extend(
other
.no_delete
.into_iter()
.map(|a| *other_remap.get(&a).unwrap()),
);
self.to_retrieve.extend(
other
.to_retrieve
.into_iter()
.map(|(a, (b, c))| (*other_remap.get(&a).unwrap(), (b, c))),
);
self.linearized_graph = None;
self.consumers_map = None;
self
}
}

impl Deref for Graph {
Expand Down
Loading