Skip to content

Commit 9a26599

Browse files
committed
feat: Allow expression casting.
1 parent 2e70358 commit 9a26599

File tree

5 files changed

+220
-58
lines changed

5 files changed

+220
-58
lines changed

src/DataFrame/Functions.hs

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ import Data.Int
3636
import qualified Data.List as L
3737
import qualified Data.Map as M
3838
import qualified Data.Maybe as Maybe
39+
import qualified Data.Set as S
3940
import qualified Data.Text as T
4041
import Data.Time
4142
import qualified Data.Vector as V
4243
import qualified Data.Vector.Unboxed as VU
4344
import Data.Word
44-
import qualified Data.Set as S
4545
import qualified DataFrame.IO.CSV as CSV
4646
import qualified DataFrame.IO.Parquet as Parquet
4747
import DataFrame.IO.Parquet.Thrift
@@ -157,6 +157,32 @@ unsafeCast name =
157157
"unsafeCast"
158158
(fromRight (error "unsafeCast: unexpected Nothing in column"))
159159

160+
castExpr ::
161+
forall b src. (Columnable b, Columnable src) => Expr src -> Expr (Maybe b)
162+
castExpr = CastExprWith @b @(Maybe b) @src "castExpr" (either (const Nothing) Just)
163+
164+
castExprWithDefault ::
165+
forall b src. (Columnable b, Columnable src) => b -> Expr src -> Expr b
166+
castExprWithDefault def =
167+
CastExprWith @b @b @src
168+
("castExprWithDefault:" <> T.pack (show def))
169+
(fromRight def)
170+
171+
castExprEither ::
172+
forall b src.
173+
(Columnable b, Columnable src) => Expr src -> Expr (Either T.Text b)
174+
castExprEither =
175+
CastExprWith @b @(Either T.Text b) @src
176+
"castExprEither"
177+
(either (Left . T.pack) Right)
178+
179+
unsafeCastExpr ::
180+
forall b src. (Columnable b, Columnable src) => Expr src -> Expr b
181+
unsafeCastExpr =
182+
CastExprWith @b @b @src
183+
"unsafeCastExpr"
184+
(fromRight (error "unsafeCastExpr: unexpected Nothing in column"))
185+
160186
liftDecorated ::
161187
(Columnable a, Columnable b) =>
162188
(a -> b) -> T.Text -> Maybe T.Text -> Expr a -> Expr b
@@ -544,19 +570,22 @@ declareColumnsFromParquetFile path = do
544570
files <- liftIO $ filterM (fmap Prelude.not . doesDirectoryExist) matches
545571
metas <- liftIO $ mapM (fmap fst . Parquet.readMetadataFromPath) files
546572
let nullableCols :: S.Set T.Text
547-
nullableCols = S.fromList
548-
[ T.pack (last colPath)
549-
| meta <- metas
550-
, rg <- rowGroups meta
551-
, cc <- rowGroupColumns rg
552-
, let cm = columnMetaData cc
553-
colPath = columnPathInSchema cm
554-
, Prelude.not (null colPath)
555-
, columnNullCount (columnStatistics cm) > 0
556-
]
557-
let df = foldl (\acc meta -> acc <> schemaToEmptyDataFrame nullableCols (schema meta))
558-
DataFrame.Internal.DataFrame.empty
559-
metas
573+
nullableCols =
574+
S.fromList
575+
[ T.pack (last colPath)
576+
| meta <- metas
577+
, rg <- rowGroups meta
578+
, cc <- rowGroupColumns rg
579+
, let cm = columnMetaData cc
580+
colPath = columnPathInSchema cm
581+
, Prelude.not (null colPath)
582+
, columnNullCount (columnStatistics cm) > 0
583+
]
584+
let df =
585+
foldl
586+
(\acc meta -> acc <> schemaToEmptyDataFrame nullableCols (schema meta))
587+
DataFrame.Internal.DataFrame.empty
588+
metas
560589
declareColumns df
561590

562591
schemaToEmptyDataFrame :: S.Set T.Text -> [SchemaElement] -> DataFrame
@@ -566,11 +595,12 @@ schemaToEmptyDataFrame nullableCols elems =
566595

567596
schemaElemToColumn :: S.Set T.Text -> SchemaElement -> (T.Text, Column)
568597
schemaElemToColumn nullableCols elem =
569-
let name = elementName elem
598+
let name = elementName elem
570599
isNull = name `S.member` nullableCols
571-
col = if isNull
572-
then emptyNullableColumnForType (elementType elem)
573-
else emptyColumnForType (elementType elem)
600+
col =
601+
if isNull
602+
then emptyNullableColumnForType (elementType elem)
603+
else emptyColumnForType (elementType elem)
574604
in (name, col)
575605

576606
emptyColumnForType :: TType -> Column
@@ -588,16 +618,16 @@ emptyColumnForType = \case
588618

589619
emptyNullableColumnForType :: TType -> Column
590620
emptyNullableColumnForType = \case
591-
BOOL -> fromList @(Maybe Bool) []
592-
BYTE -> fromList @(Maybe Word8) []
593-
I16 -> fromList @(Maybe Int16) []
594-
I32 -> fromList @(Maybe Int32) []
595-
I64 -> fromList @(Maybe Int64) []
596-
I96 -> fromList @(Maybe Int64) []
597-
FLOAT -> fromList @(Maybe Float) []
621+
BOOL -> fromList @(Maybe Bool) []
622+
BYTE -> fromList @(Maybe Word8) []
623+
I16 -> fromList @(Maybe Int16) []
624+
I32 -> fromList @(Maybe Int32) []
625+
I64 -> fromList @(Maybe Int64) []
626+
I96 -> fromList @(Maybe Int64) []
627+
FLOAT -> fromList @(Maybe Float) []
598628
DOUBLE -> fromList @(Maybe Double) []
599629
STRING -> fromList @(Maybe T.Text) []
600-
other -> error $ "Unsupported parquet type for column: " <> show other
630+
other -> error $ "Unsupported parquet type for column: " <> show other
601631

602632
declareColumnsFromCsvWithOpts :: CSV.ReadOptions -> String -> DecsQ
603633
declareColumnsFromCsvWithOpts opts path = do

src/DataFrame/Internal/Expression.hs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ data Expr a where
6262
T.Text ->
6363
(Either String a -> b) ->
6464
Expr b
65+
CastExprWith ::
66+
(Columnable a, Columnable b, Columnable src) =>
67+
T.Text ->
68+
(Either String a -> b) ->
69+
Expr src ->
70+
Expr b
6571
Lit :: (Columnable a) => a -> Expr a
6672
Unary ::
6773
(Columnable a, Columnable b) => UnaryOp b a -> Expr b -> Expr a
@@ -235,6 +241,7 @@ instance (Show a) => Show (Expr a) where
235241
show :: forall a. (Show a) => Expr a -> String
236242
show (Col name) = "(col @" ++ show (typeRep @a) ++ " " ++ show name ++ ")"
237243
show (CastWith name tag _) = "(castWith " ++ show tag ++ " " ++ show name ++ ")"
244+
show (CastExprWith tag _ inner) = "(castExprWith " ++ show tag ++ " " ++ show inner ++ ")"
238245
show (Lit value) = "(lit (" ++ show value ++ "))"
239246
show (If cond l r) = "(ifThenElse " ++ show cond ++ " " ++ show l ++ " " ++ show r ++ ")"
240247
show (Unary op value) = "(" ++ T.unpack (unaryName op) ++ " " ++ show value ++ ")"
@@ -247,6 +254,7 @@ normalize :: (Eq a, Ord a, Show a, Typeable a) => Expr a -> Expr a
247254
normalize expr = case expr of
248255
Col name -> Col name
249256
CastWith n t f -> CastWith n t f
257+
CastExprWith t f e -> CastExprWith t f (normalize e)
250258
Lit val -> Lit val
251259
If cond th el -> If (normalize cond) (normalize th) (normalize el)
252260
Unary op e -> Unary op (normalize e)
@@ -270,6 +278,7 @@ compareExpr e1 e2 = compare (exprKey e1) (exprKey e2)
270278
exprKey :: Expr a -> String
271279
exprKey (Col name) = "0:" ++ T.unpack name
272280
exprKey (CastWith name tag _) = "0CW:" ++ T.unpack name ++ ":" ++ T.unpack tag
281+
exprKey (CastExprWith tag _ _) = "0CE:" ++ T.unpack tag
273282
exprKey (Lit val) = "1:" ++ show val
274283
exprKey (If c t e) = "2:" ++ exprKey c ++ exprKey t ++ exprKey e
275284
exprKey (Unary op e) = "3:" ++ T.unpack (unaryName op) ++ exprKey e
@@ -288,6 +297,7 @@ instance (Eq a, Columnable a) => Eq (Expr a) where
288297
eqNormalized :: Expr a -> Expr a -> Bool
289298
eqNormalized (Col n1) (Col n2) = n1 == n2
290299
eqNormalized (CastWith n1 t1 _) (CastWith n2 t2 _) = n1 == n2 && t1 == t2
300+
eqNormalized (CastExprWith t1 _ e1) (CastExprWith t2 _ e2) = t1 == t2 && e1 `exprEq` e2
291301
eqNormalized (Lit v1) (Lit v2) = v1 == v2
292302
eqNormalized (If c1 t1 e1) (If c2 t2 e2) =
293303
c1 == c2 && t1 `exprEq` t2 && e1 `exprEq` e2
@@ -306,6 +316,7 @@ instance (Ord a, Columnable a) => Ord (Expr a) where
306316
compare e1 e2 = case (e1, e2) of
307317
(Col n1, Col n2) -> compare n1 n2
308318
(CastWith n1 t1 _, CastWith n2 t2 _) -> compare n1 n2 <> compare t1 t2
319+
(CastExprWith t1 _ _, CastExprWith t2 _ _) -> compare t1 t2
309320
(Lit v1, Lit v2) -> compare v1 v2
310321
(If c1 t1 e1', If c2 t2 e2') ->
311322
compare c1 c2 <> exprComp t1 t2 <> exprComp e1' e2'
@@ -320,6 +331,8 @@ instance (Ord a, Columnable a) => Ord (Expr a) where
320331
(_, Col _) -> GT
321332
(CastWith{}, _) -> LT
322333
(_, CastWith{}) -> GT
334+
(CastExprWith{}, _) -> LT
335+
(_, CastExprWith{}) -> GT
323336
(Lit _, _) -> LT
324337
(_, Lit _) -> GT
325338
(Unary{}, _) -> LT
@@ -348,6 +361,7 @@ replaceExpr new old expr = case testEquality (typeRep @b) (typeRep @c) of
348361
replace' = case expr of
349362
(Col _) -> expr
350363
(CastWith{}) -> expr
364+
(CastExprWith t f e) -> CastExprWith t f (replaceExpr new old e)
351365
(Lit _) -> expr
352366
(If cond l r) ->
353367
If (replaceExpr new old cond) (replaceExpr new old l) (replaceExpr new old r)
@@ -358,6 +372,7 @@ replaceExpr new old expr = case testEquality (typeRep @b) (typeRep @c) of
358372
eSize :: Expr a -> Int
359373
eSize (Col _) = 1
360374
eSize (CastWith{}) = 1
375+
eSize (CastExprWith _ _ e) = 1 + eSize e
361376
eSize (Lit _) = 1
362377
eSize (If c l r) = 1 + eSize c + eSize l + eSize r
363378
eSize (Unary _ e) = 1 + eSize e
@@ -367,6 +382,7 @@ eSize (Agg strategy expr) = eSize expr + 1
367382
getColumns :: Expr a -> [T.Text]
368383
getColumns (Col cName) = [cName]
369384
getColumns (CastWith name _ _) = [name]
385+
getColumns (CastExprWith _ _ e) = getColumns e
370386
getColumns expr@(Lit _) = []
371387
getColumns (If cond l r) = getColumns cond <> getColumns l <> getColumns r
372388
getColumns (Unary op value) = getColumns value
@@ -383,6 +399,7 @@ prettyPrint = go 0 0
383399
go depth prec expr = case expr of
384400
Col name -> T.unpack name
385401
CastWith name _ _ -> T.unpack name
402+
CastExprWith tag _ inner -> T.unpack tag ++ "(" ++ go depth 0 inner ++ ")"
386403
Lit value -> show value
387404
If cond t e ->
388405
let inner =

src/DataFrame/Internal/Interpreter.hs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ promoteColumnWith ::
255255
promoteColumnWith onResult col
256256
| hasElemType @b col = Right col
257257
| hasElemType @a col = mapColumn @a (onResult . Right) col
258+
| Just result <- tryMaybeWrap @a @b onResult col = result
258259
| otherwise =
259260
case testEquality (typeRep @a) (typeRep @Double) of
260261
Just Refl -> promoteToDoubleWith onResult col
@@ -416,6 +417,39 @@ tryParseWith onResult col = case col of
416417
Nothing -> castMismatch @c @b
417418
UnboxedColumn (_ :: VU.Vector c) -> castMismatch @c @b
418419

420+
{- | When the output type @b@ is @Maybe c@ (or @Maybe (Maybe c)@) and the
421+
column stores plain @c@ values, wrap each element in 'Just'.
422+
The @Maybe (Maybe c)@ case applies join semantics: instead of producing
423+
a double-wrapped column, a @Maybe c@ column is returned, so
424+
@castExpr \@(Maybe Double)@ on a @Double@ column yields @Maybe Double@
425+
rather than @Maybe (Maybe Double)@.
426+
Returns 'Nothing' when neither condition holds.
427+
-}
428+
tryMaybeWrap ::
429+
forall a b.
430+
(Columnable a, Columnable b) =>
431+
(Either String a -> b) -> Column -> Maybe (Either DataFrameException Column)
432+
tryMaybeWrap _onResult col = case col of
433+
UnboxedColumn (v :: VU.Vector c) ->
434+
let wrapped = V.map Just (VG.convert v) :: V.Vector (Maybe c)
435+
in case testEquality (typeRep @b) (typeRep @(Maybe c)) of
436+
Just Refl -> Just $ Right $ fromVector @b wrapped
437+
Nothing ->
438+
-- join: b = Maybe (Maybe c) → produce Maybe c column
439+
case testEquality (typeRep @b) (typeRep @(Maybe (Maybe c))) of
440+
Just _ -> Just $ Right $ fromVector @(Maybe c) wrapped
441+
Nothing -> Nothing
442+
BoxedColumn (v :: V.Vector c) ->
443+
let wrapped = V.map Just v :: V.Vector (Maybe c)
444+
in case testEquality (typeRep @b) (typeRep @(Maybe c)) of
445+
Just Refl -> Just $ Right $ fromVector @b wrapped
446+
Nothing ->
447+
case testEquality (typeRep @b) (typeRep @(Maybe (Maybe c))) of
448+
Just _ -> Just $ Right $ fromVector @(Maybe c) wrapped
449+
Nothing -> Nothing
450+
-- OptionalColumn and NullableColumn are already handled by the hasElemType guards above.
451+
_ -> Nothing
452+
419453
castMismatch ::
420454
forall src tgt.
421455
(Typeable src, Typeable tgt) =>
@@ -483,6 +517,17 @@ eval (GroupCtx gdf) (CastWith name _tag onResult) =
483517
Just c -> do
484518
promoted <- promoteColumnWith onResult c
485519
Right $ Group (sliceGroups promoted (offsets gdf) (valueIndices gdf))
520+
-- CastExprWith -----------------------------------------------------------
521+
522+
eval ctx (CastExprWith _tag onResult (inner :: Expr src)) = do
523+
v <- eval @src ctx inner
524+
case v of
525+
Scalar s ->
526+
Flat <$> promoteColumnWith onResult (fromList @src [s])
527+
Flat col ->
528+
Flat <$> promoteColumnWith onResult col
529+
Group gs ->
530+
Group <$> V.mapM (promoteColumnWith onResult) gs
486531
-- Unary ------------------------------------------------------------------
487532

488533
eval ctx expr@(Unary (op :: UnaryOp b a) inner) = addContext expr $ do

src/DataFrame/Typed/Expr.hs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ module DataFrame.Typed.Expr (
9292
maximum,
9393
collect,
9494

95+
-- * Cast / coercion expressions
96+
castExpr,
97+
castExprWithDefault,
98+
castExprEither,
99+
unsafeCastExpr,
100+
95101
-- * Named expression helper
96102
as,
97103

@@ -100,6 +106,7 @@ module DataFrame.Typed.Expr (
100106
desc,
101107
) where
102108

109+
import Data.Either (fromRight)
103110
import Data.Proxy (Proxy (..))
104111
import Data.String (IsString (..))
105112
import qualified Data.Text as T
@@ -413,6 +420,50 @@ maximum (TExpr e) = TExpr (Agg (FoldAgg "maximum" Nothing max) e)
413420
collect :: (Columnable a) => TExpr cols a -> TExpr cols [a]
414421
collect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e)
415422

423+
-------------------------------------------------------------------------------
424+
-- Cast / coercion expressions
425+
-------------------------------------------------------------------------------
426+
427+
castExpr ::
428+
forall b cols src.
429+
(Columnable b, Columnable src) => TExpr cols src -> TExpr cols (Maybe b)
430+
castExpr (TExpr e) =
431+
TExpr
432+
(CastExprWith @b @(Maybe b) @src "castExpr" (either (const Nothing) Just) e)
433+
434+
castExprWithDefault ::
435+
forall b cols src.
436+
(Columnable b, Columnable src) => b -> TExpr cols src -> TExpr cols b
437+
castExprWithDefault def (TExpr e) =
438+
TExpr
439+
( CastExprWith @b @b @src
440+
("castExprWithDefault:" <> T.pack (show def))
441+
(fromRight def)
442+
e
443+
)
444+
445+
castExprEither ::
446+
forall b cols src.
447+
(Columnable b, Columnable src) => TExpr cols src -> TExpr cols (Either T.Text b)
448+
castExprEither (TExpr e) =
449+
TExpr
450+
( CastExprWith @b @(Either T.Text b) @src
451+
"castExprEither"
452+
(either (Left . T.pack) Right)
453+
e
454+
)
455+
456+
unsafeCastExpr ::
457+
forall b cols src.
458+
(Columnable b, Columnable src) => TExpr cols src -> TExpr cols b
459+
unsafeCastExpr (TExpr e) =
460+
TExpr
461+
( CastExprWith @b @b @src
462+
"unsafeCastExpr"
463+
(fromRight (error "unsafeCastExpr: unexpected Nothing in column"))
464+
e
465+
)
466+
416467
-------------------------------------------------------------------------------
417468
-- Named expression helper
418469
-------------------------------------------------------------------------------

0 commit comments

Comments
 (0)