@@ -2,7 +2,6 @@ module Juvix.Compiler.Backend.Isabelle.Translation.FromTyped where
22
33import Data.HashMap.Strict qualified as HashMap
44import Data.HashSet qualified as HashSet
5- import Data.List.NonEmpty.Extra qualified as NonEmpty
65import Data.Text qualified as T
76import Data.Text qualified as Text
87import Juvix.Compiler.Backend.Isabelle.Data.Result
@@ -95,19 +94,33 @@ goModule onlyTypes infoTable Internal.Module {..} =
9594 mkExprCase c@ Case {.. } = case _caseValue of
9695 ExprIden v ->
9796 case _caseBranches of
98- CaseBranch {.. } :| [] ->
97+ CaseBranch {.. } :| _ ->
9998 case _caseBranchPattern of
10099 PatVar v' -> substVar v' v _caseBranchBody
101100 _ -> ExprCase c
102- _ -> ExprCase c
103101 ExprTuple (Tuple (ExprIden v :| [] )) ->
104102 case _caseBranches of
105- CaseBranch {.. } :| [] ->
103+ CaseBranch {.. } :| _ ->
106104 case _caseBranchPattern of
107105 PatTuple (Tuple (PatVar v' :| [] )) -> substVar v' v _caseBranchBody
108106 _ -> ExprCase c
109- _ -> ExprCase c
110- _ -> ExprCase c
107+ _ ->
108+ case _caseBranches of
109+ br@ CaseBranch {.. } :| _ ->
110+ case _caseBranchPattern of
111+ PatVar _ ->
112+ ExprCase
113+ Case
114+ { _caseValue = _caseValue,
115+ _caseBranches = br :| []
116+ }
117+ PatTuple (Tuple (PatVar _ :| [] )) ->
118+ ExprCase
119+ Case
120+ { _caseValue = _caseValue,
121+ _caseBranches = br :| []
122+ }
123+ _ -> ExprCase c
111124
112125 goMutualBlock :: Internal. MutualBlock -> [Statement ]
113126 goMutualBlock Internal. MutualBlock {.. } =
@@ -243,24 +256,25 @@ goModule onlyTypes infoTable Internal.Module {..} =
243256 : goClauses cls
244257 Nested pats npats ->
245258 let rhs = goExpression'' nset' nmap' _lambdaBody
246- argnames' = fmap getPatternArgName _lambdaPatterns
259+ argnames' = fmap getPatternArgName lambdaPats
247260 vnames =
248- fmap
249- ( \ (idx :: Int , mname ) ->
250- maybe
251- ( defaultName
252- (getLoc cl)
253- ( disambiguate
254- (nset' ^. nameSet)
255- (" v_" <> show idx)
256- )
257- )
258- (overNameText (disambiguate (nset' ^. nameSet)))
259- mname
260- )
261- (NonEmpty. zip (nonEmpty' [0 .. ]) argnames')
261+ nonEmpty' $
262+ fmap
263+ ( \ (idx :: Int , mname ) ->
264+ maybe
265+ ( defaultName
266+ (getLoc cl)
267+ ( disambiguate
268+ (nset' ^. nameSet)
269+ (" v_" <> show idx)
270+ )
271+ )
272+ (overNameText (disambiguate (nset' ^. nameSet)))
273+ mname
274+ )
275+ (zip [0 .. ] argnames')
262276 nset'' = foldl' (flip (over nameSet . HashSet. insert . (^. namePretty))) nset' vnames
263- remainingBranches = goLambdaClauses'' nset'' nmap' cls
277+ remainingBranches = goLambdaClauses'' nset'' nmap' ( Just ty) cls
264278 valTuple = ExprTuple (Tuple (fmap ExprIden vnames))
265279 patTuple = PatTuple (Tuple (nonEmpty' pats))
266280 brs = goNestedBranches (getLoc cl) valTuple rhs remainingBranches patTuple (nonEmpty' npats)
@@ -275,7 +289,8 @@ goModule onlyTypes infoTable Internal.Module {..} =
275289 }
276290 ]
277291 where
278- (npats0, nset', nmap') = goPatternArgsTop (filterTypeArgs 0 ty (toList _lambdaPatterns))
292+ lambdaPats = filterTypeArgs 0 ty (toList _lambdaPatterns)
293+ (npats0, nset', nmap') = goPatternArgsTop lambdaPats
279294 [] -> []
280295
281296 goNestedBranches :: Interval -> Expression -> Expression -> [CaseBranch ] -> Pattern -> NonEmpty (Expression , Nested Pattern ) -> NonEmpty CaseBranch
@@ -835,18 +850,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
835850 | patsNum == 0 = goExpression (head _lambdaClauses ^. Internal. lambdaBody)
836851 | otherwise = goLams vars
837852 where
838- patsNum =
839- case _lambdaType of
840- Just ty ->
841- length
842- . filterTypeArgs 0 ty
843- . toList
844- $ head _lambdaClauses ^. Internal. lambdaPatterns
845- Nothing ->
846- length
847- . filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
848- . toList
849- $ head _lambdaClauses ^. Internal. lambdaPatterns
853+ patsNum = length $ filterLambdaPatternArgs _lambdaType $ head _lambdaClauses ^. Internal. lambdaPatterns
850854 vars = map (\ i -> defaultName (getLoc lam) (" x" <> show i)) [0 .. patsNum - 1 ]
851855
852856 goLams :: [Name ] -> Sem r Expression
@@ -876,7 +880,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
876880 Tuple
877881 { _tupleComponents = nonEmpty' vars'
878882 }
879- brs <- goLambdaClauses (toList _lambdaClauses)
883+ brs <- goLambdaClauses _lambdaType (toList _lambdaClauses)
880884 return $
881885 mkExprCase
882886 Case
@@ -933,17 +937,29 @@ goModule onlyTypes infoTable Internal.Module {..} =
933937 Internal. CaseBranchRhsExpression e -> goExpression e
934938 Internal. CaseBranchRhsIf {} -> error " unsupported: side conditions"
935939
936- goLambdaClauses'' :: NameSet -> NameMap -> [Internal. LambdaClause ] -> [CaseBranch ]
937- goLambdaClauses'' nset nmap cls =
938- run $ runReader nset $ runReader nmap $ goLambdaClauses cls
939-
940- goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => [Internal. LambdaClause ] -> Sem r [CaseBranch ]
941- goLambdaClauses = \ case
940+ filterLambdaPatternArgs :: Maybe Internal. Expression -> NonEmpty Internal. PatternArg -> [Internal. PatternArg ]
941+ filterLambdaPatternArgs mty cls = case mty of
942+ Just ty ->
943+ filterTypeArgs 0 ty
944+ . toList
945+ $ cls
946+ Nothing ->
947+ filter ((/= Internal. Implicit ) . (^. Internal. patternArgIsImplicit))
948+ . toList
949+ $ cls
950+
951+ goLambdaClauses'' :: NameSet -> NameMap -> Maybe Internal. Expression -> [Internal. LambdaClause ] -> [CaseBranch ]
952+ goLambdaClauses'' nset nmap mty cls =
953+ run $ runReader nset $ runReader nmap $ goLambdaClauses mty cls
954+
955+ goLambdaClauses :: forall r . (Members '[Reader NameSet , Reader NameMap ] r ) => Maybe Internal. Expression -> [Internal. LambdaClause ] -> Sem r [CaseBranch ]
956+ goLambdaClauses mty = \ case
942957 cl@ Internal. LambdaClause {.. } : cls -> do
943- (npat, nset, nmap) <- case _lambdaPatterns of
944- p :| [] -> goPatternArgCase p
958+ let lambdaPats = filterLambdaPatternArgs mty _lambdaPatterns
959+ (npat, nset, nmap) <- case lambdaPats of
960+ [p] -> goPatternArgCase p
945961 _ -> do
946- (npats, nset, nmap) <- goPatternArgsCase (toList _lambdaPatterns)
962+ (npats, nset, nmap) <- goPatternArgsCase lambdaPats
947963 let npat =
948964 fmap
949965 ( \ pats ->
@@ -957,7 +973,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
957973 case npat of
958974 Nested pat [] -> do
959975 body <- withLocalNames nset nmap $ goExpression _lambdaBody
960- brs <- goLambdaClauses cls
976+ brs <- goLambdaClauses mty cls
961977 return $
962978 CaseBranch
963979 { _caseBranchPattern = pat,
@@ -968,7 +984,7 @@ goModule onlyTypes infoTable Internal.Module {..} =
968984 let vname = defaultName (getLoc cl) (disambiguate (nset ^. nameSet) " v" )
969985 nset' = over nameSet (HashSet. insert (vname ^. namePretty)) nset
970986 rhs <- withLocalNames nset' nmap $ goExpression _lambdaBody
971- remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses cls
987+ remainingBranches <- withLocalNames nset' nmap $ goLambdaClauses mty cls
972988 let brs' = goNestedBranches (getLoc vname) (ExprIden vname) rhs remainingBranches pat (nonEmpty' npats)
973989 return
974990 [ CaseBranch
@@ -1140,7 +1156,11 @@ goModule onlyTypes infoTable Internal.Module {..} =
11401156 case HashMap. lookup name (infoTable ^. Internal. infoConstructors) of
11411157 Just ctrInfo
11421158 | ctrInfo ^. Internal. constructorInfoRecord ->
1143- Just (indName, goRecordFields (getArgtys ctrInfo) args)
1159+ case HashMap. lookup indName (infoTable ^. Internal. infoInductives) of
1160+ Just indInfo
1161+ | length (indInfo ^. Internal. inductiveInfoConstructors) == 1 ->
1162+ Just (indName, goRecordFields (getArgtys ctrInfo) args)
1163+ _ -> Nothing
11441164 where
11451165 indName = ctrInfo ^. Internal. constructorInfoInductive
11461166 _ -> Nothing
0 commit comments