3
3
4
4
module System.Win32.Async.Socket
5
5
( sendBuf
6
+ , sendBufTo
6
7
, recvBuf
8
+ , recvBufFrom
7
9
, connect
8
10
, accept
9
11
) where
@@ -12,15 +14,20 @@ module System.Win32.Async.Socket
12
14
import Control.Concurrent
13
15
import Control.Exception
14
16
import Data.Word
17
+ import GHC.IO.Exception (IOErrorType (InvalidArgument ))
18
+ import System.IO.Error
15
19
16
- import Foreign.Ptr (Ptr )
17
- import Foreign.Marshal.Alloc (alloca )
20
+ import Foreign.Ptr (Ptr , castPtr )
21
+ import Foreign.Marshal.Alloc (alloca , allocaBytes )
18
22
import Foreign.Storable (Storable (poke ))
19
23
20
24
import Network.Socket (Socket , SockAddr )
21
25
import qualified Network.Socket as Socket
26
+ import Network.Socket.Address (SocketAddress (.. ))
22
27
23
28
import System.Win32.Types
29
+ import System.Win32.Mem (zeroMemory )
30
+
24
31
import System.Win32.Async.WSABuf
25
32
import System.Win32.Async.IOData
26
33
import System.Win32.Async.ErrCode
@@ -47,6 +54,30 @@ sendBuf sock buf size = Socket.withFdSocket sock $ \fd ->
47
54
Left e -> return $ ErrorAsync (ErrorCode e)
48
55
else return $ ErrorSync (WsaErrorCode errorCode) False
49
56
57
+
58
+ sendBufTo :: SocketAddress sa
59
+ => Socket -- ^ Socket
60
+ -> Ptr Word8 -- ^ data to send
61
+ -> Int -- ^ size of the data
62
+ -> sa -- ^ address to send to
63
+ -> IO Int
64
+ sendBufTo sock buf size sa =
65
+ Socket. withFdSocket sock $ \ fd ->
66
+ withSocketAddress sa $ \ sa_ptr sa_size->
67
+ alloca $ \ bufsPtr ->
68
+ withIOCPData " sendBufTo" (FDSocket fd) $ \ lpOverlapped waitVar -> do
69
+ poke bufsPtr WSABuf {buf, len = fromIntegral size}
70
+ sendResult <- c_WSASendTo fd bufsPtr 1 nullPtr 0 sa_ptr sa_size lpOverlapped nullPtr
71
+ errorCode <- wsaGetLastError
72
+ if sendResult == 0 || errorCode == wSA_IO_PENDING
73
+ then do
74
+ iocpResult <- takeMVar waitVar
75
+ case iocpResult of
76
+ Right numBytes -> return $ ResultAsync numBytes
77
+ Left e -> return $ ErrorAsync (ErrorCode e)
78
+ else return $ ErrorSync (WsaErrorCode errorCode) False
79
+
80
+
50
81
-- | Unfortunatelly `connect` using interruptible ffi is not interruptible.
51
82
-- Instead we run the `Socket.connect` in a dedicated thread and block on an
52
83
-- 'MVar'.
@@ -102,3 +133,65 @@ recvBuf sock buf size =
102
133
Right numBytes -> return $ ResultAsync numBytes
103
134
Left e -> return $ ErrorAsync (ErrorCode e)
104
135
else return $ ErrorSync (WsaErrorCode errorCode) False
136
+
137
+
138
+ recvBufFrom :: SocketAddress sa => Socket -> Ptr Word8 -> Int -> IO (Int , sa )
139
+ recvBufFrom _ _ size | size <= 0 =
140
+ ioError (mkInvalidRecvArgError " System.Win32.Async.Socket.recvBufFrom" )
141
+ recvBufFrom sock buf size =
142
+ Socket. withFdSocket sock $ \ fd ->
143
+ withIOCPData " recvBufFrom" (FDSocket fd) $ \ lpOverlapped waitVar ->
144
+ withNewSocketAddress $ \ saPtr saSize ->
145
+ alloca $ \ saSizePtr ->
146
+ alloca $ \ wsaBufPtr ->
147
+ alloca $ \ lpFlags -> do
148
+ poke saSizePtr (fromIntegral saSize)
149
+ poke wsaBufPtr (WSABuf (fromIntegral size) buf)
150
+ poke lpFlags 0
151
+ recvResult <-
152
+ c_WSARecvFrom fd wsaBufPtr 1
153
+ nullPtr lpFlags
154
+ saPtr saSizePtr
155
+ lpOverlapped nullPtr
156
+ errorCode <- wsaGetLastError
157
+ if recvResult == 0 || errorCode == wSA_IO_PENDING
158
+ then do
159
+ iocpResult <- takeMVar waitVar
160
+ case iocpResult of
161
+ Right numBytes -> do
162
+ -- if we catch IO exception and use `getPeerName` as the
163
+ -- `network` package does it throws `WSAENOTCONN` exception,
164
+ -- hiding the initial exception.
165
+ sockAddr <- peekSocketAddress saPtr
166
+ return $ ResultAsync (numBytes, sockAddr)
167
+ Left e -> return $ ErrorAsync (ErrorCode e)
168
+ else return $ ErrorSync (WsaErrorCode errorCode) False
169
+
170
+
171
+ --
172
+ -- Utils
173
+ --
174
+
175
+ -- | Copied from `Network.Socket.Types.withSocketAddress`.
176
+ --
177
+ withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a ) -> IO a
178
+ withSocketAddress addr f = do
179
+ let sz = sizeOfSocketAddress addr
180
+ allocaBytes sz $ \ p -> pokeSocketAddress p addr >> f (castPtr p) sz
181
+
182
+ -- sizeof(struct sockaddr_storage) which has enough space to contain
183
+ -- sockaddr_in, sockaddr_in6 and sockaddr_un.
184
+ sockaddrStorageLen :: Int
185
+ sockaddrStorageLen = 128
186
+
187
+ -- | Copied from `Network.Socket.Types.withNewSocketAddress`.
188
+ --
189
+ withNewSocketAddress :: SocketAddress sa => (Ptr sa -> Int -> IO a ) -> IO a
190
+ withNewSocketAddress f = allocaBytes sockaddrStorageLen $ \ ptr -> do
191
+ zeroMemory ptr $ fromIntegral sockaddrStorageLen
192
+ f ptr sockaddrStorageLen
193
+
194
+ mkInvalidRecvArgError :: String -> IOError
195
+ mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
196
+ InvalidArgument
197
+ loc Nothing Nothing ) " non-positive length"
0 commit comments