Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions cmd/rsocket-cli/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ func (r *Runner) runClientMode(ctx context.Context) (err error) {
_ = c.Close()
}()

initialRequest := payload.NewString("", "")

for i := 0; i < r.Ops; i++ {
if i > 0 {
logger.Infof("\n")
Expand All @@ -153,7 +155,7 @@ func (r *Runner) runClientMode(ctx context.Context) (err error) {
} else if r.Stream {
err = r.execRequestStream(ctx, c, first)
} else if r.Channel {
err = r.execRequestChannel(ctx, c, sendingPayloads)
err = r.execRequestChannel(ctx, c, initialRequest, sendingPayloads)
} else if r.MetadataPush {
err = r.execMetadataPush(ctx, c, first)
} else {
Expand Down Expand Up @@ -189,7 +191,7 @@ func (r *Runner) runServerMode(ctx context.Context) error {
r.showPayload(message)
return sendingPayloads
}))
options = append(options, rsocket.RequestChannel(func(messages flux.Flux) flux.Flux {
options = append(options, rsocket.RequestChannel(func(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
messages.Subscribe(ctx, rx.OnNext(func(input payload.Payload) error {
r.showPayload(input)
return nil
Expand Down Expand Up @@ -245,12 +247,12 @@ func (r *Runner) execRequestResponse(ctx context.Context, c rsocket.Client, send
return
}

func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, send flux.Flux) error {
func (r *Runner) execRequestChannel(ctx context.Context, c rsocket.Client, initialRequest payload.Payload, send flux.Flux) error {
var f flux.Flux
if r.N < rx.RequestMax {
f = c.RequestChannel(send).Take(r.N)
f = c.RequestChannel(initialRequest, send).Take(r.N)
} else {
f = c.RequestChannel(send)
f = c.RequestChannel(initialRequest, send)
}
return r.printFlux(ctx, f)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func responder() rsocket.RSocket {
emitter.Complete()
})
}),
rsocket.RequestChannel(func(payloads flux.Flux) flux.Flux {
rsocket.RequestChannel(func(initialRequest payload.Payload, payloads flux.Flux) flux.Flux {
//return payloads.(flux.Flux)
payloads.
//LimitRate(1).
Expand Down
5 changes: 3 additions & 2 deletions examples/word_counter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {

func server(readyCh chan struct{}) {
// create a handler that will be called when the server receives the RequestChannel frame (FrameTypeRequestChannel - 0x07)
requestChannelHandler := rsocket.RequestChannel(func(requests flux.Flux) flux.Flux {
requestChannelHandler := rsocket.RequestChannel(func(initialRequest payload.Payload, requests flux.Flux) flux.Flux {
return flux.Create(func(ctx context.Context, s flux.Sink) {
requests.DoOnNext(func(elem payload.Payload) error {
// for each payload in a flux stream respond with a word count
Expand Down Expand Up @@ -70,6 +70,7 @@ func client() {
defer client.Close()

// strings to count the words
initialRequest := payload.NewString("", "")
sentences := []payload.Payload{
payload.NewString("", extension.TextPlain.String()),
payload.NewString("qux", extension.TextPlain.String()),
Expand All @@ -86,7 +87,7 @@ func client() {
counter := 0

// register handler for RequestChannel
client.RequestChannel(f).DoOnNext(func(input payload.Payload) error {
client.RequestChannel(initialRequest, f).DoOnNext(func(input payload.Payload) error {
// print word count
fmt.Println(sentences[counter].DataUTF8(), ":", input.DataUTF8())
counter = counter + 1
Expand Down
6 changes: 3 additions & 3 deletions internal/socket/abstract_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type AbstractRSocket struct {
MP func(payload.Payload)
RR func(payload.Payload) mono.Mono
RS func(payload.Payload) flux.Flux
RC func(flux.Flux) flux.Flux
RC func(payload.Payload, flux.Flux) flux.Flux
}

// MetadataPush starts a request of MetadataPush.
Expand Down Expand Up @@ -60,9 +60,9 @@ func (a AbstractRSocket) RequestStream(message payload.Payload) flux.Flux {
}

// RequestChannel starts a request of RequestChannel.
func (a AbstractRSocket) RequestChannel(messages flux.Flux) flux.Flux {
func (a AbstractRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
if a.RC == nil {
return flux.Error(errUnimplementedRequestChannel)
}
return a.RC(messages)
return a.RC(initialRequest, messages)
}
6 changes: 3 additions & 3 deletions internal/socket/abstract_socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ func TestAbstractRSocket_RequestStream(t *testing.T) {

func TestAbstractRSocket_RequestChannel(t *testing.T) {
s := &socket.AbstractRSocket{
RC: func(publisher flux.Flux) flux.Flux {
RC: func(initialRequest payload.Payload, publisher flux.Flux) flux.Flux {
return flux.Clone(publisher)
},
}
var res []payload.Payload
_, err := s.RequestChannel(flux.Just(fakeRequest)).
_, err := s.RequestChannel(fakeRequest, flux.Just(fakeRequest)).
DoOnNext(func(input payload.Payload) error {
res = append(res, input)
return nil
Expand All @@ -101,6 +101,6 @@ func TestAbstractRSocket_RequestChannel(t *testing.T) {
assert.Len(t, res, 1)
assert.Equal(t, fakeRequest, res[0])

_, err = emptyAbstractRSocket.RequestChannel(flux.Just(fakeRequest)).BlockFirst(context.Background())
_, err = emptyAbstractRSocket.RequestChannel(fakeRequest, flux.Just(fakeRequest)).BlockFirst(context.Background())
assert.Error(t, err, "should return an error")
}
4 changes: 2 additions & 2 deletions internal/socket/base_socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func (p *BaseSocket) RequestStream(message payload.Payload) flux.Flux {
}

// RequestChannel sends RequestChannel request.
func (p *BaseSocket) RequestChannel(messages flux.Flux) flux.Flux {
func (p *BaseSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
if err := p.reqLease.allow(); err != nil {
return flux.Error(err)
}
return p.socket.RequestChannel(messages)
return p.socket.RequestChannel(initialRequest, messages)
}

// OnClose registers handler when socket closed.
Expand Down
2 changes: 1 addition & 1 deletion internal/socket/base_socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestBaseSocket(t *testing.T) {
s.FireAndForget(fakeRequest)
s.RequestResponse(fakeRequest)
s.RequestStream(fakeRequest)
s.RequestChannel(flux.Just(fakeRequest))
s.RequestChannel(fakeRequest, flux.Just(fakeRequest))
})

<-done
Expand Down
55 changes: 50 additions & 5 deletions internal/socket/duplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ func (dc *DuplexConnection) killCallback(sid uint32) {
}

// RequestChannel start a request of RequestChannel.
func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) {
func (dc *DuplexConnection) RequestChannel(request payload.Payload, sending flux.Flux) (ret flux.Flux) {
if dc.closed.Load() {
ret = flux.Error(errSocketClosed)
return
Expand Down Expand Up @@ -481,9 +481,56 @@ func (dc *DuplexConnection) RequestChannel(sending flux.Flux) (ret flux.Flux) {
return
}

// First request - send the initial REQUEST_CHANNEL frame with the request payload,
// then subscribe to the sending flux for subsequent payloads.
releasable, isReleasable := request.(common.Releasable)

if isReleasable {
releasable.IncRef()
}

data := request.Data()
metadata, _ := request.Metadata()

size := framing.CalcPayloadFrameSize(data, metadata) + 4
if !dc.shouldSplit(size) {
toBeSent := framing.NewWriteableRequestChannelFrame(sid, n, data, metadata, core.FlagNext)

if isReleasable {
toBeSent.HandleDone(func() {
releasable.Release()
})
}

if ok := dc.sendFrame(toBeSent); !ok {
dc.killCallback(sid)
return
}
} else {
dc.doSplitSkip(4, data, metadata, func(index int, result fragmentation.SplitResult) {
var toBeSent core.WriteableFrame
if index == 0 {
toBeSent = framing.NewWriteableRequestChannelFrame(sid, n, result.Data, result.Metadata, result.Flag|core.FlagNext)
} else {
toBeSent = framing.NewWriteablePayloadFrame(sid, result.Data, result.Metadata, result.Flag|core.FlagNext)
}

// Add release hook at last frame.
if !result.Flag.Check(core.FlagFollow) && isReleasable {
toBeSent.HandleDone(func() {
releasable.Release()
})
}

if ok := dc.sendFrame(toBeSent); !ok {
dc.killCallback(sid)
}
})
}

// Subscribe to sending flux for subsequent payloads
sub := &requestChannelSubscriber{
sid: sid,
n: n,
dc: dc,
rcv: receiving,
}
Expand Down Expand Up @@ -613,7 +660,7 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay
}
logger.Errorf("handle request-channel failed: %+v\n", err)
}()
resp = dc.responder.RequestChannel(receiving)
resp = dc.responder.RequestChannel(req, receiving)
if resp == nil {
err = framing.NewWriteableErrorFrame(sid, core.ErrorCodeApplicationError, unsupportedRequestChannel)
}
Expand Down Expand Up @@ -643,8 +690,6 @@ func (dc *DuplexConnection) respondRequestChannel(req fragmentation.HeaderAndPay
sending.SubscribeWith(dc.ctx, sub)
})

receivingProcessor.Next(req)

<-subscribed

return nil
Expand Down
31 changes: 4 additions & 27 deletions internal/socket/subscriber_request_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,20 @@ import (
"github.com/jjeffcaii/reactor-go"
"github.com/rsocket/rsocket-go/core"
"github.com/rsocket/rsocket-go/core/framing"
"github.com/rsocket/rsocket-go/internal/fragmentation"
"github.com/rsocket/rsocket-go/payload"
"github.com/rsocket/rsocket-go/rx"
"github.com/rsocket/rsocket-go/rx/flux"
"go.uber.org/atomic"
)

type requestChannelSubscriber struct {
sid uint32
n uint32
dc *DuplexConnection
requested atomic.Bool
rcv flux.Processor
sid uint32
dc *DuplexConnection
rcv flux.Processor
}

func (r *requestChannelSubscriber) OnNext(item payload.Payload) {
if !r.requested.CAS(false, true) {
r.dc.sendPayload(r.sid, item, core.FlagNext)
return
}
d := item.Data()
m, _ := item.Metadata()
size := framing.CalcPayloadFrameSize(d, m) + 4
if !r.dc.shouldSplit(size) {
metadata, _ := item.Metadata()
r.dc.sendFrame(framing.NewWriteableRequestChannelFrame(r.sid, r.n, item.Data(), metadata, core.FlagNext))
return
}
r.dc.doSplitSkip(4, d, m, func(index int, result fragmentation.SplitResult) {
var f core.WriteableFrame
if index == 0 {
f = framing.NewWriteableRequestChannelFrame(r.sid, r.n, result.Data, result.Metadata, result.Flag|core.FlagNext)
} else {
f = framing.NewWriteablePayloadFrame(r.sid, result.Data, result.Metadata, result.Flag|core.FlagNext)
}
r.dc.sendFrame(f)
})
r.dc.sendPayload(r.sid, item, core.FlagNext)
}

func (r *requestChannelSubscriber) OnError(err error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/socket/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Responder interface {
// RequestStream request a completable stream.
RequestStream(message payload.Payload) flux.Flux
// RequestChannel request a completable stream in both directions.
RequestChannel(messages flux.Flux) flux.Flux
RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux
}

// ClientSocket represents a client-side socket.
Expand Down
4 changes: 2 additions & 2 deletions rsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type (
// RequestStream request a completable stream.
RequestStream(message payload.Payload) flux.Flux
// RequestChannel request a completable stream in both directions.
RequestChannel(messages flux.Flux) flux.Flux
RequestChannel(initialMessage payload.Payload, messages flux.Flux) flux.Flux
}

// CloseableRSocket is RSocket which can be closed and handle close event.
Expand Down Expand Up @@ -115,7 +115,7 @@ func RequestStream(fn func(request payload.Payload) (responses flux.Flux)) OptAb
}

// RequestChannel register request handler for RequestChannel.
func RequestChannel(fn func(requests flux.Flux) (responses flux.Flux)) OptAbstractSocket {
func RequestChannel(fn func(initialRequest payload.Payload, requests flux.Flux) (responses flux.Flux)) OptAbstractSocket {
return func(opts *socket.AbstractRSocket) {
opts.RC = fn
}
Expand Down
5 changes: 3 additions & 2 deletions rsocket_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func ExampleReceive() {
s.Complete()
})
}),
rsocket.RequestChannel(func(requests flux.Flux) flux.Flux {
rsocket.RequestChannel(func(initialRequest payload.Payload, requests flux.Flux) flux.Flux {
return requests
}),
), nil
Expand Down Expand Up @@ -137,13 +137,14 @@ func ExampleConnect() {
s.Request(1)
}))
// Simple RequestChannel.
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
sendFlux := flux.Create(func(ctx context.Context, s flux.Sink) {
for i := 0; i < 3; i++ {
s.Next(payload.NewString(fmt.Sprintf("This is a RequestChannel message #%d.", i), ""))
}
s.Complete()
})
cli.RequestChannel(sendFlux).
cli.RequestChannel(initialPayload, sendFlux).
DoOnNext(func(elem payload.Payload) error {
log.Println("next element in channel:", elem)
return nil
Expand Down
10 changes: 6 additions & 4 deletions rsocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func testAll(t *testing.T, proto string, clientTp transport.ClientTransporter, s
s.Complete()
})
}),
RequestChannel(func(inputs flux.Flux) flux.Flux {
RequestChannel(func(initialRequest payload.Payload, inputs flux.Flux) flux.Flux {
received := new(int32)
inputs.
DoOnNext(func(input payload.Payload) error {
Expand Down Expand Up @@ -471,6 +471,7 @@ func testRequestStreamOneByOne(ctx context.Context, cli Client, t *testing.T) {

func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {
// RequestChannel
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
send := flux.Create(func(ctx context.Context, s flux.Sink) {
for i := 0; i < int(channelElements); i++ {
s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i)))
Expand All @@ -480,7 +481,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {

var seq int

_, err := cli.RequestChannel(send).
_, err := cli.RequestChannel(initialPayload, send).
DoOnNext(func(elem payload.Payload) error {
//fmt.Println(elem)
m, _ := elem.MetadataUTF8()
Expand All @@ -495,6 +496,7 @@ func testRequestChannel(ctx context.Context, cli Client, t *testing.T) {

func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) {
// RequestChannel
initialPayload := payload.NewString("This is a RequestChannel initial message.", "")
send := flux.Create(func(ctx context.Context, s flux.Sink) {
for i := 0; i < int(channelElements); i++ {
s.Next(payload.NewString(fakeData, fmt.Sprintf("%d", i)))
Expand All @@ -508,7 +510,7 @@ func testRequestChannelOneByOne(ctx context.Context, cli Client, t *testing.T) {

var su rx.Subscription

cli.RequestChannel(send).
cli.RequestChannel(initialPayload, send).
DoFinally(func(s rx.SignalType) {
assert.Equal(t, rx.SignalComplete, s, "bad signal type")
close(done)
Expand Down Expand Up @@ -599,7 +601,7 @@ func (d delayedRSocket) RequestStream(message payload.Payload) flux.Flux {
panic("implement me")
}

func (d delayedRSocket) RequestChannel(messages flux.Flux) flux.Flux {
func (d delayedRSocket) RequestChannel(initialRequest payload.Payload, messages flux.Flux) flux.Flux {
panic("implement me")
}

Expand Down
Loading