diff --git a/components/forwarder/envelope.go b/components/forwarder/envelope.go index b8de18763..a0bc3fd3e 100644 --- a/components/forwarder/envelope.go +++ b/components/forwarder/envelope.go @@ -1,8 +1,6 @@ package forwarder import ( - "encoding/json" - "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" "github.com/pkg/errors" @@ -40,13 +38,17 @@ func (e *messageEnvelope) validate() error { return nil } -func wrapMessageInEnvelope(destinationTopic string, msg *message.Message) (*message.Message, error) { +func wrapMessageInEnvelope( + destinationTopic string, + msg *message.Message, + marshal func(any) ([]byte, error), +) (*message.Message, error) { envelope, err := newMessageEnvelope(destinationTopic, msg) if err != nil { return nil, errors.Wrap(err, "cannot envelope a message") } - envelopedMessage, err := json.Marshal(envelope) + envelopedMessage, err := marshal(envelope) if err != nil { return nil, errors.Wrap(err, "cannot marshal a message") } @@ -57,9 +59,12 @@ func wrapMessageInEnvelope(destinationTopic string, msg *message.Message) (*mess return wrappedMsg, nil } -func unwrapMessageFromEnvelope(msg *message.Message) (destinationTopic string, unwrappedMsg *message.Message, err error) { +func unwrapMessageFromEnvelope( + msg *message.Message, + unmarshal func(data []byte, v any) error, +) (destinationTopic string, unwrappedMsg *message.Message, err error) { envelopedMsg := messageEnvelope{} - if err := json.Unmarshal(msg.Payload, &envelopedMsg); err != nil { + if err := unmarshal(msg.Payload, &envelopedMsg); err != nil { return "", nil, errors.Wrap(err, "cannot unmarshal message wrapped in an envelope") } diff --git a/components/forwarder/envelope_test.go b/components/forwarder/envelope_test.go index 1d3a706e0..ea3380fc1 100644 --- a/components/forwarder/envelope_test.go +++ b/components/forwarder/envelope_test.go @@ -2,6 +2,7 @@ package forwarder import ( "context" + "encoding/json" "testing" "github.com/ThreeDotsLabs/watermill" @@ -14,6 +15,8 @@ type contextKey string func TestEnvelope(t *testing.T) { expectedUUID := watermill.NewUUID() + marshal := json.Marshal + unmarshal := json.Unmarshal expectedPayload := message.Payload("msg content") expectedMetadata := message.Metadata{"key": "value"} expectedDestinationTopic := "dest_topic" @@ -24,14 +27,14 @@ func TestEnvelope(t *testing.T) { msg.Metadata = expectedMetadata msg.SetContext(ctx) - wrappedMsg, err := wrapMessageInEnvelope(expectedDestinationTopic, msg) + wrappedMsg, err := wrapMessageInEnvelope(expectedDestinationTopic, msg, marshal) require.NoError(t, err) require.NotNil(t, wrappedMsg) v, ok := wrappedMsg.Context().Value(contextKey("key")).(string) require.True(t, ok) require.Equal(t, "value", v) - destinationTopic, unwrappedMsg, err := unwrapMessageFromEnvelope(wrappedMsg) + destinationTopic, unwrappedMsg, err := unwrapMessageFromEnvelope(wrappedMsg, unmarshal) require.NoError(t, err) require.NotNil(t, unwrappedMsg) assert.Equal(t, expectedUUID, unwrappedMsg.UUID) diff --git a/components/forwarder/forwarder.go b/components/forwarder/forwarder.go index 09fb26b56..2bc5f4675 100644 --- a/components/forwarder/forwarder.go +++ b/components/forwarder/forwarder.go @@ -2,6 +2,7 @@ package forwarder import ( "context" + "encoding/json" "time" "github.com/ThreeDotsLabs/watermill" @@ -29,7 +30,8 @@ type Config struct { // If not provided, a new router will be created. // // If router is provided, it's not necessary to call `Forwarder.Run()` if the router is started with `router.Run()`. - Router *message.Router + Router *message.Router + Unmarshal func(data []byte, v any) error } func (c *Config) setDefaults() { @@ -39,6 +41,9 @@ func (c *Config) setDefaults() { if c.ForwarderTopic == "" { c.ForwarderTopic = defaultForwarderTopic } + if c.Unmarshal == nil { + c.Unmarshal = json.Unmarshal + } } func (c *Config) Validate() error { @@ -64,7 +69,12 @@ type Forwarder struct { // // Note: Keep in mind that by default the forwarder will nack all messages which weren't sent using a decorated publisher. // You can change this behavior by passing a middleware which will ack them instead. -func NewForwarder(subscriberIn message.Subscriber, publisherOut message.Publisher, logger watermill.LoggerAdapter, config Config) (*Forwarder, error) { +func NewForwarder( + subscriberIn message.Subscriber, + publisherOut message.Publisher, + logger watermill.LoggerAdapter, + config Config, +) (*Forwarder, error) { config.setDefaults() routerConfig := message.RouterConfig{CloseTimeout: config.CloseTimeout} @@ -117,7 +127,7 @@ func (f *Forwarder) Running() chan struct{} { } func (f *Forwarder) forwardMessage(msg *message.Message) error { - destTopic, unwrappedMsg, err := unwrapMessageFromEnvelope(msg) + destTopic, unwrappedMsg, err := unwrapMessageFromEnvelope(msg, f.config.Unmarshal) if err != nil { f.logger.Error("Could not unwrap a message from an envelope", err, watermill.LogFields{ "uuid": msg.UUID, diff --git a/components/forwarder/publisher.go b/components/forwarder/publisher.go index 9d1836791..ef43048fe 100644 --- a/components/forwarder/publisher.go +++ b/components/forwarder/publisher.go @@ -1,6 +1,8 @@ package forwarder import ( + "encoding/json" + "github.com/ThreeDotsLabs/watermill/message" "github.com/pkg/errors" ) @@ -9,12 +11,16 @@ type PublisherConfig struct { // ForwarderTopic is a topic which the forwarder is listening to. Publisher will send enveloped messages to this topic. // Defaults to `forwarder_topic`. ForwarderTopic string + Marshal func(v any) ([]byte, error) } func (c *PublisherConfig) setDefaults() { if c.ForwarderTopic == "" { c.ForwarderTopic = defaultForwarderTopic } + if c.Marshal == nil { + c.Marshal = json.Marshal + } } func (c *PublisherConfig) Validate() error { @@ -44,7 +50,7 @@ func NewPublisher(publisher message.Publisher, config PublisherConfig) *Publishe func (p *Publisher) Publish(topic string, messages ...*message.Message) error { envelopedMessages := make([]*message.Message, 0, len(messages)) for _, msg := range messages { - envelopedMsg, err := wrapMessageInEnvelope(topic, msg) + envelopedMsg, err := wrapMessageInEnvelope(topic, msg, p.config.Marshal) if err != nil { return errors.Wrapf(err, "cannot wrap message, target topic: '%s', uuid: '%s'", topic, msg.UUID) }