Skip to content

Commit a5af3eb

Browse files
committed
refactor: Reuse comparison operators in Subset.
1 parent 6ca55b4 commit a5af3eb

File tree

1 file changed

+16
-74
lines changed

1 file changed

+16
-74
lines changed

src/DataFrame/Operations/Subset.hs

Lines changed: 16 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import DataFrame.Internal.Interpreter
4444
import DataFrame.Operations.Core
4545
import DataFrame.Operations.Merge ()
4646
import DataFrame.Operations.Transformations (apply)
47+
import DataFrame.Operators
4748
import System.Random
4849
import Type.Reflection
4950
import Prelude hiding (filter, take)
@@ -345,23 +346,12 @@ sample :: (RandomGen g) => g -> Double -> DataFrame -> DataFrame
345346
sample pureGen p df =
346347
let
347348
rand = generateRandomVector pureGen (fst (dataframeDimensions df))
349+
cRand = col @Double "__rand__"
348350
in
349351
df
350-
& insertUnboxedVector "__rand__" rand
351-
& filterWhere
352-
( Binary
353-
( MkBinaryOp
354-
{ binaryFn = (>=)
355-
, binaryName = "geq"
356-
, binarySymbol = Just ">="
357-
, binaryCommutative = False
358-
, binaryPrecedence = 1
359-
}
360-
)
361-
(Col @Double "__rand__")
362-
(Lit (1 - p))
363-
)
364-
& exclude ["__rand__"]
352+
& insertUnboxedVector (name cRand) rand
353+
& filterWhere (cRand .>=. Lit (1 - p))
354+
& exclude [name cRand]
365355

366356
{- | Split a dataset into two. The first in the tuple gets a sample of p (0 <= p <= 1) and the second gets (1 - p). This is useful for creating test and train splits.
367357
@@ -377,38 +367,16 @@ randomSplit ::
377367
randomSplit pureGen p df =
378368
let
379369
rand = generateRandomVector pureGen (fst (dataframeDimensions df))
380-
withRand = df & insertUnboxedVector "__rand__" rand
370+
cRand = col @Double "__rand__"
371+
withRand = df & insertUnboxedVector (name cRand) rand
381372
in
382373
( withRand
383-
& filterWhere
384-
( Binary
385-
( MkBinaryOp
386-
{ binaryFn = (<=)
387-
, binaryName = "leq"
388-
, binarySymbol = Just "<="
389-
, binaryCommutative = False
390-
, binaryPrecedence = 1
391-
}
392-
)
393-
(Col @Double "__rand__")
394-
(Lit p)
395-
)
396-
& exclude ["__rand__"]
374+
& filterWhere (cRand .<=. Lit p)
375+
& exclude [name cRand]
397376
, withRand
398377
& filterWhere
399-
( Binary
400-
( MkBinaryOp
401-
{ binaryFn = (>)
402-
, binaryName = "gt"
403-
, binarySymbol = Just ">"
404-
, binaryCommutative = False
405-
, binaryPrecedence = 1
406-
}
407-
)
408-
(Col @Double "__rand__")
409-
(Lit p)
410-
)
411-
& exclude ["__rand__"]
378+
(cRand .>. Lit p)
379+
& exclude [name cRand]
412380
)
413381

414382
{- | Creates n folds of a dataframe.
@@ -424,46 +392,20 @@ kFolds :: (RandomGen g) => g -> Int -> DataFrame -> [DataFrame]
424392
kFolds pureGen folds df =
425393
let
426394
rand = generateRandomVector pureGen (fst (dataframeDimensions df))
427-
withRand = df & insertUnboxedVector "__rand__" rand
395+
cRand = col @Double "__rand__"
396+
withRand = df & insertUnboxedVector (name cRand) rand
428397
partitionSize = 1 / fromIntegral folds
429398
singleFold n d =
430-
d
431-
& filterWhere
432-
( Binary
433-
( MkBinaryOp
434-
{ binaryFn = (>=)
435-
, binaryName = "geq"
436-
, binarySymbol = Just ">="
437-
, binaryCommutative = False
438-
, binaryPrecedence = 1
439-
}
440-
)
441-
(Col @Double "__rand__")
442-
(Lit (fromIntegral n * partitionSize))
443-
)
399+
d & filterWhere (cRand .>=. Lit (fromIntegral n * partitionSize))
444400
go (-1) _ = []
445401
go n d =
446402
let
447403
d' = singleFold n d
448-
d'' =
449-
d
450-
& filterWhere
451-
( Binary
452-
( MkBinaryOp
453-
{ binaryFn = (<)
454-
, binaryName = "lt"
455-
, binarySymbol = Just "<"
456-
, binaryCommutative = False
457-
, binaryPrecedence = 1
458-
}
459-
)
460-
(Col @Double "__rand__")
461-
(Lit (fromIntegral n * partitionSize))
462-
)
404+
d'' = d & filterWhere (cRand .<. Lit (fromIntegral n * partitionSize))
463405
in
464406
d' : go (n - 1) d''
465407
in
466-
map (exclude ["__rand__"]) (go (folds - 1) withRand)
408+
map (exclude [name cRand]) (go (folds - 1) withRand)
467409

468410
generateRandomVector :: (RandomGen g) => g -> Int -> VU.Vector Double
469411
generateRandomVector pureGen k = VU.fromList $ go pureGen k

0 commit comments

Comments
 (0)