Skip to content

Commit 0b619be

Browse files
iohk-bors[bot]coot
andauthored
Merge #1862
1862: Added sendTo and recvFrom to Win32-network r=dcoutts a=coot This also includes `withSocketsDo` call in `withIOManager`. Fixes #1457 Co-authored-by: Marcin Szamotulski <[email protected]>
2 parents 690e825 + d20aa8a commit 0b619be

File tree

5 files changed

+259
-28
lines changed

5 files changed

+259
-28
lines changed

Win32-network/src/System/Win32/Async/IOManager.hsc

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,27 @@ instance Exception IOManagerError
127127
-- TODO: add a tracer which logs when `dequeueCompletionPackets' errors
128128
--
129129
withIOManager :: (IOCompletionPort -> IO r) -> IO r
130-
withIOManager k = do
131-
tid <- myThreadId
132-
bracket
133-
(createIOCompletionPort maxBound)
134-
closeIOCompletionPort
135-
$ \iocp -> do
136-
-- The 'c_GetQueuedCompletionStatus' is not interruptible so we
137-
-- cannot simply use `withAsync` pattern here (the main thread will
138-
-- deadlock when trying to kill the io-manager thread).
139-
-- But note that 'closeIOCopletionPort' will terminate the io-manager
140-
-- thread (we cover this scenario in the 'test_closeIOCP' test).
141-
_ <-
142-
forkOS
143-
$ void $ dequeueCompletionPackets iocp
144-
`catch`
145-
\(e :: IOException) -> do
146-
-- throw IOExceptoin's back to the thread which started 'IOManager'
147-
throwTo tid (IOManagerError e)
148-
throwIO e
149-
k iocp
130+
withIOManager k =
131+
Socket.withSocketsDo $ do
132+
tid <- myThreadId
133+
bracket
134+
(createIOCompletionPort maxBound)
135+
closeIOCompletionPort
136+
$ \iocp -> do
137+
-- The 'c_GetQueuedCompletionStatus' is not interruptible so we
138+
-- cannot simply use `withAsync` pattern here (the main thread will
139+
-- deadlock when trying to kill the io-manager thread).
140+
-- But note that 'closeIOCopletionPort' will terminate the io-manager
141+
-- thread (we cover this scenario in the 'test_closeIOCP' test).
142+
_ <-
143+
forkOS
144+
$ void $ dequeueCompletionPackets iocp
145+
`catch`
146+
\(e :: IOException) -> do
147+
-- throw IOExceptoin's back to the thread which started 'IOManager'
148+
throwTo tid (IOManagerError e)
149+
throwIO e
150+
k iocp
150151

151152

152153
data IOCompletionException

Win32-network/src/System/Win32/Async/Socket.hs

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
module System.Win32.Async.Socket
55
( sendBuf
6+
, sendBufTo
67
, recvBuf
8+
, recvBufFrom
79
, connect
810
, accept
911
) where
@@ -12,15 +14,20 @@ module System.Win32.Async.Socket
1214
import Control.Concurrent
1315
import Control.Exception
1416
import Data.Word
17+
import GHC.IO.Exception (IOErrorType(InvalidArgument))
18+
import System.IO.Error
1519

16-
import Foreign.Ptr (Ptr)
17-
import Foreign.Marshal.Alloc (alloca)
20+
import Foreign.Ptr (Ptr, castPtr)
21+
import Foreign.Marshal.Alloc (alloca, allocaBytes)
1822
import Foreign.Storable (Storable (poke))
1923

2024
import Network.Socket (Socket, SockAddr)
2125
import qualified Network.Socket as Socket
26+
import Network.Socket.Address (SocketAddress (..))
2227

2328
import System.Win32.Types
29+
import System.Win32.Mem (zeroMemory)
30+
2431
import System.Win32.Async.WSABuf
2532
import System.Win32.Async.IOData
2633
import System.Win32.Async.ErrCode
@@ -47,6 +54,30 @@ sendBuf sock buf size = Socket.withFdSocket sock $ \fd ->
4754
Left e -> return $ ErrorAsync (ErrorCode e)
4855
else return $ ErrorSync (WsaErrorCode errorCode) False
4956

57+
58+
sendBufTo :: SocketAddress sa
59+
=> Socket -- ^ Socket
60+
-> Ptr Word8 -- ^ data to send
61+
-> Int -- ^ size of the data
62+
-> sa -- ^ address to send to
63+
-> IO Int
64+
sendBufTo sock buf size sa =
65+
Socket.withFdSocket sock $ \fd ->
66+
withSocketAddress sa $ \sa_ptr sa_size->
67+
alloca $ \bufsPtr ->
68+
withIOCPData "sendBufTo" (FDSocket fd) $ \lpOverlapped waitVar -> do
69+
poke bufsPtr WSABuf {buf, len = fromIntegral size}
70+
sendResult <- c_WSASendTo fd bufsPtr 1 nullPtr 0 sa_ptr sa_size lpOverlapped nullPtr
71+
errorCode <- wsaGetLastError
72+
if sendResult == 0 || errorCode == wSA_IO_PENDING
73+
then do
74+
iocpResult <- takeMVar waitVar
75+
case iocpResult of
76+
Right numBytes -> return $ ResultAsync numBytes
77+
Left e -> return $ ErrorAsync (ErrorCode e)
78+
else return $ ErrorSync (WsaErrorCode errorCode) False
79+
80+
5081
-- | Unfortunatelly `connect` using interruptible ffi is not interruptible.
5182
-- Instead we run the `Socket.connect` in a dedicated thread and block on an
5283
-- 'MVar'.
@@ -102,3 +133,65 @@ recvBuf sock buf size =
102133
Right numBytes -> return $ ResultAsync numBytes
103134
Left e -> return $ ErrorAsync (ErrorCode e)
104135
else return $ ErrorSync (WsaErrorCode errorCode) False
136+
137+
138+
recvBufFrom :: SocketAddress sa => Socket -> Ptr Word8 -> Int -> IO (Int, sa)
139+
recvBufFrom _ _ size | size <= 0 =
140+
ioError (mkInvalidRecvArgError "System.Win32.Async.Socket.recvBufFrom")
141+
recvBufFrom sock buf size =
142+
Socket.withFdSocket sock $ \fd ->
143+
withIOCPData "recvBufFrom" (FDSocket fd) $ \lpOverlapped waitVar ->
144+
withNewSocketAddress $ \saPtr saSize ->
145+
alloca $ \saSizePtr ->
146+
alloca $ \wsaBufPtr ->
147+
alloca $ \lpFlags -> do
148+
poke saSizePtr (fromIntegral saSize)
149+
poke wsaBufPtr (WSABuf (fromIntegral size) buf)
150+
poke lpFlags 0
151+
recvResult <-
152+
c_WSARecvFrom fd wsaBufPtr 1
153+
nullPtr lpFlags
154+
saPtr saSizePtr
155+
lpOverlapped nullPtr
156+
errorCode <- wsaGetLastError
157+
if recvResult == 0 || errorCode == wSA_IO_PENDING
158+
then do
159+
iocpResult <- takeMVar waitVar
160+
case iocpResult of
161+
Right numBytes -> do
162+
-- if we catch IO exception and use `getPeerName` as the
163+
-- `network` package does it throws `WSAENOTCONN` exception,
164+
-- hiding the initial exception.
165+
sockAddr <- peekSocketAddress saPtr
166+
return $ ResultAsync (numBytes, sockAddr)
167+
Left e -> return $ ErrorAsync (ErrorCode e)
168+
else return $ ErrorSync (WsaErrorCode errorCode) False
169+
170+
171+
--
172+
-- Utils
173+
--
174+
175+
-- | Copied from `Network.Socket.Types.withSocketAddress`.
176+
--
177+
withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a
178+
withSocketAddress addr f = do
179+
let sz = sizeOfSocketAddress addr
180+
allocaBytes sz $ \p -> pokeSocketAddress p addr >> f (castPtr p) sz
181+
182+
-- sizeof(struct sockaddr_storage) which has enough space to contain
183+
-- sockaddr_in, sockaddr_in6 and sockaddr_un.
184+
sockaddrStorageLen :: Int
185+
sockaddrStorageLen = 128
186+
187+
-- | Copied from `Network.Socket.Types.withNewSocketAddress`.
188+
--
189+
withNewSocketAddress :: SocketAddress sa => (Ptr sa -> Int -> IO a) -> IO a
190+
withNewSocketAddress f = allocaBytes sockaddrStorageLen $ \ptr -> do
191+
zeroMemory ptr $ fromIntegral sockaddrStorageLen
192+
f ptr sockaddrStorageLen
193+
194+
mkInvalidRecvArgError :: String -> IOError
195+
mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
196+
InvalidArgument
197+
loc Nothing Nothing) "non-positive length"

Win32-network/src/System/Win32/Async/Socket/ByteString.hs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
module System.Win32.Async.Socket.ByteString
22
( send
33
, sendAll
4+
, sendTo
5+
, sendAllTo
46
, recv
7+
, recvFrom
58
) where
69

710
import Control.Exception
@@ -12,10 +15,11 @@ import qualified Data.ByteString as BS
1215
import qualified Data.ByteString.Internal as BS (createAndTrim)
1316
import qualified Data.ByteString.Unsafe as BS (unsafeUseAsCStringLen)
1417
import Foreign.Ptr (castPtr)
18+
import Foreign.Marshal.Alloc (allocaBytes)
1519
import GHC.IO.Exception (IOErrorType(..))
1620
import System.IO.Error (mkIOError, ioeSetErrorString)
1721

18-
import Network.Socket (Socket)
22+
import Network.Socket (Socket, SockAddr)
1923

2024
import System.Win32.Async.Socket
2125

@@ -44,6 +48,25 @@ sendAll sock bs = do
4448
$ sendAll sock (BS.drop sent bs)
4549

4650

51+
sendTo :: Socket
52+
-> ByteString
53+
-> SockAddr
54+
-> IO Int
55+
sendTo sock bs sa =
56+
BS.unsafeUseAsCStringLen bs $ \(str, size) ->
57+
sendBufTo sock (castPtr str) size sa
58+
59+
60+
sendAllTo :: Socket
61+
-> ByteString
62+
-> SockAddr
63+
-> IO ()
64+
sendAllTo _ bs _ | BS.null bs = return ()
65+
sendAllTo sock bs sa = do
66+
sent <- sendTo sock bs sa
67+
when (sent >= 0) $ sendAllTo sock (BS.drop sent bs) sa
68+
69+
4770
-- | Recv a 'ByteString' from a socket, which must be in a connected state, and
4871
-- must be associated with an IO completion port via
4972
-- 'System.Win32.Async.IOManager.associateWithIOCompletionProt'. It may return
@@ -58,3 +81,13 @@ recv _sock size | size <= 0 =
5881
(mkIOError InvalidArgument "System.Win32.Async.Socket.ByteString.recv" Nothing Nothing)
5982
"non-positive length"
6083
recv sock size = BS.createAndTrim size $ \ptr -> recvBuf sock ptr size
84+
85+
86+
recvFrom :: Socket -- ^ Socket
87+
-> Int -- ^ Maximum number of bytes to receive
88+
-> IO (ByteString, SockAddr) -- ^ Data received and sender address
89+
recvFrom sock size =
90+
allocaBytes size $ \ptr -> do
91+
(len, sockAddr) <- recvBufFrom sock (castPtr ptr) size
92+
str <- BS.packCStringLen (ptr, len)
93+
return (str, sockAddr)

Win32-network/src/System/Win32/Async/Socket/Syscalls.hs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
module System.Win32.Async.Socket.Syscalls
44
( SOCKET
55
, c_WSASend
6+
, c_WSASendTo
67
, c_WSARecv
8+
, c_WSARecvFrom
79
) where
810

911
import Foreign (Ptr)
@@ -28,13 +30,39 @@ foreign import ccall unsafe "WSASend"
2830
-> Ptr () -- ^ lpCompletionRouting
2931
-> IO CInt
3032

33+
34+
foreign import ccall unsafe "WSASendTo"
35+
c_WSASendTo :: SOCKET
36+
-> Ptr WSABuf -- ^ lpBuffers
37+
-> DWORD -- ^ dwBufferCount
38+
-> LPDWORD -- ^ lpNumberOfBytesSent
39+
-> DWORD -- ^ dwFlags
40+
-> Ptr sa -- ^ lpTo
41+
-> Int -- ^ iToLen (size in bytes of `lpTo`)
42+
-> LPWSAOVERLAPPED -- ^ lpOverlapped
43+
-> Ptr () -- ^ lpCompletionRouting
44+
-> IO CInt
45+
3146

3247
foreign import ccall unsafe "WSARecv"
3348
c_WSARecv :: SOCKET -- ^ socket
3449
-> Ptr WSABuf -- ^ lpBuffers
3550
-> DWORD -- ^ dwBufferCount
36-
-> LPDWORD -- ^ lpNumberOfBytesRecvd
51+
-> LPDWORD -- ^ lpNumberOfBytesReceived
3752
-> LPDWORD -- ^ lpFlags
3853
-> LPWSAOVERLAPPED -- ^ lpOverlapped
3954
-> Ptr () -- ^ lpCompletionRouting
4055
-> IO CInt
56+
57+
58+
foreign import ccall unsafe "WSARecvFrom"
59+
c_WSARecvFrom :: SOCKET -- ^ socket
60+
-> Ptr WSABuf -- ^ lpBuffers
61+
-> DWORD -- ^ dwBufferCount
62+
-> LPDWORD -- ^ lpNumberOfBytesReceived
63+
-> LPDWORD -- ^ lpFlags
64+
-> Ptr sa -- ^ lpFrom
65+
-> Ptr Int -- ^ iFromLen (size in bytes of `lpFrom`)
66+
-> LPWSAOVERLAPPED -- ^ lpOverlapped
67+
-> Ptr () -- ^ lpCompletionRouting
68+
-> IO CInt

0 commit comments

Comments
 (0)