@@ -22,14 +22,19 @@ module Ouroboros.Network.Driver.Limits
22
22
-- * Normal peers
23
23
, runPeerWithLimits
24
24
, runPipelinedPeerWithLimits
25
+ , runPeerWithLimitsRnd
26
+ , runPipelinedPeerWithLimitsRnd
25
27
, TraceSendRecv (.. )
26
28
-- * Driver utilities
27
29
, driverWithLimits
28
30
, runConnectedPeersWithLimits
29
31
, runConnectedPipelinedPeersWithLimits
32
+ , runConnectedPeersWithLimitsRnd
33
+ , runConnectedPipelinedPeersWithLimitsRnd
30
34
) where
31
35
32
36
import Data.Maybe (fromMaybe )
37
+ import System.Random
33
38
34
39
import Control.Monad.Class.MonadAsync
35
40
import Control.Monad.Class.MonadFork
@@ -105,6 +110,65 @@ driverWithLimits tracer timeoutFn
105
110
Nothing -> throwIO (ExceededTimeLimit tok)
106
111
107
112
113
+ driverWithLimitsRnd :: forall ps (pr :: PeerRole ) failure bytes m .
114
+ ( MonadThrow m
115
+ , ShowProxy ps
116
+ , forall (st' :: ps ) tok . tok ~ StateToken st' => Show tok
117
+ , Show failure
118
+ )
119
+ => Tracer m (TraceSendRecv ps )
120
+ -> TimeoutFn m
121
+ -> StdGen
122
+ -> Codec ps failure m bytes
123
+ -> ProtocolSizeLimits ps bytes
124
+ -> (StdGen -> ProtocolTimeLimits ps )
125
+ -> Channel m bytes
126
+ -> Driver ps pr (Maybe bytes , StdGen ) m
127
+ driverWithLimitsRnd tracer timeoutFn rnd0
128
+ Codec {encode, decode}
129
+ ProtocolSizeLimits {sizeLimitForState, dataSize}
130
+ genProtocolTimeLimits
131
+ channel@ Channel {send} =
132
+ Driver { sendMessage, recvMessage, initialDState = (Nothing , rnd0) }
133
+ where
134
+ sendMessage :: forall (st :: ps ) (st' :: ps ).
135
+ StateTokenI st
136
+ => ActiveState st
137
+ => WeHaveAgencyProof pr st
138
+ -> Message ps st st'
139
+ -> m ()
140
+ sendMessage ! _ msg = do
141
+ send (encode msg)
142
+ traceWith tracer (TraceSendMsg (AnyMessage msg))
143
+
144
+
145
+ recvMessage :: forall (st :: ps ).
146
+ StateTokenI st
147
+ => ActiveState st
148
+ => TheyHaveAgencyProof pr st
149
+ -> (Maybe bytes , StdGen )
150
+ -> m (SomeMessage st , (Maybe bytes , StdGen ))
151
+ recvMessage ! _ (trailing, ! rnd) = do
152
+ let tok = stateToken
153
+ decoder <- decode tok
154
+ let sizeLimit = sizeLimitForState @ st stateToken
155
+
156
+ let (rnd', rnd'') = split rnd
157
+ ProtocolTimeLimits {timeLimitForState} = genProtocolTimeLimits rnd''
158
+ timeLimit = fromMaybe (- 1 ) $ timeLimitForState @ st stateToken
159
+ result <- timeoutFn timeLimit $
160
+ runDecoderWithLimit sizeLimit dataSize
161
+ channel trailing decoder
162
+
163
+ case result of
164
+ Just (Right (x@ (SomeMessage msg), trailing')) -> do
165
+ traceWith tracer (TraceRecvMsg (AnyMessage msg))
166
+ return (x, (trailing', rnd'))
167
+ Just (Left (Just failure)) -> throwIO (DecoderFailure tok failure)
168
+ Just (Left Nothing ) -> throwIO (ExceededSizeLimit tok)
169
+ Nothing -> throwIO (ExceededTimeLimit tok)
170
+
171
+
108
172
runDecoderWithLimit
109
173
:: forall m bytes failure a . Monad m
110
174
=> Word
@@ -152,7 +216,8 @@ runDecoderWithLimit limit size Channel{recv} =
152
216
Just bs -> do let sz' = sz + size bs
153
217
go sz' Nothing =<< k (Just bs)
154
218
155
-
219
+ -- | Run a peer with limits.
220
+ --
156
221
runPeerWithLimits
157
222
:: forall ps (st :: ps ) pr failure bytes m a .
158
223
( MonadAsync m
@@ -175,6 +240,37 @@ runPeerWithLimits tracer codec slimits tlimits channel peer =
175
240
withTimeoutSerial $ \ timeoutFn ->
176
241
let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel
177
242
in runPeerWithDriver driver peer
243
+
244
+
245
+ -- | Run a peer with limits. 'ProtocolTimeLimits' have access to
246
+ -- a pseudorandom generator.
247
+ --
248
+ runPeerWithLimitsRnd
249
+ :: forall ps (st :: ps ) pr failure bytes m a .
250
+ ( MonadAsync m
251
+ , MonadFork m
252
+ , MonadMask m
253
+ , MonadThrow (STM m )
254
+ , MonadTimer m
255
+ , ShowProxy ps
256
+ , forall (st' :: ps ) stok . stok ~ StateToken st' => Show stok
257
+ , Show failure
258
+ )
259
+ => Tracer m (TraceSendRecv ps )
260
+ -> StdGen
261
+ -> Codec ps failure m bytes
262
+ -> ProtocolSizeLimits ps bytes
263
+ -> (StdGen -> ProtocolTimeLimits ps )
264
+ -> Channel m bytes
265
+ -> Peer ps pr NonPipelined st m a
266
+ -> m (a , Maybe bytes )
267
+ runPeerWithLimitsRnd tracer rnd codec slimits tlimits channel peer =
268
+ withTimeoutSerial $ \ timeoutFn ->
269
+ let driver = driverWithLimitsRnd tracer timeoutFn rnd codec slimits tlimits channel
270
+ in (\ (a, (trailing, _)) -> (a, trailing))
271
+ <$> runPeerWithDriver driver peer
272
+
273
+
178
274
-- | Run a pipelined peer with the given channel via the given codec.
179
275
--
180
276
-- This runs the peer to completion (if the protocol allows for termination).
@@ -206,6 +302,35 @@ runPipelinedPeerWithLimits tracer codec slimits tlimits channel peer =
206
302
in runPipelinedPeerWithDriver driver peer
207
303
208
304
305
+ -- | Like 'runPipelinedPeerWithLimits' but time limits have access to
306
+ -- a pseudorandom generator.
307
+ --
308
+ runPipelinedPeerWithLimitsRnd
309
+ :: forall ps (st :: ps ) pr failure bytes m a .
310
+ ( MonadAsync m
311
+ , MonadFork m
312
+ , MonadMask m
313
+ , MonadTimer m
314
+ , MonadThrow (STM m )
315
+ , ShowProxy ps
316
+ , forall (st' :: ps ) stok . stok ~ StateToken st' => Show stok
317
+ , Show failure
318
+ )
319
+ => Tracer m (TraceSendRecv ps )
320
+ -> StdGen
321
+ -> Codec ps failure m bytes
322
+ -> ProtocolSizeLimits ps bytes
323
+ -> (StdGen -> ProtocolTimeLimits ps )
324
+ -> Channel m bytes
325
+ -> PeerPipelined ps pr st m a
326
+ -> m (a , Maybe bytes )
327
+ runPipelinedPeerWithLimitsRnd tracer rnd codec slimits tlimits channel peer =
328
+ withTimeoutSerial $ \ timeoutFn ->
329
+ let driver = driverWithLimitsRnd tracer timeoutFn rnd codec slimits tlimits channel
330
+ in (\ (a, (trailing, _)) -> (a, trailing))
331
+ <$> runPipelinedPeerWithDriver driver peer
332
+
333
+
209
334
-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.
210
335
-- The client side is using 'driverWithLimits'.
211
336
--
@@ -248,6 +373,41 @@ runConnectedPeersWithLimits createChannels tracer codec slimits tlimits client s
248
373
tracerServer = contramap ((,) Server ) tracer
249
374
250
375
376
+ runConnectedPeersWithLimitsRnd
377
+ :: forall ps pr st failure bytes m a b .
378
+ ( MonadAsync m
379
+ , MonadFork m
380
+ , MonadMask m
381
+ , MonadTimer m
382
+ , MonadThrow (STM m )
383
+ , Exception failure
384
+ , ShowProxy ps
385
+ , forall (st' :: ps ) sing . sing ~ StateToken st' => Show sing
386
+ )
387
+ => m (Channel m bytes , Channel m bytes )
388
+ -> Tracer m (Role , TraceSendRecv ps )
389
+ -> StdGen
390
+ -> Codec ps failure m bytes
391
+ -> ProtocolSizeLimits ps bytes
392
+ -> (StdGen -> ProtocolTimeLimits ps )
393
+ -> Peer ps pr NonPipelined st m a
394
+ -> Peer ps (FlipAgency pr ) NonPipelined st m b
395
+ -> m (a , b )
396
+ runConnectedPeersWithLimitsRnd createChannels tracer rnd codec slimits tlimits client server =
397
+ createChannels >>= \ (clientChannel, serverChannel) ->
398
+
399
+ (do labelThisThread " client"
400
+ fst <$> runPeerWithLimitsRnd
401
+ tracerClient rnd codec slimits tlimits
402
+ clientChannel client)
403
+ `concurrently`
404
+ (do labelThisThread " server"
405
+ fst <$> runPeer tracerServer codec serverChannel server)
406
+ where
407
+ tracerClient = contramap ((,) Client ) tracer
408
+ tracerServer = contramap ((,) Server ) tracer
409
+
410
+
251
411
-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.
252
412
-- The client side is using 'driverWithLimits'.
253
413
--
@@ -286,3 +446,36 @@ runConnectedPipelinedPeersWithLimits createChannels tracer codec slimits tlimits
286
446
where
287
447
tracerClient = contramap ((,) Client ) tracer
288
448
tracerServer = contramap ((,) Server ) tracer
449
+
450
+
451
+ runConnectedPipelinedPeersWithLimitsRnd
452
+ :: forall ps pr st failure bytes m a b .
453
+ ( MonadAsync m
454
+ , MonadFork m
455
+ , MonadMask m
456
+ , MonadTimer m
457
+ , MonadThrow (STM m )
458
+ , Exception failure
459
+ , ShowProxy ps
460
+ , forall (st' :: ps ) sing . sing ~ StateToken st' => Show sing
461
+ )
462
+ => m (Channel m bytes , Channel m bytes )
463
+ -> Tracer m (Role , TraceSendRecv ps )
464
+ -> StdGen
465
+ -> Codec ps failure m bytes
466
+ -> ProtocolSizeLimits ps bytes
467
+ -> (StdGen -> ProtocolTimeLimits ps )
468
+ -> PeerPipelined ps pr st m a
469
+ -> Peer ps (FlipAgency pr ) NonPipelined st m b
470
+ -> m (a , b )
471
+ runConnectedPipelinedPeersWithLimitsRnd createChannels tracer rnd codec slimits tlimits client server =
472
+ createChannels >>= \ (clientChannel, serverChannel) ->
473
+
474
+ (fst <$> runPipelinedPeerWithLimitsRnd
475
+ tracerClient rnd codec slimits tlimits
476
+ clientChannel client)
477
+ `concurrently`
478
+ (fst <$> runPeer tracerServer codec serverChannel server)
479
+ where
480
+ tracerClient = contramap ((,) Client ) tracer
481
+ tracerServer = contramap ((,) Server ) tracer
0 commit comments