Skip to content

Commit be5e58e

Browse files
committed
proper postorder traversal
1 parent aefd612 commit be5e58e

File tree

1 file changed

+48
-120
lines changed
  • language/src/Circuit/Language

1 file changed

+48
-120
lines changed

Diff for: language/src/Circuit/Language/Expr.hs

+48-120
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,8 @@ instance (Pretty f, Pretty i) => Pretty (Node i f) where
586586
pretty (NSplit e n) = "split" <+> pretty e <+> pretty n
587587
pretty (NJoin e) = "join" <+> pretty e
588588
pretty (NBundle es) = "bundle" <+> pretty (toList es)
589-
pretty (NAtIndex e idx) = pretty e <+> "!" <> pretty idx
590-
pretty (NUpdateAtIndex e idx v) = pretty e <+> "!" <> pretty idx <+> ":=" <+> pretty v
589+
pretty (NAtIndex e idx) = pretty e <+> "!!" <+> pretty idx
590+
pretty (NUpdateAtIndex e idx v) = pretty e <+> "!!" <+> pretty idx <+> ":=" <+> pretty v
591591

592592
deriving instance (Show i, Show f) => Show (Node i f)
593593

@@ -635,7 +635,7 @@ untypeBinOp BXors = UBXor
635635

636636
--------------------------------------------------------------------------------
637637

638-
reifyGraph :: Expr i f ty -> Seq (Hash, Node i f)
638+
reifyGraph :: Expr i f ty -> Seq (Hash, Node i f)
639639
reifyGraph e =
640640
gbsEdges $ execState (buildGraph_ e) (GraphBuilderState mempty mempty)
641641

@@ -645,149 +645,76 @@ data GraphBuilderState i f = GraphBuilderState
645645
}
646646

647647
{-# SCC buildGraph_ #-}
648-
buildGraph_ :: forall i f ty. Expr i f ty -> State (GraphBuilderState i f) Hash
648+
buildGraph_ :: forall i f ty. Expr i f ty -> State (GraphBuilderState i f) Hash
649649
buildGraph_ expr =
650650
getId expr <$ case expr of
651-
EVal h v -> do
652-
ns <- gets gbsSharedNodes
653-
unless (h `Set.member` ns) $ do
651+
EVal h v ->
652+
unlessM (hasBeenVisited h) $ do
654653
let n = NVal (rawVal v)
655-
modify $ \s ->
656-
s
657-
{ gbsSharedNodes = Set.insert h ns,
658-
gbsEdges = gbsEdges s |> (h, n)
659-
}
660-
EVar h v -> do
661-
ns <- gets gbsSharedNodes
662-
unless (h `Set.member` ns) $ do
654+
markVisited h n
655+
EVar h v ->
656+
unlessM (hasBeenVisited h) $ do
663657
let n = NVar (rawWire v)
664-
modify $ \s ->
665-
s
666-
{ gbsSharedNodes = Set.insert h ns,
667-
gbsEdges = gbsEdges s |> (h, n)
668-
}
669-
EUnOp h op e -> do
670-
ns <- gets gbsSharedNodes
671-
unless (h `Set.member` ns) $ do
672-
modify $ \s ->
673-
s
674-
{ gbsSharedNodes = Set.insert h ns
675-
}
658+
markVisited h n
659+
EUnOp h op e ->
660+
unlessM (hasBeenVisited h) $ do
676661
e' <- buildGraph_ e
677662
let n = NUnOp (untypeUnOp op) e'
678-
modify $ \s ->
679-
s
680-
{ gbsEdges = gbsEdges s |> (h, n)
681-
}
682-
EBinOp h op e1 e2 -> do
683-
ns <- gets gbsSharedNodes
684-
unless (h `Set.member` ns) $ do
685-
modify $ \s ->
686-
s
687-
{ gbsSharedNodes = Set.insert h ns
688-
}
663+
markVisited h n
664+
EBinOp h op e1 e2 ->
665+
unlessM (hasBeenVisited h) $ do
689666
e1' <- buildGraph_ e1
690667
e2' <- buildGraph_ e2
691668
let n = NBinOp (untypeBinOp op) e1' e2'
692-
modify $ \s ->
693-
s
694-
{ gbsEdges = gbsEdges s |> (h, n)
695-
}
696-
EIf h b t f -> do
697-
ns <- gets gbsSharedNodes
698-
unless (h `Set.member` ns) $ do
699-
modify $ \s ->
700-
s
701-
{ gbsSharedNodes = Set.insert h ns
702-
}
669+
markVisited h n
670+
EIf h b t f ->
671+
unlessM (hasBeenVisited h) $ do
703672
b' <- buildGraph_ b
704673
t' <- buildGraph_ t
705674
f' <- buildGraph_ f
706675
let n = NIf b' t' f'
707-
modify $ \s ->
708-
s
709-
{ gbsEdges = gbsEdges s |> (h, n)
710-
}
711-
EEq h l r -> do
712-
ns <- gets gbsSharedNodes
713-
unless (h `Set.member` ns) $ do
714-
modify $ \s ->
715-
s
716-
{ gbsSharedNodes = Set.insert h ns
717-
}
676+
markVisited h n
677+
EEq h l r ->
678+
unlessM (hasBeenVisited h) $ do
718679
l' <- buildGraph_ l
719680
r' <- buildGraph_ r
720681
let n = NEq l' r'
721-
modify $ \s ->
722-
s
723-
{ gbsEdges = gbsEdges s |> (h, n)
724-
}
725-
ESplit h i -> do
726-
ns <- gets gbsSharedNodes
727-
unless (h `Set.member` ns) $ do
728-
modify $ \s ->
729-
s
730-
{ gbsSharedNodes = Set.insert h ns
731-
}
682+
markVisited h n
683+
ESplit h i ->
684+
unlessM (hasBeenVisited h) $ do
732685
i' <- buildGraph_ i
733686
let n = NSplit i' (fromIntegral $ natVal (Proxy @(NBits f)))
734-
modify $ \s ->
735-
s
736-
{ gbsEdges = gbsEdges s |> (h, n)
737-
}
738-
EJoin h i -> do
739-
ns <- gets gbsSharedNodes
740-
unless (h `Set.member` ns) $ do
741-
modify $ \s ->
742-
s
743-
{ gbsSharedNodes = Set.insert h ns
744-
}
687+
markVisited h n
688+
EJoin h i ->
689+
unlessM (hasBeenVisited h) $ do
745690
i' <- buildGraph_ i
746691
let n = NJoin i'
747-
modify $ \s ->
748-
s
749-
{ gbsEdges = gbsEdges s |> (h, n)
750-
}
751-
EBundle h b -> do
752-
ns <- gets gbsSharedNodes
753-
unless (h `Set.member` ns) $ do
754-
modify $ \s ->
755-
s
756-
{ gbsSharedNodes = Set.insert h ns
757-
}
692+
markVisited h n
693+
EBundle h b ->
694+
unlessM (hasBeenVisited h) $ do
758695
b' <- SV.fromSized <$> traverse buildGraph_ b
759696
let n = NBundle b'
760-
modify $ \s ->
761-
s
762-
{ gbsEdges = gbsEdges s |> (h, n)
763-
}
764-
EAtIndex h e i -> do
765-
ns <- gets gbsSharedNodes
766-
unless (h `Set.member` ns) $ do
767-
modify $ \s ->
768-
s
769-
{ gbsSharedNodes = Set.insert h ns
770-
}
697+
markVisited h n
698+
EAtIndex h e i ->
699+
unlessM (hasBeenVisited h) $ do
771700
e' <- buildGraph_ e
772701
let n = NAtIndex e' (fromIntegral i)
773-
modify $ \s ->
774-
s
775-
{ gbsEdges = gbsEdges s |> (h, n)
776-
}
777-
EUpdateAtIndex h e i v -> do
778-
ns <- gets gbsSharedNodes
779-
unless (h `Set.member` ns) $ do
780-
modify $ \s ->
781-
s
782-
{ gbsSharedNodes = Set.insert h ns
783-
}
702+
markVisited h n
703+
EUpdateAtIndex h e i v ->
704+
unlessM (hasBeenVisited h) $ do
784705
e' <- buildGraph_ e
785706
v' <- buildGraph_ v
786707
let n = NUpdateAtIndex e' (fromIntegral i) v'
787-
modify $ \s ->
788-
s
789-
{ gbsEdges = gbsEdges s |> (h, n)
790-
}
708+
markVisited h n
709+
where
710+
hasBeenVisited h = gets $ Set.member h . gbsSharedNodes
711+
{-# INLINE hasBeenVisited #-}
712+
markVisited h n = modify $ \s ->
713+
s
714+
{ gbsSharedNodes = Set.insert h (gbsSharedNodes s)
715+
, gbsEdges = gbsEdges s |> (h, n)
716+
}
717+
{-# INLINE markVisited #-}
791718

792719
--------------------------------------------------------------------------------
793720

@@ -902,11 +829,12 @@ evalNode lookupVar vars h node =
902829
assertField x
903830
| V.length x == 1 = pure $ V.head x
904831
| otherwise = throwError $ TypeErr "expected field, got vector"
832+
{-# INLINE assertField #-}
905833

906834
assertFromCache :: Hash -> EvalM i f (V.Vector f)
907835
assertFromCache i = do
908836
m <- get
909837
case Map.lookup i m of
910838
Just ws -> pure ws
911839
Nothing -> throwError $ MissingFromCache i
912-
{-# INLINE assertFromCache #-}
840+
{-# INLINE assertFromCache #-}

0 commit comments

Comments
 (0)