Skip to content

Commit 21a1521

Browse files
committed
type hint
1 parent 619e72e commit 21a1521

File tree

3 files changed

+60
-28
lines changed

3 files changed

+60
-28
lines changed

src/Juvix/Compiler/Core/Translation/FromInternal.hs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1437,9 +1437,7 @@ goExpression = \case
14371437
md <- getModule
14381438
return (goLiteral (fromJust $ getInfoLiteralIntToNat md) (fromJust $ getInfoLiteralIntToInt md) l)
14391439
Internal.ExpressionIden i -> goIden i
1440-
Internal.ExpressionApplication a -> do
1441-
traceM ("FromInternal: " <> Internal.ppTrace a)
1442-
goApplication a
1440+
Internal.ExpressionApplication a -> goApplication a
14431441
Internal.ExpressionSimpleLambda l -> goSimpleLambda l
14441442
Internal.ExpressionLambda l -> goLambda l
14451443
Internal.ExpressionCase l -> goCase l

src/Juvix/Compiler/Internal/Translation/FromInternal/Analysis/TypeChecking/CheckerNew.hs

+54-25
Original file line numberDiff line numberDiff line change
@@ -600,10 +600,11 @@ checkExpression ::
600600
Expression ->
601601
Sem r Expression
602602
checkExpression expectedTy e = do
603-
let hint = TypeHint {
604-
_typeHint = Just expectedTy,
605-
_typeHintTypeNatural = False
606-
}
603+
let hint =
604+
TypeHint
605+
{ _typeHint = Just expectedTy,
606+
_typeHintTypeNatural = False
607+
}
607608
e' <- inferExpression' hint e
608609
let inferredType = e' ^. typedType
609610
whenJustM (matchTypes expectedTy inferredType) (const (err e'))
@@ -1161,7 +1162,7 @@ inferLeftAppExpression mhint e = case e of
11611162

11621163
goLiteral :: LiteralLoc -> Sem r TypedExpression
11631164
goLiteral lit@(WithLoc i l) = case l of
1164-
LitNumeric v -> outHole v >> typedLitNumeric v
1165+
LitNumeric v -> typedLitNumeric v
11651166
LitInteger {} -> do
11661167
ty <- getIntTy
11671168
return $
@@ -1184,46 +1185,70 @@ inferLeftAppExpression mhint e = case e of
11841185
_typedType = ExpressionIden (IdenAxiom str)
11851186
}
11861187
where
1188+
unaryNatural :: Natural -> Sem r TypedExpression
1189+
unaryNatural n = do
1190+
natTy <- getNatTy
1191+
zero' <- mkBuiltinConstructor BuiltinNatZero
1192+
suc' <- mkBuiltinConstructor BuiltinNatSuc
1193+
let mkSuc :: Expression -> Expression
1194+
mkSuc num = suc' @@ num
1195+
return
1196+
TypedExpression
1197+
{ _typedExpression = iterateNat n mkSuc zero',
1198+
_typedType = natTy
1199+
}
1200+
11871201
typedLitNumeric :: Integer -> Sem r TypedExpression
11881202
typedLitNumeric v
1189-
| v < 0 = getIntTy >>= typedLit LitInteger BuiltinFromInt
1190-
| otherwise = getNatTy >>= typedLit LitNatural BuiltinFromNat
1203+
| mhint ^. typeHintTypeNatural, v >= 0 = unaryNatural (fromInteger v)
1204+
| otherwise = do
1205+
castHole v
1206+
if
1207+
| v < 0 -> getIntTy >>= typedLit LitInteger BuiltinFromInt
1208+
| otherwise -> getNatTy >>= typedLit LitNatural BuiltinFromNat
11911209
where
11921210
typedLit :: (Integer -> Literal) -> BuiltinFunction -> Expression -> Sem r TypedExpression
11931211
typedLit litt blt ty = do
11941212
from <- getBuiltinNameTypeChecker i blt
11951213
ihole <- freshHoleImpl i ImplicitInstance
11961214
let ty' = maybe ty (adjustLocation i) (mhint ^. typeHint)
1197-
-- inferExpression' (Just ty') $
1198-
inferExpression' todo $
1215+
inferExpression' (mkTypeHint (Just ty')) $
11991216
foldApplication
12001217
(ExpressionIden (IdenFunction from))
12011218
[ ApplicationArg Implicit ty',
12021219
ApplicationArg ImplicitInstance ihole,
12031220
ApplicationArg Explicit (ExpressionLiteral (WithLoc i (litt v)))
12041221
]
12051222

1223+
mkBuiltinIden :: (IsBuiltin a) => (Name -> Iden) -> a -> Sem r Expression
1224+
mkBuiltinIden mkIden = fmap (ExpressionIden . mkIden) . getBuiltinNameTypeChecker i
1225+
1226+
mkBuiltinConstructor :: BuiltinConstructor -> Sem r Expression
1227+
mkBuiltinConstructor = mkBuiltinIden IdenConstructor
1228+
12061229
mkBuiltinInductive :: BuiltinInductive -> Sem r Expression
1207-
mkBuiltinInductive = fmap (ExpressionIden . IdenInductive) . getBuiltinNameTypeChecker i
1230+
mkBuiltinInductive = mkBuiltinIden IdenInductive
12081231

12091232
getIntTy :: Sem r Expression
12101233
getIntTy = mkBuiltinInductive BuiltinInt
12111234

12121235
getNatTy :: Sem r Expression
12131236
getNatTy = mkBuiltinInductive BuiltinNat
12141237

1215-
outHole :: Integer -> Sem r ()
1216-
outHole v
1217-
| v < 0 = case mhint of
1218-
Just (ExpressionHole h) ->
1219-
output CastHole {_castHoleHole = h, _castHoleType = CastInt}
1220-
_ ->
1221-
return ()
1222-
| otherwise = case mhint of
1223-
Just (ExpressionHole h) ->
1224-
output CastHole {_castHoleHole = h, _castHoleType = CastNat}
1225-
_ ->
1226-
return ()
1238+
castHole :: Integer -> Sem r ()
1239+
castHole v =
1240+
case mhint ^. typeHint of
1241+
Just (ExpressionHole h) ->
1242+
let outCastHole ty =
1243+
output
1244+
CastHole
1245+
{ _castHoleHole = h,
1246+
_castHoleType = ty
1247+
}
1248+
in if
1249+
| v < 0 -> outCastHole CastInt
1250+
| otherwise -> outCastHole CastNat
1251+
_ -> return ()
12271252

12281253
goIden :: Iden -> Sem r TypedExpression
12291254
goIden i = case i of
@@ -1247,11 +1272,15 @@ inferLeftAppExpression mhint e = case e of
12471272
holesHelper :: forall r. (Members '[Reader InfoTable, Reader BuiltinsTable, ResultBuilder, Reader LocalVars, Error TypeCheckerError, NameIdGen, Inference, Output TypedInstanceHole, Termination, Output CastHole, Reader InsertedArgsStack] r) => TypeHint -> Expression -> Sem r TypedExpression
12481273
holesHelper mhint expr = do
12491274
let (f, args) = unfoldExpressionApp expr
1250-
hint
1251-
| null args = mhint
1252-
| otherwise = set typeHint Nothing mhint
1275+
hint <- execState mhint $ do
1276+
unless (null args) (modify (set typeHint Nothing))
1277+
f' <- weakNormalize f
1278+
case f' of
1279+
ExpressionIden IdenInductive {} -> modify (set typeHintTypeNatural True)
1280+
_ -> return ()
12531281
arityCheckBuiltins f args
12541282
fTy <- inferLeftAppExpression hint f
1283+
12551284
iniBuilderType <- mkInitBuilderType fTy
12561285
let iniArg :: ApplicationArg -> AppBuilderArg
12571286
iniArg a =

src/Juvix/Prelude/Base/Foundation.hs

+5
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,11 @@ massert b = assert b (pure ())
660660
iterateN :: Int -> (a -> a) -> a -> a
661661
iterateN n f = (!! n) . iterate f
662662

663+
iterateNat :: Natural -> (a -> a) -> a -> a
664+
iterateNat n f x
665+
| n == 0 = x
666+
| otherwise = f (iterateNat (n - 1) f x)
667+
663668
nubHashableNonEmpty :: (Hashable a) => NonEmpty a -> NonEmpty a
664669
nubHashableNonEmpty = nonEmpty' . HashSet.toList . HashSet.fromList . toList
665670

0 commit comments

Comments
 (0)