diff --git a/cardano-ping/src/Cardano/Network/Ping.hs b/cardano-ping/src/Cardano/Network/Ping.hs index 2947d4fb096..d82959cf5ac 100644 --- a/cardano-ping/src/Cardano/Network/Ping.hs +++ b/cardano-ping/src/Cardano/Network/Ping.hs @@ -674,7 +674,7 @@ pingClient stdout stderr PingOpts{..} versions peer = bracket let peerStr' = TL.pack peerStr unless pingOptsQuiet $ TL.hPutStrLn IO.stdout $ peerStr' <> " " <> (showNetworkRtt $ toSample t0_e t0_s) - bearer <- getBearer makeSocketBearer sduTimeout nullTracer sd + bearer <- getBearer makeSocketBearer sduTimeout nullTracer sd Nothing !t1_s <- write bearer timeoutfn $ wrap handshakeNum InitiatorDir (handshakeReq versions pingOptsHandshakeQuery) (msg, !t1_e) <- nextMsg bearer timeoutfn handshakeNum diff --git a/network-mux/bench/socket_read_write/Main.hs b/network-mux/bench/socket_read_write/Main.hs new file mode 100644 index 00000000000..06b4a346199 --- /dev/null +++ b/network-mux/bench/socket_read_write/Main.hs @@ -0,0 +1,318 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} + +import Control.Exception (bracket) +import Control.Concurrent.Class.MonadSTM.Strict +import Data.Functor (void) +import Control.Monad (forever, replicateM_, when, unless) +import Control.Monad.Class.MonadAsync +import Control.Monad.Class.MonadTimer.SI +import Control.Tracer +import Data.Int +import Data.Word +import Network.Socket qualified as Socket +import Network.Socket (Socket) +import Network.Socket.ByteString.Lazy qualified as Socket (recv) +import Data.ByteString.Builder (Builder, toLazyByteString) +import Data.ByteString.Lazy qualified as BL +import Test.Tasty.Bench + +import Network.Mux.Bearer +import Network.Mux.Egress +import Network.Mux.Ingress +import Network.Mux +import Network.Mux.Types + +import Network.Mux.Timeout (withTimeoutSerial) + +activeTracer :: Tracer IO a +activeTracer = nullTracer +--activeTracer = showTracing stdoutTracer + +sduTimeout :: DiffTime +sduTimeout = 10 + +numberOfPackets :: Int64 +numberOfPackets = 100000 + +totalPayloadLen :: Int64 -> Int64 +totalPayloadLen sndSize = sndSize * numberOfPackets + +-- | Run a client that connects to the specified addr. +-- Signals the message sndSize to the server by writing it +-- in the provided TMVar. +readBenchmark :: StrictTMVar IO Int64 -> Int64 -> Socket.SockAddr -> IO () +readBenchmark sndSizeV sndSize addr = do + bracket + (Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol) + Socket.close + (\sd -> do + atomically $ putTMVar sndSizeV sndSize + Socket.connect sd addr + withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer sduTimeout activeTracer sd buffer + + let chan = bearerAsChannel bearer (MiniProtocolNum 42) InitiatorDir + doRead (totalPayloadLen sndSize) chan 0 + ) + ) + where + doRead :: Int64 -> ByteChannel IO -> Int64 -> IO () + doRead maxData _ cnt | cnt >= maxData = return () + doRead maxData chan !cnt = do + msg_m <- recv chan + case msg_m of + Just msg -> doRead maxData chan (cnt + BL.length msg) + Nothing -> error "doRead: nullread" + +-- | Like readDemuxerBenchmark but it doesn't empty the ingress queue until +-- all data has been sent. +readDemuxerQueueBenchmark :: StrictTMVar IO Int64 -> Int64 -> Socket.SockAddr -> IO () +readDemuxerQueueBenchmark sndSizeV sndSize addr = do + bracket + (Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol) + Socket.close + (\sd -> do + atomically $ putTMVar sndSizeV sndSize + + Socket.connect sd addr + withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer sduTimeout activeTracer sd buffer + ms42 <- mkMiniProtocolState 42 + withAsync (demuxer [ms42] bearer) $ \aid -> do + doRead 0xa5 (totalPayloadLen sndSize) (miniProtocolIngressQueue ms42) + cancel aid + ) + ) + where + doRead :: Word8 -> Int64 -> StrictTVar IO (Int64, Builder) -> IO () + doRead tag maxData queue = do + msg <- atomically $ do + (l,b) <- readTVar queue + if l == maxData + then + return (toLazyByteString b) + else + retry + if BL.all ( == tag) msg + then return () + else error "corrupt stream" + +-- Like readBenchmark but uses a demuxer thread +readDemuxerBenchmark :: StrictTMVar IO Int64 -> Int64 -> Socket.SockAddr -> IO () +readDemuxerBenchmark sndSizeV sndSize addr = do + bracket + (Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol) + Socket.close + (\sd -> do + atomically $ putTMVar sndSizeV sndSize + + Socket.connect sd addr + withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer sduTimeout activeTracer sd buffer + ms42 <- mkMiniProtocolState 42 + ms41 <- mkMiniProtocolState 41 + withAsync (demuxer [ms41, ms42] bearer) $ \aid -> do + withAsync (doRead 42 (totalPayloadLen sndSize) (miniProtocolIngressQueue ms42) 0) $ \aid42 -> do + withAsync (doRead 41 (totalPayloadLen 10) (miniProtocolIngressQueue ms41) 0) $ \aid41 -> do + _ <- waitBoth aid42 aid41 + cancel aid + return () + ) + ) + where + doRead :: Word8 -> Int64 -> StrictTVar IO (Int64, Builder) -> Int64 -> IO () + doRead _ maxData _ cnt | cnt >= maxData = return () + doRead tag maxData queue !cnt = do + msg <- atomically $ do + (l,b) <- readTVar queue + if l == 0 + then retry + else do + writeTVar queue (0, mempty) + return (toLazyByteString b) + if BL.all ( == tag) msg + then doRead tag maxData queue (cnt + BL.length msg) + else error "corrupt stream" + +mkMiniProtocolState :: MonadSTM m => Word16 -> m (MiniProtocolState 'InitiatorMode m) +mkMiniProtocolState num = do + mpq <- newTVarIO (0, mempty) + mpv <- newTVarIO StatusRunning + + let mpi = MiniProtocolInfo (MiniProtocolNum num) InitiatorDirectionOnly + (MiniProtocolLimits maxBound) + return $ MiniProtocolState mpi mpq mpv + +-- | Run a server that accept connections on `ad`. +startServer :: StrictTMVar IO Int64 -> Socket -> IO () +startServer sndSizeV ad = forever $ do + (sd, _) <- Socket.accept ad + withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer sduTimeout activeTracer sd buffer + sndSize <- atomically $ takeTMVar sndSizeV + + let chan = bearerAsChannel bearer (MiniProtocolNum 42) ResponderDir + payload = BL.replicate sndSize 0xa5 + maxData = totalPayloadLen sndSize + numberOfSdus = fromIntegral $ maxData `div` sndSize + replicateM_ numberOfSdus $ do + send chan payload + ) +-- | Like startServer but it uses the `writeMany` function +-- for vector IO. +startServerMany :: StrictTMVar IO Int64 -> Socket -> IO () +startServerMany sndSizeV ad = forever $ do + (sd, _) <- Socket.accept ad + withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer sduTimeout activeTracer sd buffer + sndSize <- atomically $ takeTMVar sndSizeV + + let maxData = totalPayloadLen sndSize + numberOfSdus = fromIntegral $ maxData `div` sndSize + numberOfCalls = numberOfSdus `div` 10 + runtSdus = numberOfSdus `mod` 10 + + withTimeoutSerial $ \timeoutFn -> do + replicateM_ numberOfCalls $ do + let sdus = replicate 10 $ wrap $ BL.replicate sndSize 0xa5 + void $ writeMany bearer timeoutFn sdus + when (runtSdus > 0) $ do + let sdus = replicate runtSdus $ wrap $ BL.replicate sndSize 0xa5 + void $ writeMany bearer timeoutFn sdus + ) + where + -- wrap a 'ByteString' as 'SDU' + wrap :: BL.ByteString -> SDU + wrap blob = SDU { + -- it will be filled when the 'SDU' is send by the 'bearer' + msHeader = SDUHeader { + mhTimestamp = RemoteClockModel 0, + mhNum = MiniProtocolNum 42, + mhDir = ResponderDir, + mhLength = fromIntegral $ BL.length blob + }, + msBlob = blob + } + +-- | Run a server that accept connections on `ad`. +-- It will send streams of data over the 41 and 42 miniprotocol. +-- Multiplexing is done with a separate thread running +-- the Egress.muxer function. +startServerEgresss :: StrictTMVar IO Int64 -> Socket -> IO () +startServerEgresss sndSizeV ad = forever $ do + (sd, _) <- Socket.accept ad + withReadBufferIO (\buffer -> do + bearer <-getBearer makeSocketBearer sduTimeout activeTracer sd buffer + sndSize <- atomically $ takeTMVar sndSizeV + eq <- atomically $ newTBQueue 100 + w42 <- newTVarIO BL.empty + w41 <- newTVarIO BL.empty + + let maxData = totalPayloadLen sndSize + numberOfSdus = fromIntegral $ maxData `div` sndSize + numberOfCalls = numberOfSdus `div` 10 :: Int + runtSdus = numberOfSdus `mod` 10 :: Int + + withAsync (muxer eq bearer) $ \aid -> do + + replicateM_ numberOfCalls $ do + let payload42s = replicate 10 $ BL.replicate sndSize 42 + let payload41s = replicate 10 $ BL.replicate 10 41 + mapM_ (sendToMux w42 eq (MiniProtocolNum 42) ResponderDir) payload42s + mapM_ (sendToMux w41 eq (MiniProtocolNum 41) ResponderDir) payload41s + when (runtSdus > 0) $ do + let payload42s = replicate runtSdus $ BL.replicate sndSize 42 + let payload41s = replicate runtSdus $ BL.replicate 10 41 + mapM_ (sendToMux w42 eq (MiniProtocolNum 42) ResponderDir) payload42s + mapM_ (sendToMux w41 eq (MiniProtocolNum 41) ResponderDir) payload41s + + -- Wait for the egress queue to empty + atomically $ do + r42 <- readTVar w42 + r41 <- readTVar w42 + unless (BL.null r42 || BL.null r41) retry + + -- when the client is done they will close the socket + -- and we will read zero bytes. + _ <- Socket.recv sd 128 + + cancel aid + ) + where + sendToMux :: StrictTVar IO BL.ByteString -> EgressQueue IO -> MiniProtocolNum -> MiniProtocolDir + -> BL.ByteString -> IO () + sendToMux w eq mc md msg = do + atomically $ do + buf <- readTVar w + if BL.length buf < 0x3ffff + then do + let wasEmpty = BL.null buf + writeTVar w (BL.append buf msg) + when wasEmpty $ + writeTBQueue eq (TLSRDemand mc md $ Wanton w) + else retry + +setupServer :: Socket -> IO Socket.SockAddr +setupServer ad = do + muxAddress:_ <- Socket.getAddrInfo Nothing (Just "127.0.0.1") (Just "0") + Socket.setSocketOption ad Socket.ReuseAddr 1 + Socket.bind ad (Socket.addrAddress muxAddress) + addr <- Socket.getSocketName ad + Socket.listen ad 3 + + return addr + +-- Main function to run the benchmarks +main :: IO () +main = do + bracket + (do + ad1 <- Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol + ad2 <- Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol + ad3 <- Socket.socket Socket.AF_INET Socket.Stream Socket.defaultProtocol + + return (ad1, ad2, ad3) + ) + (\(ad1, ad2, ad3) -> do + Socket.close ad1 + Socket.close ad2 + Socket.close ad3 + ) + (\(ad1, ad2, ad3) -> do + sndSizeV <- newEmptyTMVarIO + sndSizeMV <- newEmptyTMVarIO + sndSizeEV <- newEmptyTMVarIO + addr <- setupServer ad1 + addrM <- setupServer ad2 + addrE <- setupServer ad3 + + withAsync (startServer sndSizeV ad1) $ \said -> do + withAsync (startServerMany sndSizeMV ad2) $ \saidM -> do + withAsync (startServerEgresss sndSizeEV ad3) $ \saidE -> do + defaultMain [ + -- Suggested Max SDU size for Socket bearer + bench "Read/Write Benchmark 12288 byte SDUs" $ nfIO $ readBenchmark sndSizeV 12288 addr + -- Payload size for ChainSync's RequestNext + , bench "Read/Write Benchmark 914 byte SDUs" $ nfIO $ readBenchmark sndSizeV 914 addr + -- Payload size for ChainSync's RequestNext + , bench "Read/Write Benchmark 10 byte SDUs" $ nfIO $ readBenchmark sndSizeV 10 addr + + -- Send batches of SDUs at the same time + , bench "Read/Write-Many Benchmark 12288 byte SDUs" $ nfIO $ readBenchmark sndSizeMV 12288 addrM + , bench "Read/Write-Many Benchmark 914 byte SDUs" $ nfIO $ readBenchmark sndSizeMV 914 addrM + , bench "Read/Write-Many Benchmark 10 byte SDUs" $ nfIO $ readBenchmark sndSizeMV 10 addrM + + -- Use standard muxer and demuxer + , bench "Read/Write Mux Benchmark 800+10 byte SDUs" $ nfIO $ readDemuxerBenchmark sndSizeEV 800 addrE + , bench "Read/Write Mux Benchmark 12288+10 byte SDUs" $ nfIO $ readDemuxerBenchmark sndSizeEV 12288 addrE + + -- Use standard demuxer + , bench "Read/Write Demuxer Queuing Benchmark 10 byte SDUs" $ nfIO $ readDemuxerQueueBenchmark sndSizeV 10 addr + , bench "Read/Write Demuxer Queuing Benchmark 256 byte SDUs" $ nfIO $ readDemuxerQueueBenchmark sndSizeV 256 addr + ] + cancel said + cancel saidM + cancel saidE + ) diff --git a/network-mux/demo/mux-demo.hs b/network-mux/demo/mux-demo.hs index e80a2dbc76e..c709cabde1c 100644 --- a/network-mux/demo/mux-demo.hs +++ b/network-mux/demo/mux-demo.hs @@ -101,7 +101,7 @@ server = associateWithIOManager ioManager (Left hpipe) Win32.Async.connectNamedPipe hpipe void $ forkIO $ do - bearer <- getBearer Mx.makeNamedPipeBearer (-1) nullTracer hpipe + bearer <- getBearer Mx.makeNamedPipeBearer (-1) nullTracer hpipe Nothing serverWorker bearer `finally` closeHandle hpipe #else @@ -113,7 +113,7 @@ server = do forever $ do (sock', _addr) <- Socket.accept sock void $ forkIO $ do - bearer <- getBearer Mx.makeSocketBearer 1.0 nullTracer sock' + bearer <- getBearer Mx.makeSocketBearer 1.0 nullTracer sock' Nothing serverWorker bearer `finally` Socket.close sock' #endif @@ -167,13 +167,13 @@ client n msg = fILE_FLAG_OVERLAPPED Nothing associateWithIOManager ioManager (Left hpipe) - bearer <- getBearer Mx.makeNamedPipeBearer (-1) nullTracer hpipe + bearer <- getBearer Mx.makeNamedPipeBearer (-1) nullTracer hpipe Nothing clientWorker bearer n msg #else client n msg = do sock <- Socket.socket AF_UNIX Socket.Stream Socket.defaultProtocol Socket.connect sock (SockAddrUnix pipeName) - bearer <- getBearer Mx.makeSocketBearer 1.0 nullTracer sock + bearer <- getBearer Mx.makeSocketBearer 1.0 nullTracer sock Nothing clientWorker bearer n msg #endif diff --git a/network-mux/network-mux.cabal b/network-mux/network-mux.cabal index bfab6b14bda..df747e48eed 100644 --- a/network-mux/network-mux.cabal +++ b/network-mux/network-mux.cabal @@ -195,3 +195,36 @@ executable mux-demo build-depends: directory, network, + +benchmark socket-read-write + type: exitcode-stdio-1.0 + hs-source-dirs: bench/socket_read_write + main-is: Main.hs + other-modules: + + build-depends: + base >=4.14 && <4.21, + bytestring, + contra-tracer, + io-classes, + network, + network-mux, + si-timers, + strict-stm, + tasty-bench + + default-extensions: ImportQualifiedPost + ghc-options: + -threaded + -rtsopts + -fproc-alignment=64 + -Wall + -Wcompat + -Wincomplete-uni-patterns + -Wincomplete-record-updates + -Wpartial-fields + -Widentities + -Wredundant-constraints + -Wunused-packages + + default-language: Haskell2010 diff --git a/network-mux/src/Network/Mux.hs b/network-mux/src/Network/Mux.hs index 34f66cb1ef0..9312b894094 100644 --- a/network-mux/src/Network/Mux.hs +++ b/network-mux/src/Network/Mux.hs @@ -48,6 +48,7 @@ module Network.Mux , WithBearer (..) ) where +import Data.ByteString.Builder (lazyByteString, toLazyByteString) import Data.ByteString.Lazy qualified as BL import Data.Int (Int64) import Data.Map (Map) @@ -143,7 +144,7 @@ mkMiniProtocolState :: MonadSTM m => MiniProtocolInfo mode -> m (MiniProtocolState mode m) mkMiniProtocolState miniProtocolInfo = do - miniProtocolIngressQueue <- newTVarIO BL.empty + miniProtocolIngressQueue <- newTVarIO (0, mempty) miniProtocolStatusVar <- newTVarIO StatusIdle return MiniProtocolState { miniProtocolInfo, @@ -203,6 +204,7 @@ data Group = MuxJob -- run :: forall m mode. ( MonadAsync m + , MonadDelay m , MonadFork m , MonadLabelledSTM m , Alternative (STM m) @@ -309,7 +311,8 @@ miniProtocolJob tracer egressQueue `orElse` throwSTM (BlockedOnCompletionVar miniProtocolNum) case remainder of Just trailing -> - modifyTVar miniProtocolIngressQueue (BL.append trailing) + modifyTVar miniProtocolIngressQueue (\(l, b) -> + (l + BL.length trailing, b <> (lazyByteString trailing))) Nothing -> pure () @@ -519,8 +522,8 @@ monitor tracer timeout jobpool egressQueue cmdQueue muxStatus = checkNonEmptyQueue :: IngressQueue m -> STM m () checkNonEmptyQueue q = do - buf <- readTVar q - check (not (BL.null buf)) + (l, _) <- readTVar q + check (l /= 0) protocolKey :: MiniProtocolState mode m -> MiniProtocolKey protocolKey MiniProtocolState { @@ -604,10 +607,10 @@ muxChannel tracer egressQueue want@(Wanton w) mc md q = -- matching ingress queue. This is the same queue the 'demux' thread writes to. traceWith tracer $ TraceChannelRecvStart mc blob <- atomically $ do - blob <- readTVar q - if blob == BL.empty + (l, blob) <- readTVar q + if l == 0 then retry - else writeTVar q BL.empty >> return blob + else writeTVar q (0, mempty) >> return (toLazyByteString blob) -- say $ printf "recv mid %s mode %s blob len %d" (show mid) (show md) (BL.length blob) traceWith tracer $ TraceChannelRecvEnd mc (fromIntegral $ BL.length blob) return $ Just blob diff --git a/network-mux/src/Network/Mux/Bearer.hs b/network-mux/src/Network/Mux/Bearer.hs index 732a7a1e0fd..b2a85e2ed48 100644 --- a/network-mux/src/Network/Mux/Bearer.hs +++ b/network-mux/src/Network/Mux/Bearer.hs @@ -15,17 +15,21 @@ module Network.Mux.Bearer #if defined(mingw32_HOST_OS) , makeNamedPipeBearer #endif + , withReadBufferIO ) where import Control.Monad.Class.MonadSTM +import Control.Concurrent.Class.MonadSTM.Strict import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI import Control.Tracer (Tracer) +import Data.ByteString.Lazy qualified as BL import Network.Socket (Socket) #if defined(mingw32_HOST_OS) import System.Win32 (HANDLE) #endif +import Foreign.Marshal.Alloc import Network.Mux.Bearer.Pipe import Network.Mux.Bearer.Queues @@ -45,22 +49,38 @@ newtype MakeBearer m fd = MakeBearer { -- tracer -> fd -- file descriptor + -> Maybe (ReadBuffer m) + -- Optional Readbuffer -> m (Bearer m) } - pureBearer :: Applicative m - => (DiffTime -> Tracer m Trace -> fd -> Bearer m) - -> DiffTime -> Tracer m Trace -> fd -> m (Bearer m) -pureBearer f = \sduTimeout tr fd -> pure (f sduTimeout tr fd) + => (DiffTime -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> Bearer m) + -> DiffTime -> Tracer m Trace -> fd -> Maybe (ReadBuffer m) -> m (Bearer m) +pureBearer f = \sduTimeout rb tr fd -> pure (f sduTimeout rb tr fd) + makeSocketBearer :: MakeBearer IO Socket -makeSocketBearer = MakeBearer $ pureBearer (socketAsBearer size) +makeSocketBearer = MakeBearer $ (\sduTimeout tr fd rb -> do + return $ socketAsBearer size batch rb sduTimeout tr fd) where size = SDUSize 12_288 + batch = 131_072 + +withReadBufferIO :: (Maybe (ReadBuffer IO) -> IO b) + -> IO b +withReadBufferIO f = allocaBytesAligned size 8 $ \ptr -> do + v <- atomically $ newTVar BL.empty + f $ Just $ ReadBuffer v ptr size + where + -- Maximum amount of data read in one call. + -- Corresponds to the default readbuffer size on Linux. + -- We want it larger than 64Kbyte, but not too large since + -- it is a memory overhead per mux bearer in an application. + size = 131_072 makePipeChannelBearer :: MakeBearer IO PipeChannel -makePipeChannelBearer = MakeBearer $ pureBearer (\_ -> pipeAsBearer size) +makePipeChannelBearer = MakeBearer $ pureBearer (\_ tr fd _ -> pipeAsBearer size tr fd) where size = SDUSize 32_768 @@ -69,13 +89,13 @@ makeQueueChannelBearer :: ( MonadSTM m , MonadThrow m ) => MakeBearer m (QueueChannel m) -makeQueueChannelBearer = MakeBearer $ pureBearer (\_ -> queueChannelAsBearer size) +makeQueueChannelBearer = MakeBearer $ pureBearer (\_ tr q _-> queueChannelAsBearer size tr q) where size = SDUSize 1_280 #if defined(mingw32_HOST_OS) makeNamedPipeBearer :: MakeBearer IO HANDLE -makeNamedPipeBearer = MakeBearer $ pureBearer (\_ -> namedPipeAsBearer size) +makeNamedPipeBearer = MakeBearer $ pureBearer (\_ tr fd _-> namedPipeAsBearer size tr fd) where size = SDUSize 24_576 #endif diff --git a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs index 13620703c32..d1e756d88c6 100644 --- a/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs +++ b/network-mux/src/Network/Mux/Bearer/AttenuatedChannel.hs @@ -273,10 +273,12 @@ attenuationChannelAsBearer :: forall m. -> Bearer m attenuationChannelAsBearer sduSize sduTimeout muxTracer chan = Bearer { - read = readMux, - write = writeMux, + read = readMux, + write = writeMux, + writeMany = writeMuxMany, sduSize, - name = "attenuation-channel" + batchSize = fromIntegral $ getSDUSize sduSize, + name = "attenuation-channel" } where readMux :: TimeoutFn m -> m (SDU, Time) @@ -311,6 +313,12 @@ attenuationChannelAsBearer sduSize sduTimeout muxTracer chan = traceWith muxTracer TraceSendEnd return ts + writeMuxMany :: TimeoutFn m -> [SDU] -> m Time + writeMuxMany timeoutFn sdus = do + ts <- getMonotonicTime + mapM_ (writeMux timeoutFn) sdus + return ts + -- -- Trace -- diff --git a/network-mux/src/Network/Mux/Bearer/NamedPipe.hs b/network-mux/src/Network/Mux/Bearer/NamedPipe.hs index 3d68442ef81..30753531844 100644 --- a/network-mux/src/Network/Mux/Bearer/NamedPipe.hs +++ b/network-mux/src/Network/Mux/Bearer/NamedPipe.hs @@ -33,16 +33,18 @@ namedPipeAsBearer :: Mx.SDUSize -> Mx.Bearer IO namedPipeAsBearer sduSize tracer h = Mx.Bearer { - Mx.read = readNamedPipe, - Mx.write = writeNamedPipe, - Mx.sduSize = sduSize, - Mx.name = "named-pipe" + Mx.read = readNamedPipe, + Mx.write = writeNamedPipe, + Mx.writeMany = writeNamedPipeMany, + Mx.sduSize = sduSize, + Mx.batchSize = fromIntegral $ Mx.getSDUSize sduSize, + Mx.name = "named-pipe" } where readNamedPipe :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time) readNamedPipe _ = do traceWith tracer Mx.TraceRecvHeaderStart - hbuf <- recvLen' True 8 [] + hbuf <- recvLen' True Mx.msHeaderLength [] case Mx.decodeSDU hbuf of Left e -> throwIO e Right header@Mx.SDU { Mx.msHeader } -> do @@ -80,3 +82,9 @@ namedPipeAsBearer sduSize tracer h = `catch` Mx.handleIOException "writeHandle errored" traceWith tracer Mx.TraceSendEnd return ts + + writeNamedPipeMany :: Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time + writeNamedPipeMany timeoutFn sdus = do + ts <- getMonotonicTime + mapM_ (writeNamedPipe timeoutFn) sdus + return ts diff --git a/network-mux/src/Network/Mux/Bearer/Pipe.hs b/network-mux/src/Network/Mux/Bearer/Pipe.hs index deceb447662..174d0f302fb 100644 --- a/network-mux/src/Network/Mux/Bearer/Pipe.hs +++ b/network-mux/src/Network/Mux/Bearer/Pipe.hs @@ -75,16 +75,18 @@ pipeAsBearer -> Bearer IO pipeAsBearer sduSize tracer channel = Mx.Bearer { - Mx.read = readPipe, - Mx.write = writePipe, - Mx.sduSize = sduSize, - Mx.name = "pipe" + Mx.read = readPipe, + Mx.write = writePipe, + Mx.writeMany = writePipeMany, + Mx.sduSize = sduSize, + Mx.name = "pipe", + Mx.batchSize = fromIntegral $ Mx.getSDUSize sduSize } where readPipe :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time) readPipe _ = do traceWith tracer Mx.TraceRecvHeaderStart - hbuf <- recvLen' 8 [] + hbuf <- recvLen' (fromIntegral Mx.msHeaderLength) [] case Mx.decodeSDU hbuf of Left e -> throwIO e Right header@Mx.SDU { Mx.msHeader } -> do @@ -118,3 +120,9 @@ pipeAsBearer sduSize tracer channel = traceWith tracer Mx.TraceSendEnd return ts + writePipeMany :: Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time + writePipeMany timeoutFn sdus = do + ts <- getMonotonicTime + mapM_ (writePipe timeoutFn) sdus + return ts + diff --git a/network-mux/src/Network/Mux/Bearer/Queues.hs b/network-mux/src/Network/Mux/Bearer/Queues.hs index a512863fa40..29c7769eef1 100644 --- a/network-mux/src/Network/Mux/Bearer/Queues.hs +++ b/network-mux/src/Network/Mux/Bearer/Queues.hs @@ -40,10 +40,12 @@ queueChannelAsBearer -> Bearer m queueChannelAsBearer sduSize tracer QueueChannel { writeQueue, readQueue } = do Mx.Bearer { - Mx.read = readMux, - Mx.write = writeMux, - Mx.sduSize = sduSize, - Mx.name = "queue-channel" + Mx.read = readMux, + Mx.write = writeMux, + Mx.writeMany = writeMuxMany, + Mx.sduSize = sduSize, + Mx.batchSize = 2 * (fromIntegral $ Mx.getSDUSize sduSize), + Mx.name = "queue-channel" } where readMux :: Mx.TimeoutFn m -> m (Mx.SDU, Time) @@ -70,3 +72,9 @@ queueChannelAsBearer sduSize tracer QueueChannel { writeQueue, readQueue } = do traceWith tracer Mx.TraceSendEnd return ts + writeMuxMany :: Mx.TimeoutFn m -> [Mx.SDU] -> m Time + writeMuxMany timeoutFn sdus = do + ts <- getMonotonicTime + mapM_ (writeMux timeoutFn) sdus + return ts + diff --git a/network-mux/src/Network/Mux/Bearer/Socket.hs b/network-mux/src/Network/Mux/Bearer/Socket.hs index 709a9a35dd9..04b7b639dad 100644 --- a/network-mux/src/Network/Mux/Bearer/Socket.hs +++ b/network-mux/src/Network/Mux/Bearer/Socket.hs @@ -2,6 +2,7 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.Mux.Bearer.Socket (socketAsBearer) where @@ -11,6 +12,7 @@ import Control.Tracer import Data.ByteString.Lazy qualified as BL import Data.Int +import Control.Concurrent.Class.MonadSTM.Strict import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTime.SI import Control.Monad.Class.MonadTimer.SI hiding (timeout) @@ -18,6 +20,9 @@ import Control.Monad.Class.MonadTimer.SI hiding (timeout) import Network.Socket qualified as Socket #if !defined(mingw32_HOST_OS) import Network.Socket.ByteString.Lazy qualified as Socket (recv, sendAll) +import Network.Socket.ByteString qualified as Socket (sendMany) +import Data.ByteString.Internal (create) +import Foreign.Marshal.Utils #else import System.Win32.Async.Socket.ByteString.Lazy qualified as Win32.Async #endif @@ -45,26 +50,28 @@ import Network.Mux.TCPInfo (SocketOption (TCPInfoSocketOption)) -- socketAsBearer :: Mx.SDUSize + -> Int + -> Maybe (Mx.ReadBuffer IO) -> DiffTime -> Tracer IO Mx.Trace -> Socket.Socket -> Bearer IO -socketAsBearer sduSize sduTimeout tracer sd = +socketAsBearer sduSize batchSize readBuffer_m sduTimeout tracer sd = Mx.Bearer { - Mx.read = readSocket, - Mx.write = writeSocket, - Mx.sduSize = sduSize, - Mx.name = "socket-bearer" + Mx.read = readSocket, + Mx.write = writeSocket, + Mx.writeMany = writeSocketMany, + Mx.sduSize = sduSize, + Mx.batchSize = batchSize, + Mx.name = "socket-bearer" } where - hdrLenght = 8 - readSocket :: Mx.TimeoutFn IO -> IO (Mx.SDU, Time) readSocket timeout = do traceWith tracer Mx.TraceRecvHeaderStart -- Wait for the first part of the header without any timeout - h0 <- recvAtMost True hdrLenght + h0 <- recvAtMost True Mx.msHeaderLength -- Optionally wait at most sduTimeout seconds for the complete SDU. r_m <- timeout sduTimeout $ recvRem h0 @@ -76,7 +83,7 @@ socketAsBearer sduSize sduTimeout tracer sd = recvRem :: BL.ByteString -> IO (Mx.SDU, Time) recvRem !h0 = do - hbuf <- recvLen' (hdrLenght - BL.length h0) [h0] + hbuf <- recvLen' (Mx.msHeaderLength - BL.length h0) [h0] case Mx.decodeSDU hbuf of Left e -> throwIO e Right header@Mx.SDU { Mx.msHeader } -> do @@ -97,26 +104,77 @@ socketAsBearer sduSize sduTimeout tracer sd = recvAtMost :: Bool -> Int64 -> IO BL.ByteString recvAtMost waitingOnNxtHeader l = do traceWith tracer $ Mx.TraceRecvStart $ fromIntegral l + + case readBuffer_m of + Nothing -> -- No read buffer available; read directly from socket + recvFromSocket l + Just Mx.ReadBuffer{..} -> do + availableData <- atomically $ do + buf <- readTVar rbVar + if BL.length buf >= l + then do + let (toProcess, remaining) = BL.splitAt l buf + writeTVar rbVar remaining + return toProcess + else do + writeTVar rbVar BL.empty + return buf + + if BL.null availableData + then do + -- Not data in buffer; read more from socket + when (not waitingOnNxtHeader) $ + -- Don't let the kernel wake us up until there is + -- at least l bytes of data. + Socket.setSocketOption sd Socket.RecvLowWater $ fromIntegral l + newBuf <- recvFromSocket $ fromIntegral rbSize + atomically $ modifyTVar rbVar (`BL.append` newBuf) + when (not waitingOnNxtHeader) $ + Socket.setSocketOption sd Socket.RecvLowWater 1 + recvAtMost waitingOnNxtHeader l + else do + traceWith tracer $ Mx.TraceRecvEnd $ fromIntegral $ BL.length availableData + return availableData + where +#if !defined(mingw32_HOST_OS) + -- Read at most `min rbSize maxLen` bytes from the socket + -- into rbBuf. + -- Creates and returns a Bytestring matching the exact size + -- of the number of bytes read. + recvBuf :: Mx.ReadBuffer IO -> Int64 -> IO BL.ByteString + recvBuf Mx.ReadBuffer{..} maxLen = do + len <- Socket.recvBuf sd rbBuf (min rbSize $ fromIntegral maxLen) + traceWith tracer $ Mx.TraceRecvRaw len + if len > 0 + then do + bs <- create len (\dest -> copyBytes dest rbBuf len) + return $ BL.fromStrict bs + else return $ BL.empty +#endif + + recvFromSocket :: Int64 -> IO BL.ByteString + recvFromSocket len = do #if defined(mingw32_HOST_OS) - buf <- Win32.Async.recv sd (fromIntegral l) + buf <- Win32.Async.recv sd (fromIntegral len) #else - buf <- Socket.recv sd l + buf <- (case readBuffer_m of + Nothing -> Socket.recv sd len + Just readBuffer -> recvBuf readBuffer len + ) #endif - `catch` Mx.handleIOException "recv errored" - if BL.null buf - then do - when waitingOnNxtHeader $ - {- This may not be an error, but could be an orderly shutdown. - - We wait 1 seconds to give the mux protocols time to perform - - a clean up and exit. - -} - threadDelay 1 - throwIO $ Mx.BearerClosed (show sd ++ - " closed when reading data, waiting on next header " ++ - show waitingOnNxtHeader) - else do - traceWith tracer $ Mx.TraceRecvEnd (fromIntegral $ BL.length buf) - return buf + `catch` Mx.handleIOException "recv errored" + if BL.null buf + then do + when waitingOnNxtHeader $ + {- This may not be an error, but could be an orderly shutdown. + - We wait 1 seconds to give the mux protocols time to perform + - a clean up and exit. + -} + threadDelay 1 + throwIO $ Mx.BearerClosed (show sd ++ + " closed when reading data, waiting on next header " ++ + show waitingOnNxtHeader) + else return buf writeSocket :: Mx.TimeoutFn IO -> Mx.SDU -> IO Time writeSocket timeout sdu = do @@ -148,3 +206,34 @@ socketAsBearer sduSize sduTimeout tracer sd = #endif return ts + writeSocketMany :: Mx.TimeoutFn IO -> [Mx.SDU] -> IO Time +#if defined(mingw32_HOST_OS) + writeSocketMany timeout sdus = do + ts <- getMonotonicTime + mapM_ (writeSocket timeout) sdus + return ts +#else + writeSocketMany timeout sdus = do + ts <- getMonotonicTime + let ts32 = Mx.timestampMicrosecondsLow32Bits ts + buf = map (Mx.encodeSDU . + (\sdu -> Mx.setTimestamp sdu (Mx.RemoteClockModel ts32))) sdus + r <- timeout ((fromIntegral $ length sdus) * sduTimeout) $ + Socket.sendMany sd (concatMap BL.toChunks buf) + `catch` Mx.handleIOException "sendAll errored" + case r of + Nothing -> do + traceWith tracer Mx.TraceSDUWriteTimeoutException + throwIO Mx.SDUWriteTimeout + Just _ -> do + traceWith tracer Mx.TraceSendEnd +#if defined(linux_HOST_OS) && defined(MUX_TRACE_TCPINFO) + -- If it was possible to detect if the TraceTCPInfo was + -- enable we wouldn't have to hide the getSockOpt + -- syscall in this ifdef. Instead we would only call it if + -- we knew that the information would be traced. + tcpi <- Socket.getSockOpt sd TCPInfoSocketOption + traceWith tracer $ Mx.TraceTCPInfo tcpi (sum $ map (Mx.mhLength . Mx.msHeader) sdus) +#endif + return ts +#endif diff --git a/network-mux/src/Network/Mux/Codec.hs b/network-mux/src/Network/Mux/Codec.hs index c8857d333ee..0b95bd35c0b 100644 --- a/network-mux/src/Network/Mux/Codec.hs +++ b/network-mux/src/Network/Mux/Codec.hs @@ -47,10 +47,13 @@ decodeSDU buf = case Bin.runGetOrFail dec buf of Left (_, _, e) -> Left $ SDUDecodeError e Right (_, _, h) -> - Right $ SDU { - msHeader = h - , msBlob = BL.empty - } + if mhLength h > 0 + then + Right $ SDU { + msHeader = h + , msBlob = BL.empty + } + else Left $ SDUDecodeError "short SDU" where dec = do mhTimestamp <- RemoteClockModel <$> Bin.getWord32be diff --git a/network-mux/src/Network/Mux/Egress.hs b/network-mux/src/Network/Mux/Egress.hs index c5629c694b4..7c0b4bc7cca 100644 --- a/network-mux/src/Network/Mux/Egress.hs +++ b/network-mux/src/Network/Mux/Egress.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} @@ -20,6 +21,7 @@ import Control.Concurrent.Class.MonadSTM.Strict import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Monad.Class.MonadTimer.SI hiding (timeout) +import Control.Monad.Class.MonadTime.SI import Network.Mux.Timeout import Network.Mux.Types @@ -131,7 +133,9 @@ newtype Wanton m = Wanton { want :: StrictTVar m BL.ByteString } -- that each active demand gets a `maxSDU`s work of data processed -- each time it gets to the front of the queue muxer - :: ( MonadAsync m + :: forall m void. + ( MonadAsync m + , MonadDelay m , MonadFork m , MonadMask m , MonadThrow (STM m) @@ -140,11 +144,45 @@ muxer => EgressQueue m -> Bearer m -> m void -muxer egressQueue bearer = +muxer egressQueue Bearer { writeMany, sduSize, batchSize } = withTimeoutSerial $ \timeout -> forever $ do + start <- getMonotonicTime TLSRDemand mpc md d <- atomically $ readTBQueue egressQueue - processSingleWanton egressQueue bearer timeout mpc md d + sdu <- processSingleWanton egressQueue sduSize mpc md d + sdus <- buildBatch [sdu] (sduLength sdu) + void $ writeMany timeout (reverse sdus) + end <- getMonotonicTime + empty <- atomically $ isEmptyTBQueue egressQueue + when (empty) $ do + let delta = diffTime end start + threadDelay (loopInterval - delta) + + where + loopInterval :: DiffTime + loopInterval = 0.001 + + maxSDUsPerBatch :: Int + maxSDUsPerBatch = 100 + + sduLength :: SDU -> Int + sduLength sdu = fromIntegral msHeaderLength + fromIntegral (msLength sdu) + + -- Build a batch of SDUs to submit in one go to the bearer. + -- The egress queue is still processed one SDU at the time + -- to ensure that we don't cause starvation. + -- The batch size is either limited by the bearer + -- (e.g the SO_SNDBUF for Socket) or number of SDUs. + -- + buildBatch sdus _ | length sdus >= maxSDUsPerBatch = return sdus + buildBatch sdus sdusLength | sdusLength >= batchSize = return sdus + buildBatch sdus !sdusLength = do + demand_m <- atomically $ tryReadTBQueue egressQueue + case demand_m of + Just (TLSRDemand mpc md d) -> do + sdu <- processSingleWanton egressQueue sduSize mpc md d + buildBatch (sdu:sdus) (sdusLength + sduLength sdu) + Nothing -> return sdus -- | Pull a `maxSDU`s worth of data out out the `Wanton` - if there is -- data remaining requeue the `TranslocationServiceRequest` (this @@ -152,18 +190,17 @@ muxer egressQueue bearer = -- first. processSingleWanton :: MonadSTM m => EgressQueue m - -> Bearer m - -> TimeoutFn m + -> SDUSize -> MiniProtocolNum -> MiniProtocolDir -> Wanton m - -> m () -processSingleWanton egressQueue Bearer { write, sduSize } - timeout mpc md wanton = do + -> m SDU +processSingleWanton egressQueue (SDUSize sduSize) + mpc md wanton = do blob <- atomically $ do -- extract next SDU d <- readTVar (want wanton) - let (frag, rest) = BL.splitAt (fromIntegral (getSDUSize sduSize)) d + let (frag, rest) = BL.splitAt (fromIntegral sduSize) d -- if more to process then enqueue remaining work if BL.null rest then writeTVar (want wanton) BL.empty @@ -184,5 +221,5 @@ processSingleWanton egressQueue Bearer { write, sduSize } }, msBlob = blob } - void $ write timeout sdu + return sdu --paceTransmission tNow diff --git a/network-mux/src/Network/Mux/Ingress.hs b/network-mux/src/Network/Mux/Ingress.hs index d7aa4185450..0d9f9a44cde 100644 --- a/network-mux/src/Network/Mux/Ingress.hs +++ b/network-mux/src/Network/Mux/Ingress.hs @@ -12,6 +12,7 @@ module Network.Mux.Ingress ) where import Data.Array +import Data.ByteString.Builder.Internal (lazyByteStringInsert, lazyByteStringThreshold) import Data.ByteString.Lazy qualified as BL import Data.List (nub) @@ -115,9 +116,16 @@ demuxer ptcls bearer = throwIO (InitiatorOnly (msNum sdu)) Just (MiniProtocolDispatchInfo q qMax) -> atomically $ do - buf <- readTVar q - if BL.length buf + BL.length (msBlob sdu) <= fromIntegral qMax - then writeTVar q $ BL.append buf (msBlob sdu) + (len, buf) <- readTVar q + let len' = len + BL.length (msBlob sdu) + if len' <= fromIntegral qMax + then do + let buf' = if len == 0 + then -- Don't copy the payload if the queue was empty + lazyByteStringInsert $ msBlob sdu + else -- Copy payloads smaller than 128 bytes + buf <> (lazyByteStringThreshold 128 $ msBlob sdu) + writeTVar q $ (len', buf') else throwSTM $ IngressQueueOverRun (msNum sdu) (msDir sdu) lookupMiniProtocol :: MiniProtocolDispatch m diff --git a/network-mux/src/Network/Mux/Trace.hs b/network-mux/src/Network/Mux/Trace.hs index 39a554d94aa..579d90131ca 100644 --- a/network-mux/src/Network/Mux/Trace.hs +++ b/network-mux/src/Network/Mux/Trace.hs @@ -125,6 +125,7 @@ data Trace = | TraceRecvHeaderEnd SDUHeader | TraceRecvDeltaQObservation SDUHeader Time | TraceRecvDeltaQSample Double Int Int Double Double Double Double String + | TraceRecvRaw Int | TraceRecvStart Int | TraceRecvEnd Int | TraceSendStart SDUHeader @@ -159,6 +160,7 @@ instance Show Trace where (unRemoteClockModel mhTimestamp) (show ts) mhLength show (TraceRecvDeltaQSample d sp so dqs dqvm dqvs estR sdud) = printf "Bearer DeltaQ Sample: duration %.3e packets %d sumBytes %d DeltaQ_S %.3e DeltaQ_VMean %.3e DeltaQ_VVar %.3e DeltaQ_estR %.3e sizeDist %s" d sp so dqs dqvm dqvs estR sdud + show (TraceRecvRaw len) = printf "Bearer Receive Raw: length %d" len show (TraceRecvStart len) = printf "Bearer Receive Start: length %d" len show (TraceRecvEnd len) = printf "Bearer Receive End: length %d" len show (TraceSendStart SDUHeader { mhTimestamp, mhNum, mhDir, mhLength }) = printf "Bearer Send Start: ts: 0x%08x (%s) %s length %d" diff --git a/network-mux/src/Network/Mux/Types.hs b/network-mux/src/Network/Mux/Types.hs index a39ec824974..aa60b2b7f21 100644 --- a/network-mux/src/Network/Mux/Types.hs +++ b/network-mux/src/Network/Mux/Types.hs @@ -36,18 +36,23 @@ module Network.Mux.Types , msNum , msDir , msLength + , msHeaderLength , RemoteClockModel (..) , remoteClockPrecision , RuntimeError (..) + , ReadBuffer (..) ) where import Prelude hiding (read) import Control.Exception (Exception, SomeException) +import Data.ByteString.Builder (Builder) import Data.ByteString.Lazy qualified as BL import Data.Functor (void) +import Data.Int import Data.Ix (Ix (..)) import Data.Word +import Foreign.Ptr (Ptr) import Quiet import GHC.Generics (Generic) @@ -168,7 +173,7 @@ data Status -- Mux internal types -- -type IngressQueue m = StrictTVar m BL.ByteString +type IngressQueue m = StrictTVar m (Int64, Builder) -- | The index of a protocol in a MuxApplication, used for array indices newtype MiniProtocolIx = MiniProtocolIx Int @@ -221,6 +226,9 @@ msDir = mhDir . msHeader msLength :: SDU -> Word16 msLength = mhLength . msHeader +-- | Size of a MuxHeader in Bytes +msHeaderLength :: Int64 +msHeaderLength = 8 -- | Low level access to underlying socket or pipe. There are three smart -- constructors: @@ -232,10 +240,14 @@ msLength = mhLength . msHeader data Bearer m = Bearer { -- | Timestamp and send SDU. write :: TimeoutFn m -> SDU -> m Time + -- | Timestamp and send many SDUs. + , writeMany :: TimeoutFn m -> [SDU] -> m Time -- | Read a SDU , read :: TimeoutFn m -> m (SDU, Time) -- | Return a suitable SDU payload size. , sduSize :: SDUSize + -- | Return a suitable batch size + , batchSize :: Int -- | Name of the bearer , name :: String } @@ -288,3 +300,16 @@ data RuntimeError = deriving Show instance Exception RuntimeError + +-- | ReadBuffer for Mux Bearers +-- +-- This is used to read more data than what currently need in one syscall. +-- Any extra data read is cached in rbVar until the next read request. +data ReadBuffer m = ReadBuffer { + -- | Read cache + rbVar :: StrictTVar m BL.ByteString + -- | Buffer, used by the kernel to write the received data into. + , rbBuf :: Ptr Word8 + -- | Size of `rbBuf`. + , rbSize :: Int + } diff --git a/network-mux/test/Test/Mux.hs b/network-mux/test/Test/Mux.hs index f8049b8cb2f..c2a1f17ad54 100644 --- a/network-mux/test/Test/Mux.hs +++ b/network-mux/test/Test/Mux.hs @@ -265,7 +265,7 @@ instance Arbitrary ArbitrarySDU where ts <- arbitrary mid <- choose (6, 0x7fff) -- ClientChainSynWithBlocks with 5 is the highest valid mid mode <- oneof [return 0x0, return 0x8000] - len <- arbitrary + len <- choose (1, 0xffff) p <- arbitrary return $ ArbitraryInvalidSDU (InvalidSDU (Mx.RemoteClockModel ts) (mid .|. mode) len @@ -274,11 +274,12 @@ instance Arbitrary ArbitrarySDU where invalidLenght = do ts <- arbitrary mid <- arbitrary - len <- arbitrary - realLen <- choose (0, 7) -- Size of mux header is 8 + realLen <- choose (0, Mx.msHeaderLength) + len <- if realLen == Mx.msHeaderLength then return 0 + else arbitrary p <- arbitrary - return $ ArbitraryInvalidSDU (InvalidSDU (Mx.RemoteClockModel ts) mid len realLen p) + return $ ArbitraryInvalidSDU (InvalidSDU (Mx.RemoteClockModel ts) mid len (fromIntegral realLen) p) (Mx.SDUDecodeError "") instance Arbitrary Mx.BearerState where @@ -403,10 +404,12 @@ prop_mux_snd_recv_bi (DummyRun messages) = ioProperty $ do (-1) clientTracer QueueChannel { writeQueue = client_w, readQueue = client_r } + Nothing serverBearer <- getBearer makeQueueChannelBearer (-1) serverTracer QueueChannel { writeQueue = server_w, readQueue = server_r } + Nothing let clientApps = [ MiniProtocolInfo { miniProtocolNum = Mx.MiniProtocolNum 2, @@ -509,10 +512,12 @@ prop_mux_snd_recv_compat messages = ioProperty $ do (-1) clientTracer QueueChannel { writeQueue = client_w, readQueue = client_r } + Nothing serverBearer <- getBearer makeQueueChannelBearer (-1) serverTracer QueueChannel { writeQueue = server_w, readQueue = server_r } + Nothing (verify, client_mp, server_mp) <- setupMiniReqRspCompat (return ()) endMpsVar messages @@ -772,10 +777,12 @@ runWithQueues initApps respApps = do (-1) clientTracer QueueChannel { writeQueue = client_w, readQueue = client_r } + Nothing serverBearer <- getBearer makeQueueChannelBearer (-1) serverTracer QueueChannel { writeQueue = server_w, readQueue = server_r } + Nothing runMuxApplication initApps clientBearer respApps serverBearer runWithPipe :: RunMuxApplications @@ -811,8 +818,8 @@ runWithPipe initApps respApps = let clientChannel = Mx.pipeChannelFromNamedPipe hCli serverChannel = Mx.pipeChannelFromNamedPipe hSrv - clientBearer <- getBearer makePipeChannelBearer (-1) clientTracer clientChannel - serverBearer <- getBearer makePipeChannelBearer (-1) serverTracer serverChannel + clientBearer <- getBearer makePipeChannelBearer (-1) clientTracer clientChannel Nothing + serverBearer <- getBearer makePipeChannelBearer (-1) serverTracer serverChannel Nothing Win32.Async.connectNamedPipe hSrv runMuxApplication initApps clientBearer respApps serverBearer @@ -828,8 +835,8 @@ runWithPipe initApps respApps = let clientChannel = Mx.pipeChannelFromHandles rCli wSrv serverChannel = Mx.pipeChannelFromHandles rSrv wCli - clientBearer <- getBearer makePipeChannelBearer (-1) clientTracer clientChannel - serverBearer <- getBearer makePipeChannelBearer (-1) serverTracer serverChannel + clientBearer <- getBearer makePipeChannelBearer (-1) clientTracer clientChannel Nothing + serverBearer <- getBearer makePipeChannelBearer (-1) serverTracer serverChannel Nothing runMuxApplication initApps clientBearer respApps serverBearer #endif @@ -917,10 +924,12 @@ prop_mux_starvation (Uneven response0 response1) = (-1) clientTracer QueueChannel { writeQueue = client_w, readQueue = client_r } + Nothing serverBearer <- getBearer makeQueueChannelBearer (-1) serverTracer QueueChannel { writeQueue = server_w, readQueue = server_r } + Nothing (client_short, server_short) <- setupMiniReqRsp (waitOnAllClients activeMpsVar 2) $ DummyTrace [(request, response1)] @@ -1038,6 +1047,7 @@ encodeInvalidMuxSDU sdu = prop_demux_sdu :: forall m. ( Alternative (STM m) , MonadAsync m + , MonadDelay m , MonadFork m , MonadLabelledSTM m , MonadMask m @@ -1152,6 +1162,7 @@ prop_demux_sdu a = do QueueChannel { writeQueue = server_w, readQueue = server_r } + Nothing serverMux <- Mx.new [serverApp] serverRes <- Mx.runMiniProtocol serverMux (Mx.miniProtocolNum serverApp) (Mx.miniProtocolDir serverApp) @@ -1376,11 +1387,13 @@ prop_mux_start_mX apps runTime = do (-1) nullTracer QueueChannel { writeQueue = mux_w, readQueue = mux_r } + Nothing peerBearer <- getBearer makeQueueChannelBearer (-1) nullTracer QueueChannel { writeQueue = mux_r, readQueue = mux_w } + Nothing prop_mux_start_m bearer (triggerApp peerBearer) checkRes apps runTime where @@ -1429,6 +1442,7 @@ prop_mux_restart_m (DummyRestartingInitiatorApps apps) = do (-1) nullTracer QueueChannel { writeQueue = mux_w, readQueue = mux_r } + Nothing let minis = map (appToInfo Mx.InitiatorDirectionOnly . fst) apps mux <- Mx.new minis @@ -1470,11 +1484,13 @@ prop_mux_restart_m (DummyRestartingResponderApps rapps) = do (-1) nullTracer QueueChannel { writeQueue = mux_w, readQueue = mux_r } + Nothing peerBearer <- getBearer makeQueueChannelBearer (-1) nullTracer QueueChannel { writeQueue = mux_r, readQueue = mux_w } + Nothing let apps = map fst rapps minis = map (appToInfo Mx.ResponderDirectionOnly) apps @@ -1518,11 +1534,13 @@ prop_mux_restart_m (DummyRestartingInitiatorResponderApps rapps) = do (-1) nullTracer QueueChannel { writeQueue = mux_w, readQueue = mux_r } + Nothing peerBearer <- getBearer makeQueueChannelBearer (-1) nullTracer QueueChannel { writeQueue = mux_r, readQueue = mux_w } + Nothing let apps = map fst rapps initMinis = map (appToInfo Mx.InitiatorDirection) apps respMinis = map (appToInfo Mx.ResponderDirection) apps @@ -1780,22 +1798,23 @@ data ClientOrServer = Client | Server deriving Show -data NetworkCtx sock m = NetworkCtx { +data NetworkCtx sock m b = NetworkCtx { ncSocket :: m sock, ncClose :: sock -> m (), - ncMuxBearer :: sock -> m (Mx.Bearer m) + ncMuxBearer :: sock -> (Mx.Bearer m -> m b) -> m b } -withNetworkCtx :: MonadThrow m => NetworkCtx sock m -> (Mx.Bearer m -> m a) -> m a +withNetworkCtx :: MonadThrow m => NetworkCtx sock m a -> (Mx.Bearer m -> m a) -> m a withNetworkCtx NetworkCtx { ncSocket, ncClose, ncMuxBearer } k = - bracket ncSocket ncClose (\sock -> ncMuxBearer sock >>= k) + bracket ncSocket ncClose (\sock -> ncMuxBearer sock k) close_experiment :: forall sock acc req resp m. ( Alternative (STM m) , MonadAsync m + , MonadDelay m , MonadFork m , MonadLabelledSTM m , MonadMask m @@ -1812,8 +1831,8 @@ close_experiment -> FaultInjection -> Tracer m (ClientOrServer, TraceSendRecv (MsgReqResp req resp)) -> Tracer m (ClientOrServer, Mx.Trace) - -> NetworkCtx sock m - -> NetworkCtx sock m + -> NetworkCtx sock m (Either SomeException (Either [resp] [resp])) + -> NetworkCtx sock m (Either SomeException ()) -> [req] -> (acc -> req -> (acc, resp)) -> acc @@ -2051,7 +2070,11 @@ prop_mux_close_io fault reqs fn acc = ioProperty $ withIOManager $ \iocp -> do associateWithIOManager iocp (Right sock) return sock, ncClose = Socket.close, - ncMuxBearer = getBearer makeSocketBearer 10 nullTracer + ncMuxBearer = \sd k -> withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer 10 nullTracer sd buffer + k bearer + ) + } clientCtx = NetworkCtx { ncSocket = do @@ -2064,7 +2087,11 @@ prop_mux_close_io fault reqs fn acc = ioProperty $ withIOManager $ \iocp -> do Socket.close sock return sock, ncClose = Socket.close, - ncMuxBearer = getBearer makeSocketBearer 10 nullTracer + ncMuxBearer = \sd k -> withReadBufferIO (\buffer -> do + bearer <- getBearer makeSocketBearer 10 nullTracer sd buffer + k bearer + ) + } close_experiment True @@ -2113,18 +2140,18 @@ prop_mux_close_sim fault (Positive sduSize_) reqs fn acc = clientCtx = NetworkCtx { ncSocket = return chann, ncClose = acClose, - ncMuxBearer = pure - . attenuationChannelAsBearer - sduSize sduTimeout - nullTracer + ncMuxBearer = \fd k -> + k $ attenuationChannelAsBearer + sduSize sduTimeout + nullTracer fd } serverCtx = NetworkCtx { ncSocket = return chann', ncClose = acClose, - ncMuxBearer = pure - . attenuationChannelAsBearer - sduSize sduTimeout - nullTracer + ncMuxBearer = \fd k -> + k $ attenuationChannelAsBearer + sduSize sduTimeout + nullTracer fd } close_experiment False diff --git a/ouroboros-network-framework/demo/connection-manager.hs b/ouroboros-network-framework/demo/connection-manager.hs index af64c266c90..1b07ec0d818 100644 --- a/ouroboros-network-framework/demo/connection-manager.hs +++ b/ouroboros-network-framework/demo/connection-manager.hs @@ -236,6 +236,7 @@ withBidirectionalConnectionManager snocket makeBearer socket CM.addressType = \_ -> Just IPv4Address, CM.snocket = snocket, CM.makeBearer = makeBearer, + CM.withBuffer = \f -> f Nothing, CM.configureSocket = \_ _ -> return (), CM.timeWaitTimeout = timeWaitTimeout, CM.outboundIdleTimeout = protocolIdleTimeout, diff --git a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs index 7b267b25204..c24dfa9be0f 100644 --- a/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/io-tests/Test/Ouroboros/Network/Socket.hs @@ -349,7 +349,7 @@ prop_socket_recv_error f rerr = localAddress = Socket.addrAddress muxAddress, remoteAddress } - bearer <- Mx.getBearer Mx.makeSocketBearer timeout nullTracer sd' + bearer <- Mx.getBearer Mx.makeSocketBearer timeout nullTracer sd' Nothing _ <- async $ do threadDelay 0.1 atomically $ putTMVar lock () @@ -449,7 +449,7 @@ prop_socket_send_error rerr = let sduTimeout = if rerr == SendSDUTimeout then 0.10 else (-1) -- No timeout blob = BL.pack $ replicate 0xffff 0xa5 - bearer <- Mx.getBearer Mx.makeSocketBearer sduTimeout nullTracer sd' + bearer <- Mx.getBearer Mx.makeSocketBearer sduTimeout nullTracer sd' Nothing Mx.withTimeoutSerial $ \timeout -> -- send maximum mux sdus until we've filled the window. replicateM 100 $ do diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs index 9ffc52a5eee..5b64a2a7f8f 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/ConnectionManager.hs @@ -345,12 +345,14 @@ newtype FD m = FD { fdState :: StrictTVar m FDState } makeFDBearer :: MonadDelay m => MakeBearer m (FD m) -makeFDBearer = MakeBearer $ \_ _ _ -> +makeFDBearer = MakeBearer $ \_ _ _ _ -> return Mx.Bearer { - Mx.write = \_ _ -> getMonotonicTime, - Mx.read = \_ -> forever (threadDelay 3600), - Mx.sduSize = Mx.SDUSize 1500, - Mx.name = "FD" + Mx.write = \_ _ -> getMonotonicTime, + Mx.writeMany = \_ _ -> getMonotonicTime, + Mx.read = \_ -> forever (threadDelay 3600), + Mx.sduSize = Mx.SDUSize 1500, + batchSize = 1500, + Mx.name = "FD" } -- | We only keep exceptions here which should not be handled by the test @@ -610,7 +612,7 @@ mkConnectionHandler snocket = handler where handler :: ConnectionHandlerFn handlerTrace (FD m) Addr (Handle m) Void Version VersionData m - handler _ fd promise _ ConnectionId { remoteAddress } _ = + handler _ fd promise _ ConnectionId { remoteAddress } _ _ = MaskedAction $ \unmask -> do threadId <- myThreadId let addr = getTestAddress remoteAddress @@ -767,6 +769,7 @@ prop_valid_transitions (Fixed rnd) (SkewedBool bindToLocalAddress) scheduleMap = CM.addressType = \_ -> Just IPv4Address, CM.snocket = snocket, CM.makeBearer = makeFDBearer, + CM.withBuffer = \f -> f Nothing, CM.configureSocket = \_ _ -> return (), CM.connectionDataFlow = id, CM.prunePolicy = simplePrunePolicy, diff --git a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs index be1b414504a..dc9b214dab9 100644 --- a/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Ouroboros/Network/Socket.hs @@ -351,7 +351,7 @@ prop_socket_recv_error f rerr = localAddress = Socket.addrAddress muxAddress, remoteAddress } - bearer <- Mx.getBearer Mx.makeSocketBearer timeout nullTracer sd' + bearer <- Mx.getBearer Mx.makeSocketBearer timeout nullTracer sd' Nothing _ <- async $ do threadDelay 0.1 atomically $ putTMVar lock () @@ -451,7 +451,7 @@ prop_socket_send_error rerr = let sduTimeout = if rerr == SendSDUTimeout then 0.10 else (-1) -- No timeout blob = BL.pack $ replicate 0xffff 0xa5 - bearer <- Mx.getBearer Mx.makeSocketBearer sduTimeout nullTracer sd' + bearer <- Mx.getBearer Mx.makeSocketBearer sduTimeout nullTracer sd' Nothing withTimeoutSerial $ \timeout -> -- send maximum mux sdus until we've filled the window. replicateM 100 $ do diff --git a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs index 5929de0771d..15be2b8eae7 100644 --- a/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/sim-tests/Test/Simulation/Network/Snocket.hs @@ -37,6 +37,7 @@ import Codec.Serialise (Serialise) import Codec.Serialise qualified as Serialise import Data.ByteString.Lazy (ByteString) +import Data.ByteString.Lazy qualified as BL import Data.Foldable (traverse_) import Data.Functor (void) import Data.Map qualified as Map @@ -228,7 +229,7 @@ clientServerSimulation payloads = (accepted, accept1) <- runAccept accept0 case accepted of Accepted fd' remoteAddr -> do - bearer <- getBearer makeFDBearer 10 nullTracer fd' + bearer <- getBearer makeFDBearer 10 nullTracer fd' Nothing thread <- async $ handleConnection bearer remoteAddr `finally` close snocket fd' @@ -307,7 +308,7 @@ clientServerSimulation payloads = (\channel -> runPeer tr codecReqResp channel clientPeer) - bearer <- Mx.getBearer makeFDBearer 10 nullTracer fd + bearer <- Mx.getBearer makeFDBearer 10 nullTracer fd Nothing -- kill mux as soon as the client returns withAsync @@ -558,6 +559,7 @@ prop_simultaneous_open defaultBearerInfo = -- prop_self_connect :: ByteString -> Property prop_self_connect payload = + BL.length payload > 0 && BL.length payload <= 0xffff ==> runSimOrThrow sim where addr :: TestAddress Int @@ -575,7 +577,7 @@ prop_self_connect payload = $ \fd -> do bind snocket fd addr connect snocket fd addr - bearer <- getBearer makeFDBearer 10 nullTracer fd + bearer <- getBearer makeFDBearer 10 nullTracer fd Nothing let channel = bearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir send channel payload payload' <- recv channel diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs index 650b4299043..ad61b7bdf87 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionHandler.hs @@ -288,6 +288,7 @@ makeConnectionHandler muxTracer singMuxMode connectionId@ConnectionId { localAddress , remoteAddress } mkMuxBearer + withBuffer = MaskedAction { runWithUnmask } where runWithUnmask :: (forall x. m x -> m x) -> m () @@ -298,7 +299,7 @@ makeConnectionHandler muxTracer singMuxMode , "-" , show remoteAddress ]) - handshakeBearer <- mkMuxBearer sduHandshakeTimeout socket + handshakeBearer <- mkMuxBearer sduHandshakeTimeout socket Nothing hsResult <- unmask (runHandshakeClient handshakeBearer connectionId @@ -331,9 +332,11 @@ makeConnectionHandler muxTracer singMuxMode hVersionData = agreedOptions } atomically $ writePromise (Right $ HandshakeConnectionResult handle (versionNumber, agreedOptions)) - bearer <- mkMuxBearer sduTimeout socket - Mx.run (Mx.WithBearer connectionId `contramap` muxTracer) - mux bearer + withBuffer (\buffer -> do + bearer <- mkMuxBearer sduTimeout socket buffer + Mx.run (Mx.WithBearer connectionId `contramap` muxTracer) + mux bearer + ) Right (HandshakeQueryResult vMap) -> do atomically $ writePromise (Right HandshakeConnectionQuery) @@ -357,6 +360,7 @@ makeConnectionHandler muxTracer singMuxMode connectionId@ConnectionId { localAddress , remoteAddress } mkMuxBearer + withBuffer = MaskedAction { runWithUnmask } where runWithUnmask :: (forall x. m x -> m x) -> m () @@ -367,7 +371,7 @@ makeConnectionHandler muxTracer singMuxMode , "-" , show remoteAddress ]) - handshakeBearer <- mkMuxBearer sduHandshakeTimeout socket + handshakeBearer <- mkMuxBearer sduHandshakeTimeout socket Nothing hsResult <- unmask (runHandshakeServer handshakeBearer connectionId @@ -401,9 +405,11 @@ makeConnectionHandler muxTracer singMuxMode hVersionData = agreedOptions } atomically $ writePromise (Right $ HandshakeConnectionResult handle (versionNumber, agreedOptions)) - bearer <- mkMuxBearer sduTimeout socket - Mx.run (Mx.WithBearer connectionId `contramap` muxTracer) + withBuffer (\buffer -> do + bearer <- mkMuxBearer sduTimeout socket buffer + Mx.run (Mx.WithBearer connectionId `contramap` muxTracer) mux bearer + ) Right (HandshakeQueryResult vMap) -> do atomically $ writePromise (Right HandshakeConnectionQuery) traceWith tracer $ TrHandshakeQuery vMap diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs index 5eff5430783..a3ff65acb34 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Core.hs @@ -114,6 +114,9 @@ data Arguments handlerTrace socket peerAddr handle handleError versionNumber ver -- makeBearer :: MakeBearer m socket, + -- | With a ReadBuffer + withBuffer :: ((Maybe (Mx.ReadBuffer m) -> m ()) -> m ()), + -- | Socket configuration. -- configureSocket :: socket -> Maybe peerAddr -> m (), @@ -395,6 +398,7 @@ with args@Arguments { addressType, snocket, makeBearer, + withBuffer, configureSocket, timeWaitTimeout, outboundIdleTimeout, @@ -623,7 +627,8 @@ with args@Arguments { (\bearerTimeout -> getBearer makeBearer bearerTimeout - (Mx.WithBearer connId `contramap` muxTracer))) + (Mx.WithBearer connId `contramap` muxTracer)) + withBuffer) unmask `finally` cleanup where diff --git a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs index a11a6bb6a85..b11e4a73e0a 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/ConnectionManager/Types.hs @@ -356,7 +356,8 @@ type ConnectionHandlerFn handlerTrace socket peerAddr handle handleError version -> PromiseWriter m (Either handleError (HandshakeConnectionResult handle (versionNumber, versionData))) -> Tracer m handlerTrace -> ConnectionId peerAddr - -> (DiffTime -> socket -> m (Mux.Bearer m)) + -> (DiffTime -> socket -> Maybe (Mux.ReadBuffer m) -> m (Mux.Bearer m)) + -> ((Maybe (Mux.ReadBuffer m) -> m ()) -> m ()) -> MaskedAction m () data HandshakeConnectionResult handle version diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs index 419d89ca842..3c36141a9ef 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Snocket.hs @@ -413,11 +413,11 @@ makeLocalRawBearer = MakeRawBearer (return . localSocketToRawBearer) makeLocalBearer :: MakeBearer IO LocalSocket #if defined(mingw32_HOST_OS) -makeLocalBearer = MakeBearer $ \sduTimeout tracer LocalSocket { getLocalHandle = fd } -> - getBearer makeNamedPipeBearer sduTimeout tracer fd +makeLocalBearer = MakeBearer $ \sduTimeout tracer LocalSocket { getLocalHandle = fd } rb -> + getBearer makeNamedPipeBearer sduTimeout tracer fd rb #else -makeLocalBearer = MakeBearer $ \sduTimeout tracer (LocalSocket fd) -> - getBearer makeSocketBearer sduTimeout tracer fd +makeLocalBearer = MakeBearer $ \sduTimeout tracer (LocalSocket fd) rb -> + getBearer makeSocketBearer sduTimeout tracer fd rb #endif -- | System dependent LocalSnocket diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index 247ed78b9e2..5dd1c72abef 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -409,7 +409,7 @@ connectToNodeWithMux' muxTracer <- initDeltaQTracer' $ Mx.WithBearer connectionId `contramap` nctMuxTracer ts_start <- getMonotonicTime - handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer sd + handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer sd Nothing app_e <- runHandshakeClient handshakeBearer @@ -436,10 +436,12 @@ connectToNodeWithMux' Right (HandshakeNegotiationResult app versionNumber agreedOptions) -> do traceWith muxTracer $ Mx.TraceHandshakeClientEnd (diffTime ts_end ts_start) - bearer <- Mx.getBearer makeBearer sduTimeout muxTracer sd - mux <- Mx.new (toMiniProtocolInfos app) - withAsync (Mx.run muxTracer mux bearer) $ \aid -> - k connectionId versionNumber agreedOptions app mux aid + Mx.withReadBufferIO (\buffer -> do + bearer <- Mx.getBearer makeBearer sduTimeout muxTracer sd buffer + mux <- Mx.new (toMiniProtocolInfos app) + withAsync (Mx.run muxTracer mux bearer) $ \aid -> + k connectionId versionNumber agreedOptions app mux aid + ) Right (HandshakeQueryResult _vMap) -> do traceWith muxTracer $ Mx.TraceHandshakeClientEnd (diffTime ts_end ts_start) @@ -584,7 +586,7 @@ beginConnection makeBearer muxTracer handshakeTracer handshakeCodec handshakeTim traceWith muxTracer' $ Mx.TraceHandshakeStart - handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer' sd + handshakeBearer <- Mx.getBearer makeBearer sduHandshakeTimeout muxTracer' sd Nothing app_e <- runHandshakeServer handshakeBearer @@ -610,10 +612,12 @@ beginConnection makeBearer muxTracer handshakeTracer handshakeCodec handshakeTim Right (HandshakeNegotiationResult (SomeResponderApplication app) versionNumber agreedOptions) -> do traceWith muxTracer' Mx.TraceHandshakeServerEnd - bearer <- Mx.getBearer makeBearer sduTimeout muxTracer' sd - mux <- Mx.new (toMiniProtocolInfos app) - withAsync (Mx.run muxTracer' mux bearer) $ \aid -> - void $ simpleMuxCallback connectionId versionNumber agreedOptions app mux aid + Mx.withReadBufferIO (\buffer -> do + bearer <- Mx.getBearer makeBearer sduTimeout muxTracer' sd buffer + mux <- Mx.new (toMiniProtocolInfos app) + withAsync (Mx.run muxTracer' mux bearer) $ \aid -> + void $ simpleMuxCallback connectionId versionNumber agreedOptions app mux aid + ) Right (HandshakeQueryResult _vMap) -> do traceWith muxTracer' Mx.TraceHandshakeServerEnd diff --git a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs index 2140d70e274..8a791d246e1 100644 --- a/ouroboros-network-framework/src/Simulation/Network/Snocket.hs +++ b/ouroboros-network-framework/src/Simulation/Network/Snocket.hs @@ -679,7 +679,7 @@ makeFDBearer :: forall addr m. , Show addr ) => MakeBearer m (FD m (TestAddress addr)) -makeFDBearer = MakeBearer $ \sduTimeout muxTracer FD { fdVar } -> do +makeFDBearer = MakeBearer $ \sduTimeout muxTracer FD { fdVar } _ -> do fd_ <- atomically (readTVar fdVar) case fd_ of FDUninitialised {} -> diff --git a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Experiments.hs b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Experiments.hs index a6a49ee14e6..a0b78211e3a 100644 --- a/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Experiments.hs +++ b/ouroboros-network-framework/testlib/Test/Ouroboros/Network/ConnectionManager/Experiments.hs @@ -291,6 +291,7 @@ withInitiatorOnlyConnectionManager name timeouts trTracer tracer stdGen snocket CM.addressType = \_ -> Just IPv4Address, CM.snocket = snocket, CM.makeBearer = makeBearer, + CM.withBuffer = \f -> f Nothing, CM.configureSocket = \_ _ -> return (), CM.connectionDataFlow = \(DataFlowProtocolData df _) -> df, CM.prunePolicy = simplePrunePolicy, @@ -481,6 +482,7 @@ withBidirectionalConnectionManager name timeouts CM.addressType = \_ -> Just IPv4Address, CM.snocket = snocket, CM.makeBearer = makeBearer, + CM.withBuffer = \f -> f Nothing, CM.configureSocket = \sock _ -> confSock sock, CM.timeWaitTimeout = tTimeWaitTimeout timeouts, CM.outboundIdleTimeout = tOutboundIdleTimeout timeouts, diff --git a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs index 151bfa67e40..d544f7299fa 100644 --- a/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs +++ b/ouroboros-network-protocols/testlib/Ouroboros/Network/Protocol/Handshake/Test.hs @@ -1329,11 +1329,13 @@ prop_channel_simultaneous_open_sim codec versionDataCodec nullTracer -- (("client",) `contramap` Tracer Debug.traceShowM) fdConn + Nothing bearer' <- Mx.getBearer makeFDBearer 1 nullTracer -- (("server",) `contramap` Tracer Debug.traceShowM) fdConn' + Nothing let chann = bearerAsChannel bearer (MiniProtocolNum 0) InitiatorDir chann' = bearerAsChannel bearer' (MiniProtocolNum 0) InitiatorDir res <- prop_channel_simultaneous_open diff --git a/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs b/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs index 5866ab49e3b..1e9cb868b75 100644 --- a/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs +++ b/ouroboros-network/io-tests/Test/Ouroboros/Network/Pipe.hs @@ -193,8 +193,8 @@ demo chain0 updates = do , ChainSync.chainSyncServerPeer server ) - clientBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan1 - serverBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan2 + clientBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan1 Nothing + serverBearer <- Mx.getBearer Mx.makePipeChannelBearer (-1) activeTracer chan2 Nothing _ <- async $ do clientMux <- Mx.new (toMiniProtocolInfos consumerApp) diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs index a16be96ed2c..6a51cafe772 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Mux.hs @@ -156,12 +156,14 @@ demo chain0 updates delay = do Mx.QueueChannel { Mx.writeQueue = client_w, Mx.readQueue = client_r } + Nothing serverBearer <- Mx.getBearer Mx.makeQueueChannelBearer (-1) activeTracer Mx.QueueChannel { Mx.writeQueue = server_w, Mx.readQueue = server_r } + Nothing clientAsync <- async $ do clientMux <- Mx.new (toMiniProtocolInfos consumerApp) diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs index 6315144978d..7988a04245e 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet.hs @@ -2927,7 +2927,7 @@ prop_diffusion_async_demotions ioSimTrace traceNumber = demotionOpportunitiesTooLong :: Signal (Set NtNAddr) demotionOpportunitiesTooLong = - Signal.keyedTimeout 1 id demotionOpportunities + Signal.keyedTimeout 10 id demotionOpportunities in signalProperty 20 show Set.null diff --git a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Node.hs b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Node.hs index ba823f49749..4e9c36e57c5 100644 --- a/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Node.hs +++ b/ouroboros-network/sim-tests-lib/Test/Ouroboros/Network/Testnet/Node.hs @@ -210,6 +210,7 @@ run blockGeneratorArgs limits ni na tracersExtra tracerBlockFetch = interfaces = Diff.P2P.Interfaces { Diff.P2P.diNtnSnocket = iNtnSnocket ni , Diff.P2P.diNtnBearer = iNtnBearer ni + , Diff.P2P.diWithBuffer = \f -> f Nothing , Diff.P2P.diNtnConfigureSocket = \_ _ -> return () , Diff.P2P.diNtnConfigureSystemdSocket = \_ _ -> return () diff --git a/ouroboros-network/src/Ouroboros/Network/Diffusion/P2P.hs b/ouroboros-network/src/Ouroboros/Network/Diffusion/P2P.hs index 15cd06fb02c..9054c912291 100644 --- a/ouroboros-network/src/Ouroboros/Network/Diffusion/P2P.hs +++ b/ouroboros-network/src/Ouroboros/Network/Diffusion/P2P.hs @@ -66,6 +66,8 @@ import Network.Socket (Socket) import Network.Socket qualified as Socket import Network.Mux qualified as Mx +import Network.Mux.Types (ReadBuffer) +import Network.Mux.Bearer (withReadBufferIO) import Ouroboros.Network.Snocket (FileDescriptor, LocalAddress, LocalSocket (..), Snocket, localSocketFileDescriptor, @@ -465,6 +467,10 @@ data Interfaces ntnFd ntnAddr ntnVersion ntnVersionData diNtnBearer :: Mx.MakeBearer m ntnFd, + -- | readbuffer + diWithBuffer + :: ((Maybe (ReadBuffer m) -> m ()) -> m ()), + -- | node-to-node socket configuration -- -- It is used by both inbound and outbound connection. The address is @@ -619,6 +625,7 @@ runM runM Interfaces { diNtnSnocket , diNtnBearer + , diWithBuffer , diNtnConfigureSocket , diNtnConfigureSystemdSocket , diNtnHandshakeArguments @@ -804,6 +811,7 @@ runM Interfaces CM.addressType = const Nothing, CM.snocket = diNtcSnocket, CM.makeBearer = diNtcBearer, + CM.withBuffer = diWithBuffer, CM.configureSocket = \_ _ -> return (), CM.timeWaitTimeout = local_TIME_WAIT_TIMEOUT, CM.outboundIdleTimeout = local_PROTOCOL_IDLE_TIMEOUT, @@ -935,6 +943,7 @@ runM Interfaces CM.addressType = diNtnAddressType, CM.snocket = diNtnSnocket, CM.makeBearer = diNtnBearer, + CM.withBuffer = diWithBuffer, CM.configureSocket = diNtnConfigureSocket, CM.connectionDataFlow = diNtnDataFlow, CM.prunePolicy = prunePolicy, @@ -1304,6 +1313,7 @@ run tracers tracersExtra args argsExtra apps appsExtra = do Interfaces { diNtnSnocket = Snocket.socketSnocket iocp, diNtnBearer = makeSocketBearer, + diWithBuffer = withReadBufferIO, diNtnConfigureSocket = configureSocket, diNtnConfigureSystemdSocket = configureSystemdSocket