Skip to content

[ refactor ] ScopedSnocList: Swap Scope on SnocList (Phase 2) #3513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions idris2api.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ modules =
Core.LinearCheck,
Core.Metadata,
Core.Name,
Core.Name.CompatibleVars,
Core.Name.Namespace,
Core.Name.Scoped,
Core.Normalise,
Expand Down Expand Up @@ -191,6 +192,7 @@ modules =
Libraries.Data.NameMap.Traversable,
Libraries.Data.Ordering.Extra,
Libraries.Data.PosMap,
Libraries.Data.SnocList.Extra,
Libraries.Data.SnocList.HasLength,
Libraries.Data.SnocList.LengthMatch,
Libraries.Data.SnocList.SizeOf,
Expand Down
10 changes: 10 additions & 0 deletions libs/base/Data/SnocList.idr
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,13 @@ tailRecAppendIsAppend : (sx, sy : SnocList a) -> tailRecAppend sx sy = sx ++ sy
tailRecAppendIsAppend sx Lin = Refl
tailRecAppendIsAppend sx (sy :< y) =
trans (snocTailRecAppend y sx sy) (cong (:< y) $ tailRecAppendIsAppend sx sy)

||| `reverseOnto` reverses the snoc list and prepends it to the "onto" argument
export
revOnto : (xs, vs : SnocList a) -> reverseOnto xs vs = xs ++ reverse vs
revOnto _ [<] = Refl
revOnto xs (vs :< v) =
do rewrite revOnto (xs :< v) vs
rewrite sym $ appendAssociative xs [<v] (reverse vs)
rewrite revOnto [<v] vs
Refl
6 changes: 3 additions & 3 deletions libs/base/Data/SnocList/HasLength.idr
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ map f Z = Z
map f (S hl) = S (map f hl)

export
sucL : HasLength n sx -> HasLength (S n) ([<x] ++ sx)
sucL Z = S Z
sucL (S n) = S (sucL n)
sucR : HasLength n sx -> HasLength (S n) ([<x] ++ sx)
sucR Z = S Z
sucR (S n) = S (sucR n)

export
hlAppend : HasLength m sx -> HasLength n sy -> HasLength (n + m) (sx ++ sy)
Expand Down
11 changes: 11 additions & 0 deletions libs/base/Data/SnocList/Operations.idr
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ lengthHomomorphism sx (sy :< x) = Calc $
~~ 1 + (length sx + length sy) ...(cong (1+) $ lengthHomomorphism _ _)
~~ length sx + (1 + length sy) ...(plusSuccRightSucc _ _)

export
lengthDistributesOverFish : (sx : SnocList a) -> (ys : List a) ->
length (sx <>< ys) === length sx + length ys
lengthDistributesOverFish sx [] = sym $ plusZeroRightNeutral _
lengthDistributesOverFish sx (y :: ys) = Calc $
|~ length ((sx :< y) <>< ys)
~~ length (sx :< y) + length ys ...( lengthDistributesOverFish (sx :< y) ys)
~~ S (length sx) + length ys ...( Refl )
~~ length sx + S (length ys) ...( plusSuccRightSucc _ _ )
~~ length sx + length (y :: ys) ...( Refl )

-- cons-list operations on snoc-lists

||| Take `n` first elements from `sx`, returning the whole list if
Expand Down
59 changes: 29 additions & 30 deletions src/Compiler/ANF.idr
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import Core.Core
import Core.TT

import Data.List
import Data.SnocList
import Data.Vect
import Libraries.Data.SortedSet
import Libraries.Data.SnocList.Extra

%default covering

Expand Down Expand Up @@ -136,9 +138,12 @@ Show ANFDef where
show args ++ " -> " ++ show ret
show (MkAError exp) = "Error: " ++ show exp

data AVars : List Name -> Type where
Nil : AVars []
(::) : Int -> AVars xs -> AVars (x :: xs)
data AVars : Scoped where
Lin : AVars ScopeEmpty
(:<) : AVars xs -> Int -> AVars (xs :< x)

ScopeEmpty : AVars ScopeEmpty
ScopeEmpty = [<]

data Next : Type where

Expand All @@ -150,8 +155,8 @@ nextVar
pure i

lookup : {idx : _} -> (0 p : IsVar x idx vs) -> AVars vs -> Int
lookup First (x :: xs) = x
lookup (Later p) (x :: xs) = lookup p xs
lookup First (xs :< x) = x
lookup (Later p) (xs :< x) = lookup p xs

bindArgs : {auto v : Ref Next Int} ->
List ANF -> Core (List (AVar, Maybe ANF))
Expand Down Expand Up @@ -187,6 +192,15 @@ mlet fc val sc
= do i <- nextVar
pure $ ALet fc i val (sc (ALocal i))

bindAsFresh :
{auto v : Ref Next Int} ->
(args : List Name) -> AVars vars' ->
Core (List Int, AVars (vars' <>< args))
bindAsFresh [] vs = pure ([], vs)
bindAsFresh (n :: ns) vs
= do i <- nextVar
mapFst (i ::) <$> bindAsFresh ns (vs :< i)

mutual
anfArgs : {vars : _} ->
{auto v : Ref Next Int} ->
Expand All @@ -211,7 +225,7 @@ mutual
_ => ACrash fc "Can't happen (AApp)"
anf vs (LLet fc x val sc)
= do i <- nextVar
let vs' = i :: vs
let vs' = vs :< i
pure $ ALet fc i !(anf vs val) !(anf vs' sc)
anf vs (LCon fc n ci t args)
= anfArgs fc vs args (ACon fc n ci t)
Expand Down Expand Up @@ -241,16 +255,8 @@ mutual
{auto v : Ref Next Int} ->
AVars vars -> LiftedConAlt vars -> Core AConAlt
anfConAlt vs (MkLConAlt n ci t args sc)
= do (is, vs') <- bindArgs args vs
= do (is, vs') <- bindAsFresh args vs
pure $ MkAConAlt n ci t is !(anf vs' sc)
where
bindArgs : (args : List Name) -> AVars vars' ->
Core (List Int, AVars (args ++ vars'))
bindArgs [] vs = pure ([], vs)
bindArgs (n :: ns) vs
= do i <- nextVar
(is, vs') <- bindArgs ns vs
pure (i :: is, i :: vs')

anfConstAlt : {vars : _} ->
{auto v : Ref Next Int} ->
Expand All @@ -262,25 +268,18 @@ export
toANF : LiftedDef -> Core ANFDef
toANF (MkLFun args scope sc)
= do v <- newRef Next (the Int 0)
(iargs, vsNil) <- bindArgs args []
let vs : AVars args = rewrite sym (appendNilRightNeutral args) in
vsNil
(iargs', vs) <- bindArgs scope vs
pure $ MkAFun (iargs ++ reverse iargs') !(anf vs sc)
where
bindArgs : {auto v : Ref Next Int} ->
(args : List Name) -> AVars vars' ->
Core (List Int, AVars (args ++ vars'))
bindArgs [] vs = pure ([], vs)
bindArgs (n :: ns) vs
= do i <- nextVar
(is, vs') <- bindArgs ns vs
pure (i :: is, i :: vs')
(iargs, vsNil) <- bindAsFresh (cast args) [<]
let vs : AVars args
:= rewrite sym $ appendLinLeftNeutral args in
rewrite snocAppendAsFish [<] args in vsNil
(iargs', vs) <- bindAsFresh (cast scope) vs
sc' <- anf (rewrite snocAppendAsFish args scope in vs) sc
pure $ MkAFun (iargs ++ iargs') sc'
toANF (MkLCon t a ns) = pure $ MkACon t a ns
toANF (MkLForeign ccs fargs t) = pure $ MkAForeign ccs fargs t
toANF (MkLError err)
= do v <- newRef Next (the Int 0)
pure $ MkAError !(anf [] err)
pure $ MkAError !(anf ScopeEmpty err)

export
freeVariables : ANF -> SortedSet AVar
Expand Down
65 changes: 35 additions & 30 deletions src/Compiler/CaseOpts.idr
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ import Core.FC
import Core.TT

import Data.List
import Data.SnocList
import Data.Vect

import Libraries.Data.List.SizeOf
import Libraries.Data.SnocList.SizeOf
import Libraries.Data.SnocList.Extra

%default covering

{-
Expand All @@ -32,38 +37,38 @@ case t of

shiftUnder : {args : _} ->
{idx : _} ->
(0 p : IsVar n idx (x :: args ++ vars)) ->
NVar n (args ++ x :: vars)
(0 p : IsVar n idx (vars ++ args :< x)) ->
NVar n (vars :< x ++ args)
shiftUnder First = weakenNVar (mkSizeOf args) (MkNVar First)
shiftUnder (Later p) = insertNVar (mkSizeOf args) (MkNVar p)

shiftVar : {outer, args : Scope} ->
NVar n (outer ++ (x :: args ++ vars)) ->
NVar n (outer ++ (args ++ x :: vars))
shiftVar : {outer : Scope} -> {args : List Name} ->
NVar n ((vars <>< args :< x) ++ outer) ->
NVar n ((vars :< x <>< args) ++ outer)
shiftVar nvar
= let out = mkSizeOf outer in
case locateNVar out nvar of
Left nvar => embed nvar
Right (MkNVar p) => weakenNs out (shiftUnder p)
Right (MkNVar p) => weakenNs out (shiftUndersN (mkSizeOf _) p)

mutual
renameVar : IsVar x i ((vars :< old <>< args) ++ local) ->
IsVar x i ((vars :< new <>< args) ++ local)
renameVar = believe_me -- it's the same index, so just the identity at run time

shiftBinder : {outer, args : _} ->
(new : Name) ->
CExp (outer ++ old :: (args ++ vars)) ->
CExp (outer ++ (args ++ new :: vars))
CExp (((vars <>< args) :< old) ++ outer) ->
CExp ((vars :< new <>< args) ++ outer)
shiftBinder new (CLocal fc p)
= case shiftVar (MkNVar p) of
MkNVar p' => CLocal fc (renameVar p')
where
renameVar : IsVar x i (outer ++ (args ++ (old :: rest))) ->
IsVar x i (outer ++ (args ++ (new :: rest)))
renameVar = believe_me -- it's the same index, so just the identity at run time
shiftBinder new (CRef fc n) = CRef fc n
shiftBinder {outer} new (CLam fc n sc)
= CLam fc n $ shiftBinder {outer = n :: outer} new sc
= CLam fc n $ shiftBinder {outer = outer :< n} new sc
shiftBinder new (CLet fc n inlineOK val sc)
= CLet fc n inlineOK (shiftBinder new val)
$ shiftBinder {outer = n :: outer} new sc
$ shiftBinder {outer = outer :< n} new sc
shiftBinder new (CApp fc f args)
= CApp fc (shiftBinder new f) $ map (shiftBinder new) args
shiftBinder new (CCon fc ci c tag args)
Expand All @@ -87,34 +92,34 @@ mutual

shiftBinderConAlt : {outer, args : _} ->
(new : Name) ->
CConAlt (outer ++ (x :: args ++ vars)) ->
CConAlt (outer ++ (args ++ new :: vars))
CConAlt (((vars <>< args) :< old) ++ outer) ->
CConAlt ((vars :< new <>< args) ++ outer)
shiftBinderConAlt new (MkConAlt n ci t args' sc)
= let sc' : CExp ((args' ++ outer) ++ (x :: args ++ vars))
= rewrite sym (appendAssociative args' outer (x :: args ++ vars)) in sc in
= let sc' : CExp (((vars <>< args) :< old) ++ (outer <>< args'))
= rewrite sym $ snocAppendFishAssociative (vars <>< args :< old) outer args' in sc in
MkConAlt n ci t args' $
rewrite (appendAssociative args' outer (args ++ new :: vars))
in shiftBinder new {outer = args' ++ outer} sc'
rewrite snocAppendFishAssociative (vars :< new <>< args) outer args'
in shiftBinder new {outer = outer <>< args'} sc'

shiftBinderConstAlt : {outer, args : _} ->
(new : Name) ->
CConstAlt (outer ++ (x :: args ++ vars)) ->
CConstAlt (outer ++ (args ++ new :: vars))
CConstAlt (((vars <>< args) :< old) ++ outer) ->
CConstAlt ((vars :< new <>< args) ++ outer)
shiftBinderConstAlt new (MkConstAlt c sc) = MkConstAlt c $ shiftBinder new sc

-- If there's a lambda inside a case, move the variable so that it's bound
-- outside the case block so that we can bind it just once outside the block
liftOutLambda : {args : _} ->
(new : Name) ->
CExp (old :: args ++ vars) ->
CExp (args ++ new :: vars)
liftOutLambda = shiftBinder {outer = []}
CExp (vars <>< args :< old) ->
CExp (vars :< new <>< args)
liftOutLambda = shiftBinder {outer = ScopeEmpty}

-- If all the alternatives start with a lambda, we can have a single lambda
-- binding outside
tryLiftOut : (new : Name) ->
List (CConAlt vars) ->
Maybe (List (CConAlt (new :: vars)))
Maybe (List (CConAlt (vars :< new)))
tryLiftOut new [] = Just []
tryLiftOut new (MkConAlt n ci t args (CLam fc x sc) :: as)
= do as' <- tryLiftOut new as
Expand All @@ -124,7 +129,7 @@ tryLiftOut _ _ = Nothing

tryLiftOutConst : (new : Name) ->
List (CConstAlt vars) ->
Maybe (List (CConstAlt (new :: vars)))
Maybe (List (CConstAlt (vars :< new)))
tryLiftOutConst new [] = Just []
tryLiftOutConst new (MkConstAlt c (CLam fc x sc) :: as)
= do as' <- tryLiftOutConst new as
Expand All @@ -134,7 +139,7 @@ tryLiftOutConst _ _ = Nothing

tryLiftDef : (new : Name) ->
Maybe (CExp vars) ->
Maybe (Maybe (CExp (new :: vars)))
Maybe (Maybe (CExp (vars :< new)))
tryLiftDef new Nothing = Just Nothing
tryLiftDef new (Just (CLam fc x sc))
= let sc' = liftOutLambda {args = []} new sc in
Expand Down Expand Up @@ -313,8 +318,8 @@ doCaseOfCase fc x xalts xdef alts def
updateAlt (MkConAlt n ci t args sc)
= MkConAlt n ci t args $
CConCase fc sc
(map (weakenNs (mkSizeOf args)) alts)
(map (weakenNs (mkSizeOf args)) def)
(map (weakensN (mkSizeOf args)) alts)
(map (weakensN (mkSizeOf args)) def)

updateDef : CExp vars -> CExp vars
updateDef sc = CConCase fc sc alts def
Expand Down
10 changes: 5 additions & 5 deletions src/Compiler/Common.idr
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Ord UsePhase where
public export
record CompileData where
constructor MkCompileData
mainExpr : CExp [] -- main expression to execute. This also appears in
mainExpr : ClosedCExp -- main expression to execute. This also appears in
-- the definitions below as MN "__mainExpression" 0
-- For incremental compilation and for compiling exported
-- names only, this can be set to 'erased'.
Expand Down Expand Up @@ -153,7 +153,7 @@ getMinimalDef (Coded ns bin)
name <- fromBuf b
let def
= MkGlobalDef fc name (Erased fc Placeholder) [] [] [] [] mul
[] (specified Public) (MkTotality Unchecked IsCovering) False
ScopeEmpty (specified Public) (MkTotality Unchecked IsCovering) False
[] Nothing refsR False False True
None cdef Nothing [] Nothing
pure (def, Just (ns, bin))
Expand Down Expand Up @@ -355,8 +355,8 @@ getCompileDataWith exports doLazyAnnots phase_in tm_in
traverse (lambdaLift doLazyAnnots) cseDefs
else pure []

let lifted = (mainname, MkLFun [] [] liftedtm) ::
ldefs ++ concat lifted_in
let lifted = (mainname, MkLFun ScopeEmpty ScopeEmpty liftedtm) ::
(ldefs ++ concat lifted_in)

anf <- if phase >= ANF
then logTime 2 "Get ANF" $ traverse (\ (n, d) => pure (n, !(toANF d))) lifted
Expand Down Expand Up @@ -412,7 +412,7 @@ getCompileData = getCompileDataWith []

export
compileTerm : {auto c : Ref Ctxt Defs} ->
ClosedTerm -> Core (CExp [])
ClosedTerm -> Core ClosedCExp
compileTerm tm_in
= do tm <- toFullNames tm_in
fixArityExp !(compileExp tm)
Expand Down
Loading