Skip to content

Commit 6117c7c

Browse files
committed
feat: Add numeric promotion.
1 parent e1f6666 commit 6117c7c

File tree

17 files changed

+786
-34
lines changed

17 files changed

+786
-34
lines changed

src/DataFrame/Functions.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ zScore :: Expr Double -> Expr Double
335335
zScore c = (c - mean c) / stddev c
336336

337337
pow :: (Columnable a, Num a) => Expr a -> Int -> Expr a
338-
pow = (.^^)
338+
pow expr i = lift2Decorated (^) "max" (Just "^") True 8 expr (Lit i)
339339

340340
relu :: (Columnable a, Num a, Ord a) => Expr a -> Expr a
341341
relu = liftDecorated (Prelude.max 0) "relu" Nothing

src/DataFrame/Internal/Nullable.hs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
{-# LANGUAGE AllowAmbiguousTypes #-}
12
{-# LANGUAGE FlexibleContexts #-}
23
{-# LANGUAGE FlexibleInstances #-}
34
{-# LANGUAGE FunctionalDependencies #-}
5+
{-# LANGUAGE ScopedTypeVariables #-}
6+
{-# LANGUAGE TypeApplications #-}
47
{-# LANGUAGE TypeFamilies #-}
58
{-# LANGUAGE TypeOperators #-}
69
{-# LANGUAGE UndecidableInstances #-}
@@ -56,9 +59,15 @@ module DataFrame.Internal.Nullable (
5659

5760
-- * Result-type type family for comparison operators
5861
NullCmpResult,
62+
63+
-- * Numeric widening
64+
NumericWidenOp (..),
65+
widenArithOp,
66+
WidenResult,
5967
) where
6068

6169
import DataFrame.Internal.Column (Columnable)
70+
import DataFrame.Internal.Types (Promote)
6271

6372
{- | Strip one layer of 'Maybe'.
6473
@@ -337,3 +346,47 @@ instance
337346
applyNull2 _ Nothing _ = Nothing
338347
applyNull2 _ _ Nothing = Nothing
339348
applyNull2 f (Just x) (Just y) = Just (f x y)
349+
350+
-- ---------------------------------------------------------------------------
351+
-- Numeric widening
352+
-- ---------------------------------------------------------------------------
353+
354+
{- | Widen two numeric base types to their promoted common type.
355+
356+
When @a ~ b@ the coercions are identity; otherwise one operand is widened
357+
(e.g. 'Int' → 'Double').
358+
-}
359+
class (Columnable (Promote a b)) => NumericWidenOp a b where
360+
widen1 :: a -> Promote a b
361+
widen2 :: b -> Promote a b
362+
363+
-- | Same type: identity coercions.
364+
instance {-# OVERLAPPING #-} (Columnable a) => NumericWidenOp a a where
365+
widen1 = id
366+
widen2 = id
367+
368+
instance NumericWidenOp Int Double where widen1 = fromIntegral; widen2 = id
369+
instance NumericWidenOp Double Int where
370+
widen1 = id
371+
widen2 = fromIntegral
372+
instance NumericWidenOp Float Double where widen1 = realToFrac; widen2 = id
373+
instance NumericWidenOp Double Float where
374+
widen1 = id
375+
widen2 = realToFrac
376+
instance NumericWidenOp Int Float where widen1 = fromIntegral; widen2 = id
377+
instance NumericWidenOp Float Int where
378+
widen1 = id
379+
widen2 = fromIntegral
380+
381+
-- | Apply an arithmetic function after widening both operands to their common type.
382+
widenArithOp ::
383+
forall a b.
384+
(NumericWidenOp a b) =>
385+
(Promote a b -> Promote a b -> Promote a b) ->
386+
a ->
387+
b ->
388+
Promote a b
389+
widenArithOp f x y = f (widen1 @a @b x) (widen2 @a @b y)
390+
391+
-- | Result type of a widening binary operator, accounting for nullable wrappers.
392+
type WidenResult a b = NullLift2Result a b (Promote (BaseType a) (BaseType b))

src/DataFrame/Internal/Types.hs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,18 @@ sFloating :: forall a. (SBoolI (FloatingTypes a)) => SBool (FloatingTypes a)
127127
sFloating = sbool @(FloatingTypes a)
128128

129129
type FloatingIf a = When (FloatingTypes a) (Real a, Fractional a)
130+
131+
{- | Numeric type promotion: resolves the common type for mixed arithmetic.
132+
Double dominates over Float/Int; Float dominates over Int; same types stay unchanged.
133+
-}
134+
type family Promote (a :: Type) (b :: Type) :: Type where
135+
Promote a a = a
136+
Promote Double _ = Double
137+
Promote _ Double = Double
138+
Promote Float _ = Float
139+
Promote _ Float = Float
140+
Promote Int64 _ = Int64
141+
Promote _ Int64 = Int64
142+
Promote Int32 _ = Int32
143+
Promote _ Int32 = Int32
144+
Promote a _ = a

src/DataFrame/Operators.hs

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ import DataFrame.Internal.Expression (
2222
import DataFrame.Internal.Nullable (
2323
BaseType,
2424
NullCmpResult,
25-
NullableArithOp (nullArithOp),
25+
NullLift2Op (applyNull2),
2626
NullableCmpOp (nullCmpOp),
27+
NumericWidenOp,
28+
WidenResult,
29+
widenArithOp,
2730
)
31+
import DataFrame.Internal.Types (Promote)
2832

2933
infix 8 .^^
3034
infix 6 .+, .-
@@ -64,14 +68,17 @@ lit = Lit
6468
@col \@Int "x" .+ col \@(Maybe Int) "y" -- :: Expr (Maybe Int)@
6569
-}
6670
(.+) ::
67-
(NullableArithOp a b c, Num (BaseType a)) =>
71+
( NumericWidenOp (BaseType a) (BaseType b)
72+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
73+
, Num (Promote (BaseType a) (BaseType b))
74+
) =>
6875
Expr a ->
6976
Expr b ->
70-
Expr c
77+
Expr (WidenResult a b)
7178
(.+) =
7279
Binary
7380
( MkBinaryOp
74-
{ binaryFn = nullArithOp (+)
81+
{ binaryFn = applyNull2 (widenArithOp (+))
7582
, binaryName = "nulladd"
7683
, binarySymbol = Just "+"
7784
, binaryCommutative = True
@@ -81,14 +88,17 @@ lit = Lit
8188

8289
-- | Nullable-aware subtraction.
8390
(.-) ::
84-
(NullableArithOp a b c, Num (BaseType a)) =>
91+
( NumericWidenOp (BaseType a) (BaseType b)
92+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
93+
, Num (Promote (BaseType a) (BaseType b))
94+
) =>
8595
Expr a ->
8696
Expr b ->
87-
Expr c
97+
Expr (WidenResult a b)
8898
(.-) =
8999
Binary
90100
( MkBinaryOp
91-
{ binaryFn = nullArithOp (-)
101+
{ binaryFn = applyNull2 (widenArithOp (-))
92102
, binaryName = "nullsub"
93103
, binarySymbol = Just "-"
94104
, binaryCommutative = False
@@ -98,14 +108,17 @@ lit = Lit
98108

99109
-- | Nullable-aware multiplication.
100110
(.*) ::
101-
(NullableArithOp a b c, Num (BaseType a)) =>
111+
( NumericWidenOp (BaseType a) (BaseType b)
112+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
113+
, Num (Promote (BaseType a) (BaseType b))
114+
) =>
102115
Expr a ->
103116
Expr b ->
104-
Expr c
117+
Expr (WidenResult a b)
105118
(.*) =
106119
Binary
107120
( MkBinaryOp
108-
{ binaryFn = nullArithOp (*)
121+
{ binaryFn = applyNull2 (widenArithOp (*))
109122
, binaryName = "nullmul"
110123
, binarySymbol = Just "*"
111124
, binaryCommutative = True
@@ -115,14 +128,17 @@ lit = Lit
115128

116129
-- | Nullable-aware division.
117130
(./) ::
118-
(NullableArithOp a b c, Fractional (BaseType a)) =>
131+
( NumericWidenOp (BaseType a) (BaseType b)
132+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
133+
, Fractional (Promote (BaseType a) (BaseType b))
134+
) =>
119135
Expr a ->
120136
Expr b ->
121-
Expr c
137+
Expr (WidenResult a b)
122138
(./) =
123139
Binary
124140
( MkBinaryOp
125-
{ binaryFn = nullArithOp (/)
141+
{ binaryFn = applyNull2 (widenArithOp (/))
126142
, binaryName = "nulldiv"
127143
, binarySymbol = Just "/"
128144
, binaryCommutative = False
@@ -258,16 +274,44 @@ lit = Lit
258274
}
259275
)
260276

261-
(.^^) :: (Columnable a, Num a) => Expr a -> Int -> Expr a
262-
(.^^) expr i =
277+
(.^^) ::
278+
( Columnable (BaseType a)
279+
, Columnable (BaseType b)
280+
, Fractional (BaseType a)
281+
, Integral (BaseType b)
282+
, NumericWidenOp (BaseType a) (BaseType b)
283+
, NullLift2Op a b (BaseType a) a
284+
, Num (Promote (BaseType a) (BaseType b))
285+
) =>
286+
Expr a -> Expr b -> Expr a
287+
(.^^) =
263288
Binary
264289
( MkBinaryOp
265-
{ binaryFn = (^)
290+
{ binaryFn = applyNull2 (^^)
266291
, binaryName = "pow"
267-
, binarySymbol = Just "^"
292+
, binarySymbol = Just "^^"
293+
, binaryCommutative = False
294+
, binaryPrecedence = 8
295+
}
296+
)
297+
298+
(.^) ::
299+
( Columnable (BaseType a)
300+
, Columnable (BaseType b)
301+
, Num (BaseType a)
302+
, Integral (BaseType b)
303+
, NumericWidenOp (BaseType a) (BaseType b)
304+
, NullLift2Op a b (BaseType a) a
305+
, Num (Promote (BaseType a) (BaseType b))
306+
) =>
307+
Expr a -> Expr b -> Expr a
308+
(.^) =
309+
Binary
310+
( MkBinaryOp
311+
{ binaryFn = applyNull2 (^)
312+
, binaryName = "pow"
313+
, binarySymbol = Just "^^"
268314
, binaryCommutative = False
269315
, binaryPrecedence = 8
270316
}
271317
)
272-
expr
273-
(Lit i)

src/DataFrame/Typed/Expr.hs

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,12 @@ import DataFrame.Internal.Nullable (
129129
NullLift1Result,
130130
NullLift2Op (applyNull2),
131131
NullLift2Result,
132-
NullableArithOp (nullArithOp),
133132
NullableCmpOp (nullCmpOp),
133+
NumericWidenOp,
134+
WidenResult,
135+
widenArithOp,
134136
)
137+
import DataFrame.Internal.Types (Promote)
135138

136139
import DataFrame.Typed.Schema (AssertPresent, Lookup)
137140
import DataFrame.Typed.Types (TExpr (..), TSortOrder (..))
@@ -292,39 +295,71 @@ infix 4 .==, ./=, .<, .<=, .>=, .>
292295
@col \@\"x\" '.+' col \@\"y\" -- :: TExpr cols (Maybe Int) when y :: Maybe Int@
293296
-}
294297
(.+) ::
295-
(NullableArithOp a b c, Num (BaseType a)) =>
298+
( NumericWidenOp (BaseType a) (BaseType b)
299+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
300+
, Num (Promote (BaseType a) (BaseType b))
301+
) =>
296302
TExpr cols a ->
297303
TExpr cols b ->
298-
TExpr cols c
304+
TExpr cols (WidenResult a b)
299305
(.+) (TExpr a) (TExpr b) =
300-
TExpr (Binary (MkBinaryOp (nullArithOp (+)) "nulladd" (Just "+") True 6) a b)
306+
TExpr
307+
( Binary
308+
(MkBinaryOp (applyNull2 (widenArithOp (+))) "nulladd" (Just "+") True 6)
309+
a
310+
b
311+
)
301312

302313
-- | Nullable-aware subtraction.
303314
(.-) ::
304-
(NullableArithOp a b c, Num (BaseType a)) =>
315+
( NumericWidenOp (BaseType a) (BaseType b)
316+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
317+
, Num (Promote (BaseType a) (BaseType b))
318+
) =>
305319
TExpr cols a ->
306320
TExpr cols b ->
307-
TExpr cols c
321+
TExpr cols (WidenResult a b)
308322
(.-) (TExpr a) (TExpr b) =
309-
TExpr (Binary (MkBinaryOp (nullArithOp (-)) "nullsub" (Just "-") False 6) a b)
323+
TExpr
324+
( Binary
325+
(MkBinaryOp (applyNull2 (widenArithOp (-))) "nullsub" (Just "-") False 6)
326+
a
327+
b
328+
)
310329

311330
-- | Nullable-aware multiplication.
312331
(.*) ::
313-
(NullableArithOp a b c, Num (BaseType a)) =>
332+
( NumericWidenOp (BaseType a) (BaseType b)
333+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
334+
, Num (Promote (BaseType a) (BaseType b))
335+
) =>
314336
TExpr cols a ->
315337
TExpr cols b ->
316-
TExpr cols c
338+
TExpr cols (WidenResult a b)
317339
(.*) (TExpr a) (TExpr b) =
318-
TExpr (Binary (MkBinaryOp (nullArithOp (*)) "nullmul" (Just "*") True 7) a b)
340+
TExpr
341+
( Binary
342+
(MkBinaryOp (applyNull2 (widenArithOp (*))) "nullmul" (Just "*") True 7)
343+
a
344+
b
345+
)
319346

320347
-- | Nullable-aware division.
321348
(./) ::
322-
(NullableArithOp a b c, Fractional (BaseType a)) =>
349+
( NumericWidenOp (BaseType a) (BaseType b)
350+
, NullLift2Op a b (Promote (BaseType a) (BaseType b)) (WidenResult a b)
351+
, Fractional (Promote (BaseType a) (BaseType b))
352+
) =>
323353
TExpr cols a ->
324354
TExpr cols b ->
325-
TExpr cols c
355+
TExpr cols (WidenResult a b)
326356
(./) (TExpr a) (TExpr b) =
327-
TExpr (Binary (MkBinaryOp (nullArithOp (/)) "nulldiv" (Just "/") False 7) a b)
357+
TExpr
358+
( Binary
359+
(MkBinaryOp (applyNull2 (widenArithOp (/))) "nulldiv" (Just "/") False 7)
360+
a
361+
b
362+
)
328363

329364
-------------------------------------------------------------------------------
330365
-- Nullable-aware comparison operators (three-valued logic)

0 commit comments

Comments
 (0)