Skip to content

Commit 4b47e98

Browse files
committed
WIP: Annotated decoder
This is an experiment to provide `runAnnotatedPeer`, which is like `runPeer' but allows us to run a decoder which has access to bytes used when decoding a message. This allows one to record ByteString from which a piece of data was decoded, e.g. for each `tx` inside `MsgReplyTxs`. The `Codec` type in `typed-protocols` was generalised for this purpose. The core functionality is implemented in `runAnnotatedDecoderWithChannel` which runs `AnnotatedCodec` against a `Channel` which does incremental decoding & recording bytes used so far. We also expose `runAnnotatedPeer` which runs a `Peer` against `Channel` using an `AnnotatedCodec` (using `annotatedDriverSimple`). TODO: * `runAnnotatedPeerWithLimits` * `runAnnotatedPipelinedPeerWithLimits` It's actually the last one that we will need in `tx-submission`. * Add ``` data WithBytes a { encoded :: ByteString, decoded :: a ``` and generalise `codecTxSubmission2` so that it can be used to used with annotator and without it - it might require two separate function, but I think it can be generated from one more general function (so we don't need to maintain two codecs). TODO: design & implement quickcheck properties
1 parent ed11046 commit 4b47e98

File tree

3 files changed

+168
-30
lines changed
  • ouroboros-network-framework/src/Ouroboros/Network/Driver
  • ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2

3 files changed

+168
-30
lines changed

cabal.project

+9
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,12 @@ package network-mux
5454
package ouroboros-network
5555
flags: +asserts +cddl
5656

57+
58+
source-repository-package
59+
type: git
60+
location: https://github.com/input-output-hk/typed-protocols
61+
tag: d0c0668048be5b9878917180d7a0641861216bec
62+
subdir: typed-protocols
63+
typed-protocols-cborg
64+
allow-newer: typed-protocols:io-classes
65+

ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs

+136-18
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
{-# LANGUAGE QuantifiedConstraints #-}
77
{-# LANGUAGE RankNTypes #-}
88
{-# LANGUAGE ScopedTypeVariables #-}
9-
{-# LANGUAGE StandaloneDeriving #-}
109
{-# LANGUAGE TypeFamilies #-}
1110
-- @UndecidableInstances@ extensions is required for defining @Show@ instance
1211
-- of @'TraceSendRecv'@.
@@ -19,10 +18,12 @@ module Ouroboros.Network.Driver.Simple
1918
-- $intro
2019
-- * Normal peers
2120
runPeer
21+
, runAnnotatedPeer
2222
, TraceSendRecv (..)
2323
, DecoderFailure (..)
2424
-- * Pipelined peers
2525
, runPipelinedPeer
26+
, runPipelinedAnnotatedPeer
2627
-- * Connected peers
2728
-- TODO: move these to a test lib
2829
, Role (..)
@@ -43,6 +44,9 @@ import Ouroboros.Network.Channel
4344
import Control.Monad.Class.MonadAsync
4445
import Control.Monad.Class.MonadThrow
4546
import Control.Tracer (Tracer (..), contramap, traceWith)
47+
import Data.Maybe (fromMaybe)
48+
import Data.Functor.Identity (Identity)
49+
import Control.Monad.Identity (Identity(..))
4650

4751

4852
-- $intro
@@ -107,18 +111,31 @@ instance Show DecoderFailure where
107111
instance Exception DecoderFailure where
108112

109113

110-
driverSimple :: forall ps failure bytes m.
111-
( MonadThrow m
112-
, Show failure
113-
, forall (st :: ps). Show (ClientHasAgency st)
114-
, forall (st :: ps). Show (ServerHasAgency st)
115-
, ShowProxy ps
116-
)
117-
=> Tracer m (TraceSendRecv ps)
118-
-> Codec ps failure m bytes
119-
-> Channel m bytes
120-
-> Driver ps (Maybe bytes) m
121-
driverSimple tracer Codec{encode, decode} channel@Channel{send} =
114+
mkSimpleDriver :: forall ps failure bytes m f annotator.
115+
( MonadThrow m
116+
, Show failure
117+
, forall (st :: ps). Show (ClientHasAgency st)
118+
, forall (st :: ps). Show (ServerHasAgency st)
119+
, ShowProxy ps
120+
)
121+
=> (forall a.
122+
Channel m bytes
123+
-> Maybe bytes
124+
-> DecodeStep bytes failure m (f a)
125+
-> m (Either failure (a, Maybe bytes))
126+
)
127+
-- ^ run incremental decoder against a channel
128+
129+
-> (forall st. annotator st -> f (SomeMessage st))
130+
-- ^ transform annotator to a container holding the decoded
131+
-- message
132+
133+
-> Tracer m (TraceSendRecv ps)
134+
-> Codec' ps failure m annotator bytes
135+
-> Channel m bytes
136+
-> Driver ps (Maybe bytes) m
137+
138+
mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} =
122139
Driver { sendMessage, recvMessage, startDState = Nothing }
123140
where
124141
sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
@@ -135,7 +152,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
135152
-> m (SomeMessage st, Maybe bytes)
136153
recvMessage stok trailing = do
137154
decoder <- decode stok
138-
result <- runDecoderWithChannel channel trailing decoder
155+
result <- runDecodeSteps channel trailing (nat <$> decoder)
139156
case result of
140157
Right x@(SomeMessage msg, _trailing') -> do
141158
traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg))
@@ -144,6 +161,36 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
144161
throwIO (DecoderFailure stok failure)
145162

146163

164+
simpleDriver :: forall ps failure bytes m.
165+
( MonadThrow m
166+
, Show failure
167+
, forall (st :: ps). Show (ClientHasAgency st)
168+
, forall (st :: ps). Show (ServerHasAgency st)
169+
, ShowProxy ps
170+
)
171+
=> Tracer m (TraceSendRecv ps)
172+
-> Codec ps failure m bytes
173+
-> Channel m bytes
174+
-> Driver ps (Maybe bytes) m
175+
simpleDriver = mkSimpleDriver runDecoderWithChannel Identity
176+
177+
178+
annotatedSimpleDriver
179+
:: forall ps failure bytes m.
180+
( MonadThrow m
181+
, Monoid bytes
182+
, Show failure
183+
, forall (st :: ps). Show (ClientHasAgency st)
184+
, forall (st :: ps). Show (ServerHasAgency st)
185+
, ShowProxy ps
186+
)
187+
=> Tracer m (TraceSendRecv ps)
188+
-> AnnotatedCodec ps failure m bytes
189+
-> Channel m bytes
190+
-> Driver ps (Maybe bytes) m
191+
annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator
192+
193+
147194
-- | Run a peer with the given channel via the given codec.
148195
--
149196
-- This runs the peer to completion (if the protocol allows for termination).
@@ -164,7 +211,31 @@ runPeer
164211
runPeer tracer codec channel peer =
165212
runPeerWithDriver driver peer (startDState driver)
166213
where
167-
driver = driverSimple tracer codec channel
214+
driver = simpleDriver tracer codec channel
215+
216+
217+
-- | Run a peer with the given channel via the given annotated codec.
218+
--
219+
-- This runs the peer to completion (if the protocol allows for termination).
220+
--
221+
runAnnotatedPeer
222+
:: forall ps (st :: ps) pr failure bytes m a .
223+
( MonadThrow m
224+
, Monoid bytes
225+
, Show failure
226+
, forall (st' :: ps). Show (ClientHasAgency st')
227+
, forall (st' :: ps). Show (ServerHasAgency st')
228+
, ShowProxy ps
229+
)
230+
=> Tracer m (TraceSendRecv ps)
231+
-> AnnotatedCodec ps failure m bytes
232+
-> Channel m bytes
233+
-> Peer ps pr st m a
234+
-> m (a, Maybe bytes)
235+
runAnnotatedPeer tracer codec channel peer =
236+
runPeerWithDriver driver peer (startDState driver)
237+
where
238+
driver = annotatedSimpleDriver tracer codec channel
168239

169240

170241
-- | Run a pipelined peer with the given channel via the given codec.
@@ -191,7 +262,35 @@ runPipelinedPeer
191262
runPipelinedPeer tracer codec channel peer =
192263
runPipelinedPeerWithDriver driver peer (startDState driver)
193264
where
194-
driver = driverSimple tracer codec channel
265+
driver = simpleDriver tracer codec channel
266+
267+
268+
-- | Run a pipelined peer with the given channel via the given annotated codec.
269+
--
270+
-- This runs the peer to completion (if the protocol allows for termination).
271+
--
272+
-- Unlike normal peers, running pipelined peers rely on concurrency, hence the
273+
-- 'MonadAsync' constraint.
274+
--
275+
runPipelinedAnnotatedPeer
276+
:: forall ps (st :: ps) pr failure bytes m a.
277+
( MonadAsync m
278+
, MonadThrow m
279+
, Monoid bytes
280+
, Show failure
281+
, forall (st' :: ps). Show (ClientHasAgency st')
282+
, forall (st' :: ps). Show (ServerHasAgency st')
283+
, ShowProxy ps
284+
)
285+
=> Tracer m (TraceSendRecv ps)
286+
-> AnnotatedCodec ps failure m bytes
287+
-> Channel m bytes
288+
-> PeerPipelined ps pr st m a
289+
-> m (a, Maybe bytes)
290+
runPipelinedAnnotatedPeer tracer codec channel peer =
291+
runPipelinedPeerWithDriver driver peer (startDState driver)
292+
where
293+
driver = annotatedSimpleDriver tracer codec channel
195294

196295

197296
--
@@ -204,17 +303,36 @@ runPipelinedPeer tracer codec channel peer =
204303
runDecoderWithChannel :: Monad m
205304
=> Channel m bytes
206305
-> Maybe bytes
207-
-> DecodeStep bytes failure m a
306+
-> DecodeStep bytes failure m (Identity a)
208307
-> m (Either failure (a, Maybe bytes))
209308

210309
runDecoderWithChannel Channel{recv} = go
211310
where
212-
go _ (DecodeDone x trailing) = return (Right (x, trailing))
311+
go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing))
213312
go _ (DecodeFail failure) = return (Left failure)
214313
go Nothing (DecodePartial k) = recv >>= k >>= go Nothing
215314
go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing
216315

217316

317+
runAnnotatedDecoderWithChannel
318+
:: forall m bytes failure a.
319+
( Monad m
320+
, Monoid bytes
321+
)
322+
=> Channel m bytes
323+
-> Maybe bytes
324+
-> DecodeStep bytes failure m (bytes -> a)
325+
-> m (Either failure (a, Maybe bytes))
326+
327+
runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0
328+
where
329+
go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes))
330+
go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing)
331+
go _bytes _ (DecodeFail failure) = return (Left failure)
332+
go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing
333+
go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing
334+
335+
218336
data Role = Client | Server
219337

220338
-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.

ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs

+23-12
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ codecTxSubmission2
7070
-> (forall s . CBOR.Decoder s txid)
7171
-> (tx -> CBOR.Encoding)
7272
-> (forall s . CBOR.Decoder s tx)
73-
-> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString
73+
-> AnnotatedCodec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString
7474
codecTxSubmission2 encodeTxId decodeTxId
7575
encodeTx decodeTx =
7676
mkCodecCborLazyBS
@@ -79,7 +79,7 @@ codecTxSubmission2 encodeTxId decodeTxId
7979
where
8080
decode :: forall (pr :: PeerRole) (st :: TxSubmission2 txid tx).
8181
PeerHasAgency pr st
82-
-> forall s. CBOR.Decoder s (SomeMessage st)
82+
-> forall s. CBOR.Decoder s (Annotator ByteString st)
8383
decode stok = do
8484
len <- CBOR.decodeListLen
8585
key <- CBOR.decodeWord
@@ -156,26 +156,26 @@ decodeTxSubmission2
156156
PeerHasAgency pr st
157157
-> Int
158158
-> Word
159-
-> CBOR.Decoder s (SomeMessage st))
159+
-> CBOR.Decoder s (Annotator ByteString st))
160160
decodeTxSubmission2 decodeTxId decodeTx = decode
161161
where
162162
decode :: forall (pr :: PeerRole) s (st :: TxSubmission2 txid tx).
163163
PeerHasAgency pr st
164164
-> Int
165165
-> Word
166-
-> CBOR.Decoder s (SomeMessage st)
166+
-> CBOR.Decoder s (Annotator ByteString st)
167167
decode stok len key = do
168168
case (stok, len, key) of
169169
(ClientAgency TokInit, 1, 6) ->
170-
return (SomeMessage MsgInit)
170+
return (Annotator $ \_ -> SomeMessage MsgInit)
171171
(ServerAgency TokIdle, 4, 0) -> do
172172
blocking <- CBOR.decodeBool
173173
ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16
174174
reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16
175175
return $!
176176
if blocking
177-
then SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo)
178-
else SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo)
177+
then Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo)
178+
else Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo)
179179

180180
(ClientAgency (TokTxIds b), 2, 1) -> do
181181
CBOR.decodeListLenIndef
@@ -187,11 +187,11 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
187187
return (txid, SizeInBytes sz))
188188
case (b, txids) of
189189
(TokBlocking, t:ts) ->
190-
return $
190+
return $ Annotator $ \_ ->
191191
SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts)))
192192

193193
(TokNonBlocking, ts) ->
194-
return $
194+
return $ Annotator $ \_ ->
195195
SomeMessage (MsgReplyTxIds (NonBlockingReply ts))
196196

197197
(TokBlocking, []) ->
@@ -201,15 +201,26 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
201201
(ServerAgency TokIdle, 2, 2) -> do
202202
CBOR.decodeListLenIndef
203203
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId
204-
return (SomeMessage (MsgRequestTxs txids))
204+
return (Annotator $ \_ -> SomeMessage (MsgRequestTxs txids))
205205

206206
(ClientAgency TokTxs, 2, 3) -> do
207207
CBOR.decodeListLenIndef
208208
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx
209-
return (SomeMessage (MsgReplyTxs txids))
209+
-- ^ TODO: `txids -> txs` :grin:
210+
return (Annotator $
211+
-- TODO: here we have access to bytes from which the message was decoded.
212+
-- we can use `Codec.CBOR.Decoding.decodeWithByteSpan`
213+
-- around each `tx` and wrap each `tx` in `WithBytes`.
214+
--
215+
-- `decodeTxSubmission2` can be polymorphic by adding an
216+
-- extra argument of type
217+
-- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a`
218+
-- this way we could wrap `tx` in `WithBytes` or just
219+
-- return `tx`.
220+
\_bytes -> SomeMessage (MsgReplyTxs txids))
210221

211222
(ClientAgency (TokTxIds TokBlocking), 1, 4) ->
212-
return (SomeMessage MsgDone)
223+
return (Annotator $ \_ -> SomeMessage MsgDone)
213224

214225
--
215226
-- failures per protocol state

0 commit comments

Comments
 (0)