From ab8d7ffeb1d077f6d6fc724a5a22b7465eee67d5 Mon Sep 17 00:00:00 2001 From: Ondrej Belusky Date: Thu, 24 Apr 2025 21:40:39 +0200 Subject: [PATCH] MQTT: allow custom timeout for JS API calls --- server/mqtt.go | 80 ++++++++++++++++++++++++++------------------- server/mqtt_test.go | 17 ++++++++++ server/opts.go | 5 +++ 3 files changed, 68 insertions(+), 34 deletions(-) diff --git a/server/mqtt.go b/server/mqtt.go index e76b7168f0..331e688c87 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -189,6 +189,7 @@ const ( mqttProcessSubTooLong = 100 * time.Millisecond mqttDefaultRetainedCacheTTL = 2 * time.Minute mqttRetainedTransferTimeout = 10 * time.Second + mqttDefaultJSAPITimeout = 5 * time.Second ) const ( @@ -209,30 +210,30 @@ var ( mqttOldProtoName = []byte("MQIsdp") mqttSessJailDur = mqttSessFlappingJailDur mqttFlapCleanItvl = mqttSessFlappingCleanupInterval - mqttJSAPITimeout = 4 * time.Second mqttRetainedCacheTTL = mqttDefaultRetainedCacheTTL ) var ( - errMQTTNotWebsocketPort = errors.New("MQTT clients over websocket must connect to the Websocket port, not the MQTT port") - errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty") - errMQTTMalformedVarInt = errors.New("malformed variable int") - errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet") - errMQTTServerNameMustBeSet = errors.New("mqtt requires server name to be explicitly set") - errMQTTUserMixWithUsersNKeys = errors.New("mqtt authentication username not compatible with presence of users/nkeys") - errMQTTTokenMixWIthUsersNKeys = errors.New("mqtt authentication token not compatible with presence of users/nkeys") - errMQTTAckWaitMustBePositive = errors.New("ack wait must be a positive value") - errMQTTStandaloneNeedsJetStream = errors.New("mqtt requires JetStream to be enabled if running in standalone mode") - errMQTTConnFlagReserved = errors.New("connect flags reserved bit not set to 0") - errMQTTWillAndRetainFlag = errors.New("if Will flag is set to 0, Will Retain flag must be 0 too") - errMQTTPasswordFlagAndNoUser = errors.New("password flag set but username flag is not") - errMQTTCIDEmptyNeedsCleanFlag = errors.New("when client ID is empty, clean session flag must be set to 1") - errMQTTEmptyWillTopic = errors.New("empty Will topic not allowed") - errMQTTEmptyUsername = errors.New("empty user name not allowed") - errMQTTTopicIsEmpty = errors.New("topic cannot be empty") - errMQTTPacketIdentifierIsZero = errors.New("packet identifier cannot be 0") - errMQTTUnsupportedCharacters = errors.New("character ' ' not supported for MQTT topics") - errMQTTInvalidSession = errors.New("invalid MQTT session") + errMQTTNotWebsocketPort = errors.New("MQTT clients over websocket must connect to the Websocket port, not the MQTT port") + errMQTTTopicFilterCannotBeEmpty = errors.New("topic filter cannot be empty") + errMQTTMalformedVarInt = errors.New("malformed variable int") + errMQTTSecondConnectPacket = errors.New("received a second CONNECT packet") + errMQTTServerNameMustBeSet = errors.New("mqtt requires server name to be explicitly set") + errMQTTUserMixWithUsersNKeys = errors.New("mqtt authentication username not compatible with presence of users/nkeys") + errMQTTTokenMixWIthUsersNKeys = errors.New("mqtt authentication token not compatible with presence of users/nkeys") + errMQTTAckWaitMustBePositive = errors.New("ack wait must be a positive value") + errMQTTJSAPITimeoutMustBePositive = errors.New("JS API timeout must be a positive value") + errMQTTStandaloneNeedsJetStream = errors.New("mqtt requires JetStream to be enabled if running in standalone mode") + errMQTTConnFlagReserved = errors.New("connect flags reserved bit not set to 0") + errMQTTWillAndRetainFlag = errors.New("if Will flag is set to 0, Will Retain flag must be 0 too") + errMQTTPasswordFlagAndNoUser = errors.New("password flag set but username flag is not") + errMQTTCIDEmptyNeedsCleanFlag = errors.New("when client ID is empty, clean session flag must be set to 1") + errMQTTEmptyWillTopic = errors.New("empty Will topic not allowed") + errMQTTEmptyUsername = errors.New("empty user name not allowed") + errMQTTTopicIsEmpty = errors.New("topic cannot be empty") + errMQTTPacketIdentifierIsZero = errors.New("packet identifier cannot be 0") + errMQTTUnsupportedCharacters = errors.New("character ' ' not supported for MQTT topics") + errMQTTInvalidSession = errors.New("invalid MQTT session") ) type srvMQTT struct { @@ -281,6 +282,7 @@ type mqttJSA struct { quitCh chan struct{} domain string // Domain or possibly empty. This is added to session subject. domainSet bool // covers if domain was set, even to empty + timeout time.Duration } type mqttJSPubMsg struct { @@ -696,6 +698,9 @@ func validateMQTTOptions(o *Options) error { if mo.AckWait < 0 { return errMQTTAckWaitMustBePositive } + if mo.JSAPITimeout < 0 { + return errMQTTJSAPITimeoutMustBePositive + } // If strictly standalone and there is no JS enabled, then it won't work... // For leafnodes, we could either have remote(s) and it would be ok, or no // remote but accept from a remote side that has "hub" property set, which @@ -1152,6 +1157,12 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc c.acc = acc id := s.NodeName() + + mqttJSAPITimeout := opts.MQTT.JSAPITimeout + if mqttJSAPITimeout == 0 { + mqttJSAPITimeout = mqttDefaultJSAPITimeout + } + replicas := opts.MQTT.StreamReplicas if replicas <= 0 { replicas = s.mqttDetermineReplicas() @@ -1163,12 +1174,13 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc sessLocked: make(map[string]struct{}), flappers: make(map[string]int64), jsa: mqttJSA{ - id: id, - c: c, - rplyr: mqttJSARepliesPrefix + id + ".", - sendq: newIPQueue[*mqttJSPubMsg](s, qname+"send"), - nuid: nuid.New(), - quitCh: quitCh, + id: id, + c: c, + rplyr: mqttJSARepliesPrefix + id + ".", + sendq: newIPQueue[*mqttJSPubMsg](s, qname+"send"), + nuid: nuid.New(), + quitCh: quitCh, + timeout: mqttJSAPITimeout, }, } if !testDisableRMSCache { @@ -1546,7 +1558,7 @@ func (s *Server) mqttDetermineReplicas() int { ////////////////////////////////////////////////////////////////////////////// func (jsa *mqttJSA) newRequest(kind, subject string, hdr int, msg []byte) (any, error) { - return jsa.newRequestEx(kind, subject, _EMPTY_, hdr, msg, mqttJSAPITimeout) + return jsa.newRequestEx(kind, subject, _EMPTY_, hdr, msg) } func (jsa *mqttJSA) prefixDomain(subject string) string { @@ -1559,8 +1571,8 @@ func (jsa *mqttJSA) prefixDomain(subject string) string { return subject } -func (jsa *mqttJSA) newRequestEx(kind, subject, cidHash string, hdr int, msg []byte, timeout time.Duration) (any, error) { - responses, err := jsa.newRequestExMulti(kind, subject, cidHash, []int{hdr}, [][]byte{msg}, timeout) +func (jsa *mqttJSA) newRequestEx(kind, subject, cidHash string, hdr int, msg []byte) (any, error) { + responses, err := jsa.newRequestExMulti(kind, subject, cidHash, []int{hdr}, [][]byte{msg}) if err != nil { return nil, err } @@ -1578,7 +1590,7 @@ func (jsa *mqttJSA) newRequestEx(kind, subject, cidHash string, hdr int, msg []b // // Note that each response may represent an error and should be inspected as // such by the caller. -func (jsa *mqttJSA) newRequestExMulti(kind, subject, cidHash string, hdrs []int, msgs [][]byte, timeout time.Duration) ([]*mqttJSAResponse, error) { +func (jsa *mqttJSA) newRequestExMulti(kind, subject, cidHash string, hdrs []int, msgs [][]byte) ([]*mqttJSAResponse, error) { if len(hdrs) != len(msgs) { return nil, fmt.Errorf("unreachable: invalid number of messages (%d) or header offsets (%d)", len(msgs), len(hdrs)) } @@ -1630,7 +1642,7 @@ func (jsa *mqttJSA) newRequestExMulti(kind, subject, cidHash string, hdrs []int, c := 0 responses := make([]*mqttJSAResponse, len(msgs)) start := time.Now() - t := time.NewTimer(timeout) + t := time.NewTimer(jsa.timeout) defer t.Stop() for { select { @@ -1789,7 +1801,7 @@ func (jsa *mqttJSA) loadLastMsgForMulti(streamName string, subjects []string) ([ headerBytes = append(headerBytes, 0) } - all, err := jsa.newRequestExMulti(mqttJSAMsgLoad, fmt.Sprintf(JSApiMsgGetT, streamName), _EMPTY_, headerBytes, marshaled, mqttJSAPITimeout) + all, err := jsa.newRequestExMulti(mqttJSAMsgLoad, fmt.Sprintf(JSApiMsgGetT, streamName), _EMPTY_, headerBytes, marshaled) // all has the same order as subjects, preserve it as we unmarshal responses := make([]*JSApiMsgGetResponse, len(all)) for i, v := range all { @@ -1847,7 +1859,7 @@ func (jsa *mqttJSA) storeSessionMsg(domainTk, cidHash string, hdr int, msg []byt // Passing cidHash will add it to the JS reply subject, so that we can use // it in processSessionPersist. - smri, err := jsa.newRequestEx(mqttJSASessPersist, subject, cidHash, hdr, msg, mqttJSAPITimeout) + smri, err := jsa.newRequestEx(mqttJSASessPersist, subject, cidHash, hdr, msg) if err != nil { return nil, err } @@ -2982,7 +2994,7 @@ func (as *mqttAccountSessionManager) transferUniqueSessStreamsToMuxed(log *Serve }() jsa := &as.jsa - sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, _EMPTY_, 0, nil, 5*time.Second) + sni, err := jsa.newRequestEx(mqttJSAStreamNames, JSApiStreams, _EMPTY_, 0, nil) if err != nil { log.Errorf("Unable to transfer MQTT session streams: %v", err) return diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 807069329d..349184f7f8 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -469,6 +469,11 @@ func TestMQTTValidateOptions(t *testing.T) { o.MQTT.AckWait = -10 * time.Second return o }, errMQTTAckWaitMustBePositive}, + {"js api timeout should be >=0", func() *Options { + o := mqtto.Clone() + o.MQTT.JSAPITimeout = -10 * time.Second + return o + }, errMQTTJSAPITimeoutMustBePositive}, } { t.Run(test.name, func(t *testing.T) { err := validateMQTTOptions(test.getOpts()) @@ -498,6 +503,7 @@ func TestMQTTParseOptions(t *testing.T) { {"ack wait", `mqtt: {ack_wait: abc}`, nil, "invalid duration"}, {"max ack pending", `mqtt: {max_ack_pending: abc}`, nil, "not int64"}, {"max ack pending too high", `mqtt: {max_ack_pending: 12345678}`, nil, "invalid value"}, + {"js_api_timeout bad duration", `mqtt: {js_api_timeout: abc}`, nil, "invalid duration"}, // Positive tests {"tls gen fails", ` mqtt { @@ -627,6 +633,17 @@ func TestMQTTParseOptions(t *testing.T) { } return nil }, ""}, + {"js_api_timeout", + ` + mqtt { + js_api_timeout: "60s" + } + `, func(o *MQTTOpts) error { + if o.JSAPITimeout != 60*time.Second { + return fmt.Errorf("Invalid JS API timeout: %v", o.JSAPITimeout) + } + return nil + }, ""}, } { t.Run(test.name, func(t *testing.T) { conf := createConfFile(t, []byte(test.content)) diff --git a/server/opts.go b/server/opts.go index 087c585cb1..f6f7fc1863 100644 --- a/server/opts.go +++ b/server/opts.go @@ -616,6 +616,9 @@ type MQTTOpts struct { // PubRels). AckWait time.Duration + // JSAPITimeout defines timeout for JetStream api calls (default is 5 seconds) + JSAPITimeout time.Duration + // MaxAckPending is the amount of QoS 1 and 2 messages (combined) the server // can send to a subscription without receiving any PUBACK for those // messages. The valid range is [0..65535]. @@ -5205,6 +5208,8 @@ func parseMQTT(v any, o *Options, errors *[]error, warnings *[]error) error { o.MQTT.NoAuthUser = mv.(string) case "ack_wait", "ackwait": o.MQTT.AckWait = parseDuration("ack_wait", tk, mv, errors, warnings) + case "js_api_timeout", "api_timeout": + o.MQTT.JSAPITimeout = parseDuration("js_api_timeout", tk, mv, errors, warnings) case "max_ack_pending", "max_pending", "max_inflight": tmp := int(mv.(int64)) if tmp < 0 || tmp > 0xFFFF {