Skip to content

Commit 19bf351

Browse files
authored
Merge pull request #111 from tomsmeding/ptx-fix-deviceptr
ptx: Retain arrays until kernel is done executing
2 parents 3025e30 + cd94310 commit 19bf351

23 files changed

Lines changed: 306 additions & 190 deletions

File tree

accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute.hs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import qualified Data.Array.Accelerate.LLVM.Native.Debug as Debug
5353

5454
import Control.Concurrent ( myThreadId )
5555
import Control.Concurrent.Extra ( getThreadId )
56-
import Control.Monad.State ( gets )
56+
import Control.Monad.Reader ( asks )
5757
import Control.Monad.Trans ( liftIO )
5858
import Data.ByteString.Short ( ShortByteString )
5959
import Data.IORef ( newIORef, readIORef, writeIORef )
@@ -139,7 +139,7 @@ simpleOp
139139
simpleOp name repr NativeR{..} gamma aenv sh = do
140140
let fun = nativeExecutable !# name
141141
param = TupRsingle $ ParamRarray repr
142-
Native{..} <- gets llvmTarget
142+
Native{..} <- asks llvmTarget
143143
future <- new
144144
result <- allocateRemote repr sh
145145
scheduleOp fun gamma aenv (arrayRshape repr) sh param result
@@ -167,7 +167,7 @@ mapOp inplace repr tp NativeR{..} gamma aenv input = do
167167
shr = arrayRshape repr
168168
repr' = ArrayR shr tp
169169
param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr)
170-
Native{..} <- gets llvmTarget
170+
Native{..} <- asks llvmTarget
171171
future <- new
172172
result <- case inplace of
173173
Just Refl -> return input
@@ -201,7 +201,7 @@ transformOp
201201
-> Par Native (Future (Array sh' b))
202202
transformOp repr repr' NativeR{..} gamma aenv sh' input = do
203203
let fun = nativeExecutable !# "transform"
204-
Native{..} <- gets llvmTarget
204+
Native{..} <- asks llvmTarget
205205
future <- new
206206
result <- allocateRemote repr' sh'
207207
let param = TupRsingle (ParamRarray repr') `TupRpair` TupRsingle (ParamRarray repr)
@@ -300,7 +300,7 @@ foldAllOp
300300
-> Delayed (Vector e)
301301
-> Par Native (Future (Scalar e))
302302
foldAllOp tp NativeR{..} gamma aenv arr = do
303-
Native{..} <- gets llvmTarget
303+
Native{..} <- asks llvmTarget
304304
future <- new
305305
result <- allocateRemote (ArrayR dim0 tp) ()
306306
let
@@ -343,7 +343,7 @@ foldDimOp
343343
-> Delayed (Array (sh, Int) e)
344344
-> Par Native (Future (Array sh e))
345345
foldDimOp repr NativeR{..} gamma aenv arr@(delayedShape -> (sh, _)) = do
346-
Native{..} <- gets llvmTarget
346+
Native{..} <- asks llvmTarget
347347
future <- new
348348
result <- allocateRemote repr sh
349349
let
@@ -371,7 +371,7 @@ foldSegOp
371371
-> Delayed (Segments i)
372372
-> Par Native (Future (Array (sh, Int) e))
373373
foldSegOp iR repr NativeR{..} gamma aenv input@(delayedShape -> (sh, _)) segments@(delayedShape -> ((), ss)) = do
374-
Native{..} <- gets llvmTarget
374+
Native{..} <- asks llvmTarget
375375
future <- new
376376
let
377377
n = ss-1
@@ -428,7 +428,7 @@ scanCore
428428
-> Delayed (Array (sh, Int) e)
429429
-> Par Native (Future (Array (sh, Int) e))
430430
scanCore repr NativeR{..} gamma aenv m input@(delayedShape -> (sz, n)) = do
431-
Native{..} <- gets llvmTarget
431+
Native{..} <- asks llvmTarget
432432
future <- new
433433
result <- allocateRemote repr (sz, m)
434434
--
@@ -527,7 +527,7 @@ scan'Core repr NativeR{..} gamma aenv input@(delayedShape -> sh@(sz, n)) = do
527527
paramA = TupRsingle $ ParamRarray repr
528528
paramA' = TupRsingle $ ParamRarray repr'
529529
--
530-
Native{..} <- gets llvmTarget
530+
Native{..} <- asks llvmTarget
531531
future <- new
532532
result <- allocateRemote repr sh
533533
sums <- allocateRemote repr' sz
@@ -608,7 +608,7 @@ permuteOp inplace repr shr' NativeR{..} gamma aenv defaults@(shape -> shOut) inp
608608
let
609609
ArrayR shr tp = repr
610610
repr' = ArrayR shr' tp
611-
Native{..} <- gets llvmTarget
611+
Native{..} <- asks llvmTarget
612612
future <- new
613613
result <- if inplace
614614
then Debug.trace Debug.dump_exec "exec: permute/inplace" $ return defaults
@@ -701,7 +701,7 @@ stencilCore
701701
-> params
702702
-> Par Native (Future (Array sh e))
703703
stencilCore repr NativeR{..} gamma aenv halo sh paramsR params = do
704-
Native{..} <- gets llvmTarget
704+
Native{..} <- asks llvmTarget
705705
future <- new
706706
result <- allocateRemote repr sh
707707
let
@@ -815,7 +815,7 @@ scheduleOp
815815
-> Maybe Action
816816
-> Par Native ()
817817
scheduleOp fun gamma aenv shr sz paramsR params done = do
818-
Native{..} <- gets llvmTarget
818+
Native{..} <- asks llvmTarget
819819
let
820820
splits = numWorkers workers - 1
821821
minsize = case shr of
@@ -842,7 +842,7 @@ scheduleOpWith
842842
-> Maybe Action -- run after the last piece completes
843843
-> Par Native ()
844844
scheduleOpWith splits minsize fun gamma aenv shr sz paramsR params done = do
845-
Native{..} <- gets llvmTarget
845+
Native{..} <- asks llvmTarget
846846
job <- mkJob splits minsize fun gamma aenv shr (empty shr) sz paramsR params done
847847
liftIO $ schedule workers job
848848

@@ -858,7 +858,7 @@ scheduleOpUsing
858858
-> Maybe Action
859859
-> Par Native ()
860860
scheduleOpUsing ranges fun gamma aenv shr paramsR params jobDone = do
861-
Native{..} <- gets llvmTarget
861+
Native{..} <- asks llvmTarget
862862
job <- mkJobUsing ranges fun gamma aenv shr paramsR params jobDone
863863
liftIO $ schedule workers job
864864

@@ -919,7 +919,7 @@ mkTasksUsing
919919
-> params
920920
-> Par Native (Seq Action)
921921
mkTasksUsing ranges (name, f) gamma aenv shr paramsR params = do
922-
arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
922+
(arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
923923
return $ flip fmap ranges $ \(_,u,v) -> do
924924
sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v)
925925
let argU = marshalShape' @Native shr u
@@ -937,7 +937,7 @@ mkTasksUsingIndex
937937
-> params
938938
-> Par Native (Seq Action)
939939
mkTasksUsingIndex ranges (name, f) gamma aenv shr paramsR params = do
940-
arg <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
940+
(arg, ()) <- marshalParams' @Native (paramsR `TupRpair` TupRsingle (ParamRenv gamma)) (params, aenv)
941941
return $ flip fmap ranges $ \(i,u,v) -> do
942942
sched (string % " " % parenthesised string % " -> " % parenthesised string) (S8.unpack name) (showShape shr u) (showShape shr v)
943943
let argU = marshalShape' @Native shr u

accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Async.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.State
3333
-- standard library
3434
import Control.Concurrent
3535
import Control.Monad.Cont
36-
import Control.Monad.State
36+
import Control.Monad.Reader
3737
import Data.IORef
3838
import Data.Sequence ( Seq )
3939
import qualified Data.Sequence as Seq
@@ -78,7 +78,7 @@ data IVar a
7878
instance Async Native where
7979
type FutureR Native = Future
8080
newtype Par Native a = Par { runPar :: ContT () (LLVM Native) a }
81-
deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadState Native )
81+
deriving ( Functor, Applicative, Monad, MonadIO, MonadCont, MonadReader Native )
8282

8383
{-# INLINE new #-}
8484
{-# INLINE newFull #-}
@@ -93,7 +93,7 @@ instance Async Native where
9393
{-# INLINE get #-}
9494
get (Future ref) =
9595
callCC $ \k -> do
96-
native <- gets llvmTarget
96+
native <- asks llvmTarget
9797
next <- liftIO . atomicModifyIORef' ref $ \case
9898
Empty -> (Blocked (Seq.singleton (evalParIO native . k)), reschedule)
9999
Blocked ks -> (Blocked (ks Seq.|> evalParIO native . k), reschedule)
@@ -102,7 +102,7 @@ instance Async Native where
102102

103103
{-# INLINE put #-}
104104
put future ref = do
105-
Native{..} <- gets llvmTarget
105+
Native{..} <- asks llvmTarget
106106
liftIO (putIO workers future ref)
107107

108108
{-# INLINE liftPar #-}

accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Execute/Marshal.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
{-# LANGUAGE GADTs #-}
66
{-# LANGUAGE MultiParamTypeClasses #-}
77
{-# LANGUAGE TemplateHaskell #-}
8+
{-# LANGUAGE TupleSections #-}
89
{-# LANGUAGE TypeApplications #-}
910
{-# LANGUAGE TypeFamilies #-}
1011
{-# OPTIONS_GHC -fno-warn-orphans #-}
@@ -34,9 +35,10 @@ import qualified Foreign.LibFFI as FFI
3435

3536
instance Marshal Native where
3637
type ArgR Native = FFI.Arg
38+
type MarshalCleanup Native = ()
3739
marshalInt = $( case finiteBitSize (undefined::Int) of
3840
32 -> [| FFI.argInt32 . fromIntegral |]
3941
64 -> [| FFI.argInt64 . fromIntegral |]
4042
_ -> error "I don't know what architecture I am" )
41-
marshalScalarData' _ = return . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr
43+
marshalScalarData' _ = return . (,()) . DL.singleton . FFI.argPtr . unsafeUniqueArrayPtr
4244

accelerate-llvm-native/src/Data/Array/Accelerate/LLVM/Native/Link.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Data.Array.Accelerate.LLVM.Native.Link.Cache
3333
import Data.Array.Accelerate.LLVM.Native.Link.Object
3434
import Data.Array.Accelerate.LLVM.Native.Link.Runtime
3535

36-
import Control.Monad.State
36+
import Control.Monad.Reader
3737
import Prelude hiding ( lookup )
3838

3939

@@ -48,7 +48,7 @@ instance Link Native where
4848
--
4949
link :: ObjectR Native -> LLVM Native (ExecutableR Native)
5050
link (ObjectR uid nms _ so) = do
51-
cache <- gets linkCache
51+
cache <- asks linkCache
5252
funs <- liftIO $ dlsym uid cache (loadSharedObject nms so)
5353
return $! NativeR funs
5454

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Data.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim as Prim
4444

4545
import Control.Applicative
4646
import Control.Monad
47-
import Control.Monad.Reader
48-
import Control.Monad.State ( gets )
47+
import Control.Monad.IO.Class ( liftIO )
48+
import Control.Monad.Reader ( asks )
4949
import System.IO.Unsafe
5050
import Prelude
5151

@@ -99,7 +99,7 @@ copyToHostLazy (TupRpair r1 r2) (f1, f2) = do
9999
a2 <- copyToHostLazy r2 f2
100100
return (a1, a2)
101101
copyToHostLazy (TupRsingle (ArrayR shr tp)) future = do
102-
ptx <- gets llvmTarget
102+
ptx <- asks llvmTarget
103103
liftIO $ do
104104
Array sh adata <- wait future
105105

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Prim.hs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ pokeArrayAsync !t !n !ad
116116
let !src = CUDA.HostPtr (unsafeUniqueArrayPtr ad)
117117
!bytes = n * bytesElt (TupRsingle (SingleScalarType t))
118118
--
119-
stream <- asks ptxStream
119+
stream <- asksParState ptxStream
120120
result <- liftPar $
121121
withLifetime stream $ \st ->
122122
withDevicePtr t ad $ \dst ->
@@ -150,7 +150,7 @@ indexArrayAsync !n !t !ad_src !i
150150
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
151151
!dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad_dst)
152152
--
153-
stream <- asks ptxStream
153+
stream <- asksParState ptxStream
154154
result <- liftPar $
155155
withLifetime stream $ \st ->
156156
withDevicePtr t ad_src $ \src ->
@@ -179,7 +179,7 @@ peekArrayAsync !t !n !ad
179179
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
180180
!dst = CUDA.HostPtr (unsafeUniqueArrayPtr ad)
181181
--
182-
stream <- asks ptxStream
182+
stream <- asksParState ptxStream
183183
result <- liftPar $
184184
withLifetime stream $ \st ->
185185
withDevicePtr t ad $ \src ->
@@ -208,7 +208,7 @@ copyArrayAsync !t !n !ad_src !ad_dst
208208
= do
209209
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
210210
--
211-
stream <- asks ptxStream
211+
stream <- asksParState ptxStream
212212
result <- liftPar $
213213
withLifetime stream $ \st ->
214214
withDevicePtr t ad_src $ \src ->
@@ -287,7 +287,7 @@ memsetArrayAsync !t !n !v !ad
287287
= do
288288
let !bytes = n * bytesElt (TupRsingle (SingleScalarType t))
289289
--
290-
stream <- asks ptxStream
290+
stream <- asksParState ptxStream
291291
result <- liftPar $
292292
withLifetime stream $ \st ->
293293
withDevicePtr t ad $ \ptr ->
@@ -350,7 +350,7 @@ nonblocking !stream !action = do
350350
return (Nothing, future)
351351

352352
else do
353-
future <- Future <$> liftIO (newIORef (Pending event Nothing result))
353+
future <- Future <$> liftIO (newIORef (Pending event (return ()) result))
354354
return (Just event, future)
355355

356356
{-# INLINE withLifetime #-}

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/Array/Remote.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import qualified Foreign.CUDA.Driver as CUDA
4343
import qualified Foreign.CUDA.Driver.Stream as CUDA
4444

4545
import Control.Exception
46-
import Control.Monad.State
46+
import Control.Monad.Reader
4747
import Data.Text.Lazy.Builder
4848
import Formatting hiding ( bytes )
4949
import qualified Formatting as F
@@ -63,7 +63,7 @@ instance Remote.RemoteMemory (LLVM PTX) where
6363
mallocRemote n
6464
| n <= 0 = return (Just CUDA.nullDevPtr)
6565
| otherwise = do
66-
name <- gets ptxDeviceName
66+
name <- asks ptxDeviceName
6767
liftIO $ do
6868
ep <- try (CUDA.mallocArray n)
6969
case ep of
@@ -114,7 +114,7 @@ malloc
114114
-> Bool
115115
-> LLVM PTX Bool
116116
malloc !tp !ad !n !frozen = do
117-
PTX{..} <- gets llvmTarget
117+
PTX{..} <- asks llvmTarget
118118
Remote.malloc ptxMemoryTable tp ad frozen n
119119

120120

@@ -127,7 +127,7 @@ withRemote
127127
-> (CUDA.DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
128128
-> LLVM PTX (Maybe r)
129129
withRemote !tp !ad !f = do
130-
PTX{..} <- gets llvmTarget
130+
PTX{..} <- asks llvmTarget
131131
Remote.withRemote ptxMemoryTable tp ad f
132132

133133

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ import qualified Data.Array.Accelerate.LLVM.Internal.LLVMPretty as LP
9090

9191
import Control.Applicative
9292
import Control.Monad ( void )
93-
import Control.Monad.State ( gets )
93+
import Control.Monad.Reader ( asks )
9494
import Data.Bits
9595
import Data.Proxy
9696
import Data.String
@@ -139,7 +139,7 @@ laneMask_ge = specialPTXReg "llvm.nvvm.read.ptx.sreg.lanemask.ge"
139139
--
140140
warpId :: CodeGen PTX (Operands Int32)
141141
warpId = do
142-
dev <- liftCodeGen $ gets ptxDeviceProperties
142+
dev <- liftCodeGen $ asks ptxDeviceProperties
143143
tid <- threadIdx
144144
A.quot integralType tid (A.liftInt32 (P.fromIntegral (CUDA.warpSize dev)))
145145

@@ -245,7 +245,7 @@ __syncwarp = __syncwarp_mask (liftWord32 0xffffffff)
245245
__syncwarp_mask :: HasCallStack => Operands Word32 -> CodeGen PTX ()
246246
__syncwarp_mask mask = do
247247
llvmver <- getLLVMversion
248-
dev <- liftCodeGen $ gets ptxDeviceProperties
248+
dev <- liftCodeGen $ asks ptxDeviceProperties
249249
case (computeCapability dev >= Compute 7 0, llvmver >= 6) of
250250
(True, True) -> void $ call (Lam primType (op primType mask) (Body VoidType (Just Tail) "llvm.nvvm.bar.warp.sync")) [NoUnwind, NoDuplicate, Convergent]
251251
(True, False) -> internalError "LLVM-6.0 or above is required for Volta devices and later"
@@ -506,7 +506,7 @@ shfl_op
506506
-> Operands a -- value to give
507507
-> CodeGen PTX (Operands a) -- value received
508508
shfl_op sop t delta val = do
509-
dev <- liftCodeGen $ gets ptxDeviceProperties
509+
dev <- liftCodeGen $ asks ptxDeviceProperties
510510

511511
let
512512
-- The CUDA __shfl* instruction take an optional final parameter
@@ -762,7 +762,7 @@ makeOpenAcc
762762
-> CodeGen PTX ()
763763
-> CodeGen PTX (IROpenAcc PTX aenv a)
764764
makeOpenAcc uid name param kernel = do
765-
dev <- liftCodeGen $ gets ptxDeviceProperties
765+
dev <- liftCodeGen $ asks ptxDeviceProperties
766766
makeOpenAccWith (simpleLaunchConfig dev) uid name param kernel
767767

768768
-- | Create a single kernel program with the given launch analysis information.

accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Fold.hs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import LLVM.AST.Type.Representation
4646
import qualified Foreign.CUDA.Analysis as CUDA
4747

4848
import Control.Monad ( (>=>) )
49-
import Control.Monad.State ( gets )
49+
import Control.Monad.Reader ( asks )
5050
import Data.String ( fromString )
5151
import Data.Bits as P
5252
import Prelude as P
@@ -105,7 +105,7 @@ mkFoldAll
105105
-> MIRDelayed PTX aenv (Vector e) -- ^ input data
106106
-> CodeGen PTX (IROpenAcc PTX aenv (Scalar e))
107107
mkFoldAll uid aenv tp combine mseed macc = do
108-
dev <- liftCodeGen $ gets ptxDeviceProperties
108+
dev <- liftCodeGen $ asks ptxDeviceProperties
109109
foldr1 (+++) <$> sequence [ mkFoldAllS uid dev aenv tp combine mseed macc
110110
, mkFoldAllM1 uid dev aenv tp combine macc
111111
, mkFoldAllM2 uid dev aenv tp combine mseed
@@ -303,7 +303,7 @@ mkFoldDim
303303
-> MIRDelayed PTX aenv (Array (sh, Int) e) -- ^ input data
304304
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
305305
mkFoldDim uid aenv repr@(ArrayR shr tp) combine mseed marr = do
306-
dev <- liftCodeGen $ gets ptxDeviceProperties
306+
dev <- liftCodeGen $ asks ptxDeviceProperties
307307
--
308308
let
309309
(arrOut, paramOut) = mutableArray repr "out"

0 commit comments

Comments
 (0)