Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions src/Juvix/Compiler/Backend/Isabelle/Translation/FromTyped.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Data.List.NonEmpty.Extra qualified as NonEmpty
import Data.Text qualified as T
import Data.Text qualified as Text
import Juvix.Compiler.Backend.Isabelle.Data.Result
Expand Down Expand Up @@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
mkExprCase c@Case {..} = case _caseValue of
ExprIden v ->
case _caseBranches of
CaseBranch {..} :| [] ->
CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatVar v' -> substVar v' v _caseBranchBody
_ -> ExprCase c
_ -> ExprCase c
ExprTuple (Tuple (ExprIden v :| [])) ->
case _caseBranches of
CaseBranch {..} :| [] ->
CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatTuple (Tuple (PatVar v' :| [])) -> substVar v' v _caseBranchBody
_ -> ExprCase c
_ -> ExprCase c
_ -> ExprCase c
_ ->
case _caseBranches of
br@CaseBranch {..} :| _ ->
case _caseBranchPattern of
PatVar _ ->
ExprCase
Case
{ _caseValue = _caseValue,
_caseBranches = br :| []
}
PatTuple (Tuple (PatVar _ :| [])) ->
ExprCase
Case
{ _caseValue = _caseValue,
_caseBranches = br :| []
}
_ -> ExprCase c

goMutualBlock :: Internal.MutualBlock -> [Statement]
goMutualBlock Internal.MutualBlock {..} =
Expand Down Expand Up @@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
: goClauses cls
Nested pats npats ->
let rhs = goExpression'' nset' nmap' _lambdaBody
argnames' = fmap getPatternArgName _lambdaPatterns
argnames' = fmap getPatternArgName lambdaPats
vnames =
fmap
( \(idx :: Int, mname) ->
maybe
( defaultName
(getLoc cl)
( disambiguate
(nset' ^. nameSet)
("v_" <> show idx)
)
)
(overNameText (disambiguate (nset' ^. nameSet)))
mname
)
(NonEmpty.zip (nonEmpty' [0 ..]) argnames')
nonEmpty' $
fmap
( \(idx :: Int, mname) ->
maybe
( defaultName
(getLoc cl)
( disambiguate
(nset' ^. nameSet)
("v_" <> show idx)
)
)
(overNameText (disambiguate (nset' ^. nameSet)))
mname
)
(zip [0 ..] argnames')
nset'' = foldl' (flip (over nameSet . HashSet.insert . (^. namePretty))) nset' vnames
remainingBranches = goLambdaClauses'' nset'' nmap' cls
remainingBranches = goLambdaClauses'' nset'' nmap' (Just ty) cls
valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
patTuple = PatTuple (Tuple (nonEmpty' pats))
brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
Expand All @@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
}
]
where
(npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
(npats0, nset', nmap') = goPatternArgsTop lambdaPats
[] -> []

goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch] -> Pattern -> NonEmpty (Expression, Nested Pattern) -> NonEmpty CaseBranch
Expand Down Expand Up @@ -835,18 +850,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
| patsNum == 0 = goExpression (head _lambdaClauses ^. Internal.lambdaBody)
| otherwise = goLams vars
where
patsNum =
case _lambdaType of
Just ty ->
length
. filterTypeArgs 0 ty
. toList
$ head _lambdaClauses ^. Internal.lambdaPatterns
Nothing ->
length
. filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
. toList
$ head _lambdaClauses ^. Internal.lambdaPatterns
patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal.lambdaPatterns
vars = map (\i -> defaultName (getLoc lam) ("x" <> show i)) [0 .. patsNum - 1]

goLams :: [Name] -> Sem r Expression
Expand Down Expand Up @@ -876,7 +880,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
Tuple
{ _tupleComponents = nonEmpty' vars'
}
brs <- goLambdaClauses (toList _lambdaClauses)
brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
return $
mkExprCase
Case
Expand Down Expand Up @@ -933,17 +937,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
Internal.CaseBranchRhsExpression e -> goExpression e
Internal.CaseBranchRhsIf {} -> error "unsupported: side conditions"

goLambdaClauses'' :: NameSet -> NameMap -> [Internal.LambdaClause] -> [CaseBranch]
goLambdaClauses'' nset nmap cls =
run $ runReader nset $ runReader nmap $ goLambdaClauses cls

goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => [Internal.LambdaClause] -> Sem r [CaseBranch]
goLambdaClauses = \case
filterLambdaPatternArgs :: Maybe Internal.Expression -> NonEmpty Internal.PatternArg -> [Internal.PatternArg]
filterLambdaPatternArgs mty cls = case mty of
Just ty ->
filterTypeArgs 0 ty
. toList
$ cls
Nothing ->
filter ((/= Internal.Implicit) . (^. Internal.patternArgIsImplicit))
. toList
$ cls

goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal.Expression -> [Internal.LambdaClause] -> [CaseBranch]
goLambdaClauses'' nset nmap mty cls =
run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls

goLambdaClauses :: forall r. (Members '[Reader NameSet, Reader NameMap] r) => Maybe Internal.Expression -> [Internal.LambdaClause] -> Sem r [CaseBranch]
goLambdaClauses mty = \case
[email protected] {..} : cls -> do
(npat, nset, nmap) <- case _lambdaPatterns of
p :| [] -> goPatternArgCase p
let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
(npat, nset, nmap) <- case lambdaPats of
[p] -> goPatternArgCase p
_ -> do
(npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
(npats, nset, nmap) <- goPatternArgsCase lambdaPats
let npat =
fmap
( \pats ->
Expand All @@ -957,7 +973,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
case npat of
Nested pat [] -> do
body <- withLocalNames nset nmap $ goExpression _lambdaBody
brs <- goLambdaClauses cls
brs <- goLambdaClauses mty cls
return $
CaseBranch
{ _caseBranchPattern = pat,
Expand All @@ -968,7 +984,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) "v")
nset' = over nameSet (HashSet.insert (vname ^. namePretty)) nset
rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
return
[ CaseBranch
Expand Down Expand Up @@ -1140,7 +1156,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
case HashMap.lookup name (infoTable ^. Internal.infoConstructors) of
Just ctrInfo
| ctrInfo ^. Internal.constructorInfoRecord ->
Just (indName, goRecordFields (getArgtys ctrInfo) args)
case HashMap.lookup indName (infoTable ^. Internal.infoInductives) of
Just indInfo
| length (indInfo ^. Internal.inductiveInfoConstructors) == 1 ->
Just (indName, goRecordFields (getArgtys ctrInfo) args)
_ -> Nothing
where
indName = ctrInfo ^. Internal.constructorInfoInductive
_ -> Nothing
Expand Down
47 changes: 47 additions & 0 deletions tests/positive/Isabelle/Program.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,50 @@ funR4 : R -> R
bf (b1 b2 : Bool) : Bool := not (b1 && b2);

nf (n1 n2 : Int) : Bool := n1 - n2 >= n1 || n2 <= n1 + n2;

-- Nested record patterns

type MessagePacket (MessageType : Type) : Type := mkMessagePacket {
target : Nat;
mailbox : Maybe Nat;
message : MessageType;
};

open MessagePacket;

type EnvelopedMessage (MessageType : Type) : Type :=
mkEnvelopedMessage {
sender : Maybe Nat;
packet : MessagePacket MessageType;
};

open EnvelopedMessage;

type Timer (HandleType : Type): Type := mkTimer {
time : Nat;
handle : HandleType;
};

type Trigger (MessageType : Type) (HandleType : Type) :=
| MessageArrived { envelope : EnvelopedMessage MessageType; }
| Elapsed { timers : List (Timer HandleType) };

open Trigger;

getMessageFromTrigger : {M H : Type} -> Trigger M H -> Maybe M
| (MessageArrived@{
envelope := (mkEnvelopedMessage@{
packet := (mkMessagePacket@{
message := m })})})
:= just m
| _ := nothing;


getMessageFromTrigger' {M H} (t : Trigger M H) : Maybe M :=
case t of
| (MessageArrived@{
envelope := (mkEnvelopedMessage@{
packet := (mkMessagePacket@{
message := m })})})
:= just m
| _ := nothing;
60 changes: 60 additions & 0 deletions tests/positive/Isabelle/isabelle/Program.thy
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,64 @@ fun bf :: "bool \<Rightarrow> bool \<Rightarrow> bool" where
fun nf :: "int \<Rightarrow> int \<Rightarrow> bool" where
"nf n1 n2 = (n1 - n2 \<ge> n1 \<or> n2 \<le> n1 + n2)"

(* Nested record patterns *)
record 'MessageType MessagePacket =
target :: nat
mailbox :: "nat option"
message :: 'MessageType

fun target :: "'MessageType MessagePacket \<Rightarrow> nat" where
"target (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
target'"

fun mailbox :: "'MessageType MessagePacket \<Rightarrow> nat option" where
"mailbox (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
mailbox'"

fun message :: "'MessageType MessagePacket \<Rightarrow> 'MessageType" where
"message (| MessagePacket.target = target', MessagePacket.mailbox = mailbox', MessagePacket.message = message' |) =
message'"

record 'MessageType EnvelopedMessage =
sender :: "nat option"
packet :: "'MessageType MessagePacket"

fun sender :: "'MessageType EnvelopedMessage \<Rightarrow> nat option" where
"sender (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) =
sender'"

fun packet :: "'MessageType EnvelopedMessage \<Rightarrow> 'MessageType MessagePacket" where
"packet (| EnvelopedMessage.sender = sender', EnvelopedMessage.packet = packet' |) =
packet'"

record 'HandleType Timer =
time :: nat
handle :: 'HandleType

fun time :: "'HandleType Timer \<Rightarrow> nat" where
"time (| Timer.time = time', Timer.handle = handle' |) = time'"

fun handle :: "'HandleType Timer \<Rightarrow> 'HandleType" where
"handle (| Timer.time = time', Timer.handle = handle' |) = handle'"

datatype ('MessageType, 'HandleType) Trigger
= MessageArrived "'MessageType EnvelopedMessage" |
Elapsed "('HandleType Timer) list"

fun getMessageFromTrigger :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
"getMessageFromTrigger v_0 =
(case (v_0) of
(MessageArrived v') \<Rightarrow>
(case (EnvelopedMessage.packet v') of
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
v'1 \<Rightarrow> None)"

fun getMessageFromTrigger' :: "('M, 'H) Trigger \<Rightarrow> 'M option" where
"getMessageFromTrigger' t =
(case t of
(MessageArrived v') \<Rightarrow>
(case (EnvelopedMessage.packet v') of
(v'0) \<Rightarrow> Some (MessagePacket.message v'0)) |
v'2 \<Rightarrow> None)"

end
Loading