@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where
22
33import Data.HashMap.Strict qualified as HashMap
44import Data.HashSet qualified as HashSet
5- import Data.List.NonEmpty.Extra qualified as NonEmpty
65import Data.Text qualified as T
76import Data.Text qualified as Text
87import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
9594 mkExprCase c@ Case {.. } = case _caseValue of
9695 ExprIden v ->
9796 case _caseBranches of
98- CaseBranch {.. } :| [] ->
97+ CaseBranch {.. } :| _ ->
9998 case _caseBranchPattern of
10099 PatVar v' -> substVar v' v _caseBranchBody
101100 _ -> ExprCase c
102- _ -> ExprCase c
103101 ExprTuple (Tuple (ExprIden v :| [] )) ->
104102 case _caseBranches of
105- CaseBranch {.. } :| [] ->
103+ CaseBranch {.. } :| _ ->
106104 case _caseBranchPattern of
107105 PatTuple (Tuple (PatVar v' :| [] )) -> substVar v' v _caseBranchBody
108106 _ -> ExprCase c
109- _ -> ExprCase c
110- _ -> ExprCase c
107+ _ ->
108+ case _caseBranches of
109+ br@ CaseBranch {.. } :| _ ->
110+ case _caseBranchPattern of
111+ PatVar _ ->
112+ ExprCase
113+ Case
114+ { _caseValue = _caseValue,
115+ _caseBranches = br :| []
116+ }
117+ PatTuple (Tuple (PatVar _ :| [] )) ->
118+ ExprCase
119+ Case
120+ { _caseValue = _caseValue,
121+ _caseBranches = br :| []
122+ }
123+ _ -> ExprCase c
111124
112125 goMutualBlock :: Internal. MutualBlock -> [Statement ]
113126 goMutualBlock Internal. MutualBlock {.. } =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
243256 : goClauses cls
244257 Nested pats npats ->
245258 let rhs = goExpression'' nset' nmap' _lambdaBody
246- argnames' = fmap getPatternArgName _lambdaPatterns
259+ argnames' = fmap getPatternArgName lambdaPats
247260 vnames =
248- fmap
249- ( \ (idx :: Int , mname ) ->
250- maybe
251- ( defaultName
252- (getLoc cl)
253- ( disambiguate
254- (nset' ^. nameSet)
255- (" v_" <> show idx)
256- )
257- )
258- (overNameText (disambiguate (nset' ^. nameSet)))
259- mname
260- )
261- (NonEmpty. zip (nonEmpty' [0 .. ]) argnames')
261+ nonEmpty' $
262+ fmap
263+ ( \ (idx :: Int , mname ) ->
264+ maybe
265+ ( defaultName
266+ (getLoc cl)
267+ ( disambiguate
268+ (nset' ^. nameSet)
269+ (" v_" <> show idx)
270+ )
271+ )
272+ (overNameText (disambiguate (nset' ^. nameSet)))
273+ mname
274+ )
275+ (zip [0 .. ] argnames')
262276 nset'' = foldl' (flip (over nameSet . HashSet. insert . (^. namePretty))) nset' vnames
263- remainingBranches = goLambdaClauses'' nset'' nmap' cls
277+ remainingBranches = goLambdaClauses'' nset'' nmap' ( Just ty) cls
264278 valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
265279 patTuple = PatTuple (Tuple (nonEmpty' pats))
266280 brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
275289 }
276290 ]
277291 where
278- (npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
292+ lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
293+ (npats0, nset', nmap') = goPatternArgsTop lambdaPats
279294 [] -> []
280295
281296 goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch ] -> Pattern -> NonEmpty (Expression , Nested Pattern ) -> NonEmpty CaseBranch
@@ -828,18 +843,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
828843 | patsNum == 0 = goExpression (head _lambdaClauses ^. Internal. lambdaBody)
829844 | otherwise = goLams vars
830845 where
831- patsNum =
832- case _lambdaType of
833- Just ty ->
834- length
835- . filterTypeArgs 0 ty
836- . toList
837- $ head _lambdaClauses ^. Internal. lambdaPatterns
838- Nothing ->
839- length
840- . filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
841- . toList
842- $ head _lambdaClauses ^. Internal. lambdaPatterns
846+ patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal. lambdaPatterns
843847 vars = map (\ i -> defaultName (getLoc lam) (" x" <> show i)) [0 .. patsNum - 1 ]
844848
845849 goLams :: [Name ] -> Sem r Expression
@@ -869,7 +873,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
869873 Tuple
870874 { _tupleComponents = nonEmpty' vars'
871875 }
872- brs <- goLambdaClauses (toList _lambdaClauses)
876+ brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
873877 return $
874878 mkExprCase
875879 Case
@@ -926,17 +930,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
926930 Internal. CaseBranchRhsExpression e -> goExpression e
927931 Internal. CaseBranchRhsIf {} -> error " unsupported: side conditions"
928932
929- goLambdaClauses'' :: NameSet -> NameMap -> [Internal. LambdaClause ] -> [CaseBranch ]
930- goLambdaClauses'' nset nmap cls =
931- run $ runReader nset $ runReader nmap $ goLambdaClauses cls
932-
933- goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => [Internal. LambdaClause ] -> Sem r [CaseBranch ]
934- goLambdaClauses = \ case
933+ filterLambdaPatternArgs :: Maybe Internal. Expression -> NonEmpty Internal. PatternArg -> [Internal. PatternArg ]
934+ filterLambdaPatternArgs mty cls = case mty of
935+ Just ty ->
936+ filterTypeArgs 0 ty
937+ . toList
938+ $ cls
939+ Nothing ->
940+ filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
941+ . toList
942+ $ cls
943+
944+ goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal. Expression -> [Internal. LambdaClause ] -> [CaseBranch ]
945+ goLambdaClauses'' nset nmap mty cls =
946+ run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls
947+
948+ goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => Maybe Internal. Expression -> [Internal. LambdaClause ] -> Sem r [CaseBranch ]
949+ goLambdaClauses mty = \ case
935950 cl@ Internal. LambdaClause {.. } : cls -> do
936- (npat, nset, nmap) <- case _lambdaPatterns of
937- p :| [] -> goPatternArgCase p
951+ let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
952+ (npat, nset, nmap) <- case lambdaPats of
953+ [p] -> goPatternArgCase p
938954 _ -> do
939- (npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
955+ (npats, nset, nmap) <- goPatternArgsCase lambdaPats
940956 let npat =
941957 fmap
942958 ( \ pats ->
@@ -950,7 +966,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
950966 case npat of
951967 Nested pat [] -> do
952968 body <- withLocalNames nset nmap $ goExpression _lambdaBody
953- brs <- goLambdaClauses cls
969+ brs <- goLambdaClauses mty cls
954970 return $
955971 CaseBranch
956972 { _caseBranchPattern = pat,
@@ -961,7 +977,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
961977 let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) " v" )
962978 nset' = over nameSet (HashSet. insert (vname ^. namePretty)) nset
963979 rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
964- remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
980+ remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
965981 let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
966982 return
967983 [ CaseBranch
@@ -1133,7 +1149,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
11331149 case HashMap. lookup name (infoTable ^. Internal. infoConstructors) of
11341150 Just ctrInfo
11351151 | ctrInfo ^. Internal. constructorInfoRecord ->
1136- Just (indName, goRecordFields (getArgtys ctrInfo) args)
1152+ case HashMap. lookup indName (infoTable ^. Internal. infoInductives) of
1153+ Just indInfo
1154+ | length (indInfo ^. Internal. inductiveInfoConstructors) == 1 ->
1155+ Just (indName, goRecordFields (getArgtys ctrInfo) args)
1156+ _ -> Nothing
11371157 where
11381158 indName = ctrInfo ^. Internal. constructorInfoInductive
11391159 _ -> Nothing
0 commit comments