diff --git a/server/client.go b/server/client.go index 54337195934..cbf7bbdf2db 100644 --- a/server/client.go +++ b/server/client.go @@ -3022,7 +3022,7 @@ func (c *client) processSubEx(subject, queue, bsid []byte, cb msgHandler, noForw return sub, nil } - if err := c.addShadowSubscriptions(acc, sub, true); err != nil { + if err := c.addShadowSubscriptions(acc, sub); err != nil { c.Errorf(err.Error()) } @@ -3052,10 +3052,7 @@ type ime struct { // If the client's account has stream imports and there are matches for this // subscription's subject, then add shadow subscriptions in the other accounts // that export this subject. -// -// enact=false allows MQTT clients to get the list of shadow subscriptions -// without enacting them, in order to first obtain matching "retained" messages. -func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact bool) error { +func (c *client) addShadowSubscriptions(acc *Account, sub *subscription) error { if acc == nil { return ErrMissingAccount } @@ -3158,7 +3155,7 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b for i := 0; i < len(ims); i++ { ime := &ims[i] // We will create a shadow subscription. - nsub, err := c.addShadowSub(sub, ime, enact) + nsub, err := c.addShadowSub(sub, ime) if err != nil { return err } @@ -3175,7 +3172,7 @@ func (c *client) addShadowSubscriptions(acc *Account, sub *subscription, enact b } // Add in the shadow subscription. -func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscription, error) { +func (c *client) addShadowSub(sub *subscription, ime *ime) (*subscription, error) { c.mu.Lock() nsub := *sub // copy c.mu.Unlock() @@ -3203,10 +3200,6 @@ func (c *client) addShadowSub(sub *subscription, ime *ime, enact bool) (*subscri } // Else use original subject - if !enact { - return &nsub, nil - } - c.Debugf("Creating import subscription on %q from account %q", nsub.subject, im.acc.Name) if err := im.acc.sl.Insert(&nsub); err != nil { @@ -5796,7 +5789,7 @@ func (c *client) processSubsOnConfigReload(awcsti map[string]struct{}) { oldShadows := sub.shadow sub.shadow = nil c.mu.Unlock() - c.addShadowSubscriptions(acc, sub, true) + c.addShadowSubscriptions(acc, sub) for _, nsub := range oldShadows { nsub.im.acc.sl.Remove(nsub) } diff --git a/server/leafnode.go b/server/leafnode.go index 5e233f7d99d..b939f9dc1cc 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -2832,7 +2832,7 @@ func (c *client) processLeafSub(argo []byte) (err error) { // Only add in shadow subs if a new sub or qsub. if osub == nil { - if err := c.addShadowSubscriptions(acc, sub, true); err != nil { + if err := c.addShadowSubscriptions(acc, sub); err != nil { c.Errorf(err.Error()) } } diff --git a/server/mqtt.go b/server/mqtt.go index d8e75a766f3..f67f9d33f28 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -192,6 +192,7 @@ const ( mqttDefaultRetainedCacheTTL = 2 * time.Minute mqttRetainedTransferTimeout = 10 * time.Second mqttDefaultJSAPITimeout = 5 * time.Second + mqttRetainedFlagDelMarker = '-' ) const ( @@ -236,6 +237,7 @@ var ( errMQTTPacketIdentifierIsZero = errors.New("packet identifier cannot be 0") errMQTTUnsupportedCharacters = errors.New("character ' ' not supported for MQTT topics") errMQTTInvalidSession = errors.New("invalid MQTT session") + errMQTTInvalidRetainFlags = errors.New("invalid retained message flags") ) type srvMQTT struct { @@ -250,8 +252,6 @@ type mqttSessionManager struct { sessions map[string]*mqttAccountSessionManager // key is account name } -var testDisableRMSCache = false - type mqttAccountSessionManager struct { mu sync.RWMutex sessions map[string]*mqttSession // key is MQTT client ID @@ -263,10 +263,7 @@ type mqttAccountSessionManager struct { retmsgs map[string]*mqttRetainedMsgRef // retained messages rmsCache *sync.Map // map[subject]mqttRetainedMsg jsa mqttJSA - rrmNum uint64 // Number of restored retained messages. - rrmTotal uint64 // Total of retained messages to restore. - rrmDoneCh chan struct{} // To notify the caller that all retained messages have been loaded. - domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject. + domainTk string // Domain (with trailing "."), or possibly empty. This is added to session subject. } type mqttJSAResponse struct { @@ -364,9 +361,8 @@ type mqttRetainedMsg struct { } type mqttRetainedMsgRef struct { - sseq uint64 - floor uint64 - sub *subscription + sseq uint64 + sub *subscription } // mqttSub contains fields associated with a MQTT subscription, and is added to @@ -1185,9 +1181,7 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc quitCh: quitCh, timeout: mqttJSAPITimeout, }, - } - if !testDisableRMSCache { - as.rmsCache = &sync.Map{} + rmsCache: &sync.Map{}, } // TODO record domain name in as here @@ -1283,12 +1277,10 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc }) // Start the go routine that will clean up cached retained messages that expired. - if as.rmsCache != nil { - s.startGoRoutine(func() { - defer s.grWG.Done() - as.cleanupRetainedMessageCache(s, closeCh) - }) - } + s.startGoRoutine(func() { + defer s.grWG.Done() + as.cleanupRetainedMessageCache(s, closeCh) + }) lookupStream := func(stream, txt string) (*StreamInfo, error) { si, err := jsa.lookupStream(stream) @@ -1476,15 +1468,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc return nil, err } - var rmTotal uint64 - var rmDoneCh chan struct{} - st := si.State - if rmTotal = st.Msgs; rmTotal > 0 { - rmDoneCh = make(chan struct{}) - as.rrmTotal = rmTotal - as.rrmDoneCh = rmDoneCh - } - // Opportunistically delete the old (legacy) consumer, from v2.10.10 and // before. Ignore any errors that might arise. rmLegacyDurName := mqttRetainedMsgsStreamName + "_" + jsa.id @@ -1507,19 +1490,6 @@ func (s *Server) mqttCreateAccountSessionManager(acc *Account, quitCh chan struc return nil, fmt.Errorf("create retained messages consumer for account %q: %v", accName, err) } - if rmTotal > 0 { - ttl := time.NewTimer(mqttJSAPITimeout) - defer ttl.Stop() - - select { - case <-rmDoneCh: - case <-ttl.C: - s.Warnf("Timing out waiting to load %v retained messages", rmTotal) - case <-quitCh: - return nil, ErrServerNotRunning - } - } - // Set this so that on defer we don't cleanup. success = true @@ -1674,8 +1644,7 @@ func (jsa *mqttJSA) newRequestExMulti(kind, subject, cidHash string, hdrs []int, } func (jsa *mqttJSA) sendAck(ackSubject string) { - // We pass -1 for the hdr so that the send loop does not need to - // add the "client info" header. This is not a JS API request per se. + // Send to the ack subject with no payload. jsa.sendMsg(ackSubject, nil) } @@ -1683,6 +1652,8 @@ func (jsa *mqttJSA) sendMsg(subj string, msg []byte) { if subj == _EMPTY_ { return } + // We pass -1 for the hdr so that the send loop does not need to + // add the "client info" header. This is not a JS API request per se. jsa.sendq.push(&mqttJSPubMsg{subj: subj, msg: msg, hdr: -1}) } @@ -1840,12 +1811,16 @@ func (jsa *mqttJSA) loadMsg(streamName string, seq uint64) (*StoredMsg, error) { return lmr.Message, lmr.ToError() } -func (jsa *mqttJSA) storeMsg(subject string, headers int, msg []byte) (*JSPubAckResponse, error) { - return jsa.storeMsgWithKind(mqttJSAMsgStore, subject, headers, msg) +func (jsa *mqttJSA) storeMsgNoWait(subject string, hdrLen int, msg []byte) { + jsa.sendq.push(&mqttJSPubMsg{ + subj: subject, + msg: msg, + hdr: hdrLen, + }) } -func (jsa *mqttJSA) storeMsgWithKind(kind, subject string, headers int, msg []byte) (*JSPubAckResponse, error) { - smri, err := jsa.newRequest(kind, subject, headers, msg) +func (jsa *mqttJSA) storeMsg(subject string, headers int, msg []byte) (*JSPubAckResponse, error) { + smri, err := jsa.newRequest(mqttJSAMsgStore, subject, headers, msg) if err != nil { return nil, err } @@ -1990,45 +1965,41 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl // // Run from various go routines (JS consumer, etc..). // No lock held on entry. -func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { +func (as *mqttAccountSessionManager) processRetainedMsg(sub *subscription, c *client, acc *Account, subject, reply string, rmsg []byte) { h, m := c.msgParts(rmsg) // We need to strip the trailing "\r\n". if l := len(m); l >= LEN_CR_LF { m = m[:l-LEN_CR_LF] } - rm, err := mqttDecodeRetainedMessage(h, m) + rm, err := mqttDecodeRetainedMessage(subject, h, m) if err != nil { return } - // If rrmTotal is 0 (nothing to recover, or done doing it) and this is - // from our own server, ignore. - as.mu.RLock() - if as.rrmTotal == 0 && rm.Origin == as.jsa.id { - as.mu.RUnlock() - return - } - as.mu.RUnlock() - - // At this point we either recover from our own server, or process a remote retained message. + // The as.jsa.id is immutable, so no need to have a rlock here. + local := rm.Origin == as.jsa.id + // Get the stream sequence for this message. seq, _, _ := ackReplyInfo(reply) - - // Handle this retained message. The `rm.Msg` references some buffer owned - // by the caller. handleRetainedMsg() will take care of making a copy of - // `rm.Msg` it `rm` ends-up being stored in the cache. - as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm) - - // If we were recovering (rrmTotal > 0), then check if we are done. - as.mu.Lock() - if as.rrmTotal > 0 { - if as.rrmNum++; as.rrmNum == as.rrmTotal { - as.rrmTotal = 0 - close(as.rrmDoneCh) - as.rrmDoneCh = nil + if len(m) == 0 { + // An empty payload means that we need to remove the retained message. + rmSeq := as.removeRetainedMsg(rm.Subject, 0) + if local { + if rmSeq > 0 { + // This is for backward compatibility reasons. + // Should be removed in a future release. + as.notifyRetainedMsgDeleted(rm.Subject, rmSeq) + } + // Delete this very message we just processed, we don't need it anymore. + as.deleteRetainedMsg(seq) } + } else { + // Add this retained message. The `rm.Msg` references some buffer that we + // don't own. But addRetainedMsg() will take care of making a copy of + // `rm.Msg` it `rm` ends-up being stored in the cache. + as.addRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm) } - as.mu.Unlock() } +// NOTE: This is maintained for backward compatibility reasons. Should be removed in 2.14/2.15? func (as *mqttAccountSessionManager) processRetainedMsgDel(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) { idHash := tokenAt(subject, 3) if idHash == _EMPTY_ || idHash == as.jsa.id { @@ -2042,7 +2013,7 @@ func (as *mqttAccountSessionManager) processRetainedMsgDel(_ *subscription, c *c if err := json.Unmarshal(msg, &drm); err != nil { return } - as.handleRetainedMsgDel(drm.Subject, drm.Seq) + as.removeRetainedMsg(drm.Subject, drm.Seq) } // This will receive all JS API replies for a request to store a session record, @@ -2292,7 +2263,7 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc // If a message for this topic already existed, the existing record is updated // with the provided information. // Lock not held on entry. -func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) { +func (as *mqttAccountSessionManager) addRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) { as.mu.Lock() defer as.mu.Unlock() if as.retmsgs == nil { @@ -2300,23 +2271,9 @@ func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetai as.sl = NewSublistWithCache() } else { // Check if we already had one retained message. If so, update the existing one. - if erm, exists := as.retmsgs[key]; exists { - // If the new sequence is below the floor or the existing one, - // then ignore the new one. - if rf.sseq <= erm.sseq || rf.sseq <= erm.floor { - return - } - // Capture existing sequence number so we can return it as the old sequence. - erm.sseq = rf.sseq - // Clear the floor - erm.floor = 0 - // If sub is nil, it means that it was removed from sublist following a - // network delete. So need to add it now. - if erm.sub == nil { - erm.sub = &subscription{subject: []byte(key)} - as.sl.Insert(erm.sub) - } - + if erf, exists := as.retmsgs[key]; exists { + // Update the stream sequence with the new value. + erf.sseq = rf.sseq // Update the in-memory retained message cache but only for messages // that are already in the cache, i.e. have been (recently) used. // If that is the case, we ask setCachedRetainedMsg() to make a copy @@ -2325,49 +2282,34 @@ func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetai return } } - rf.sub = &subscription{subject: []byte(key)} as.retmsgs[key] = rf as.sl.Insert(rf.sub) } -// Removes the retained message for the given `subject` if present, and returns the -// stream sequence it was stored at. It will be 0 if no retained message was removed. -// If a sequence is passed and not 0, then the retained message will be removed only -// if the given sequence is equal or higher to what is stored. -// -// No lock held on entry. -func (as *mqttAccountSessionManager) handleRetainedMsgDel(subject string, seq uint64) uint64 { - var seqToRemove uint64 +// Remove the retained message stored with the `subject` key from the map/cache. +// When invoked from the retained message stream's consumer, this function will +// be called with `seq == 0`, this is because add/remove are serialized in this +// stream and so the request is to remove the current retained message. +// But in some conditions, we will invoke this function from some other places +// with `seq > 0` which means that the retained message will be removed only if +// its sequence is the same than the provided one. +// This function returns the sequence associated with the existing retained +// message that is being removed (used with `seq == 0`) and returns 0 if the +// retained message was not removed from the map (not found or sequence did not +// match). +func (as *mqttAccountSessionManager) removeRetainedMsg(subject string, seq uint64) uint64 { as.mu.Lock() - if as.retmsgs == nil { - as.retmsgs = make(map[string]*mqttRetainedMsgRef) - as.sl = NewSublistWithCache() - } - if erm, ok := as.retmsgs[subject]; ok { - if as.rmsCache != nil { - as.rmsCache.Delete(subject) - } - if erm.sub != nil { - as.sl.Remove(erm.sub) - erm.sub = nil - } - // If processing a delete request from the network, then seq will be > 0. - // If that is the case and it is greater or equal to what we have, we need - // to record the floor for this subject. - if seq != 0 && seq >= erm.sseq { - erm.sseq = 0 - erm.floor = seq - } else if seq == 0 { - delete(as.retmsgs, subject) - seqToRemove = erm.sseq - } - } else if seq != 0 { - rf := &mqttRetainedMsgRef{floor: seq} - as.retmsgs[subject] = rf + defer as.mu.Unlock() + rm, ok := as.retmsgs[subject] + if !ok || (seq > 0 && rm.sseq != seq) { + return 0 } - as.mu.Unlock() - return seqToRemove + seq = rm.sseq + as.rmsCache.Delete(subject) + delete(as.retmsgs, subject) + as.sl.Remove(rm.sub) + return seq } // First check if this session's client ID is already in the "locked" map, @@ -2493,9 +2435,9 @@ func (sess *mqttSession) processSub( } if len(rms) > 0 { - for _, ss := range subs { - as.serializeRetainedMsgsForSub(rms, sess, c, ss, trace) - } + // Only deal with retained messages for the normal subscription, + // not the shadow one (which is for a different account and subject). + as.serializeRetainedMsgsForSub(rms, sess, c, sub, trace) } return sub, nil @@ -2518,10 +2460,6 @@ func (sess *mqttSession) processSub( func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, filters []*mqttFilter, fromSubProto, trace bool) ([]*subscription, error) { - c.mu.Lock() - acc := c.acc - c.mu.Unlock() - // Helper to determine if we need to create a separate top-level // subscription for a wildcard. fwc := func(subject string) (bool, string, string) { @@ -2536,7 +2474,7 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, return true, fwcsubject, fwcsid } - rmSubjects := map[string]struct{}{} + rmSubjects := map[string]uint64{} // Preload retained messages for all requested subscriptions. Also, since // it's the first iteration over the filter list, do some cleanup. for _, f := range filters { @@ -2568,43 +2506,16 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, // Find retained messages. if fromSubProto { - addRMSubjects := func(subject string) error { - sub := &subscription{ - client: c, - subject: []byte(subject), - sid: []byte(subject), - } - if err := c.addShadowSubscriptions(acc, sub, false); err != nil { - return err - } - - for _, sub := range append([]*subscription{sub}, sub.shadow...) { - as.addRetainedSubjectsForSubject(rmSubjects, bytesToString(sub.subject)) - for _, ss := range sub.shadow { - as.addRetainedSubjectsForSubject(rmSubjects, bytesToString(ss.subject)) - } - } - return nil - } - - if err := addRMSubjects(f.filter); err != nil { - f.qos = mqttSubAckFailure - continue - } + as.addRetainedSubjectsForSubject(rmSubjects, f.filter) if need, subject, _ := fwc(f.filter); need { - if err := addRMSubjects(subject); err != nil { - f.qos = mqttSubAckFailure - continue - } + as.addRetainedSubjectsForSubject(rmSubjects, subject) } } } - serializeRMS := len(rmSubjects) > 0 var rms map[string]*mqttRetainedMsg - if serializeRMS { - // Make the best effort to load retained messages. We will identify - // errors in the next pass. + if len(rmSubjects) > 0 { + // Make the best effort to load retained messages. rms = as.loadRetainedMessages(rmSubjects, c) } @@ -2725,13 +2636,13 @@ func (as *mqttAccountSessionManager) processSubs(sess *mqttSession, c *client, // Runs from the client's readLoop. // Account session manager lock held on entry. // Session lock held on entry. -func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]*mqttRetainedMsg, sess *mqttSession, c *client, sub *subscription, trace bool) error { +func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string]*mqttRetainedMsg, sess *mqttSession, c *client, sub *subscription, trace bool) { if len(as.retmsgs) == 0 || len(rms) == 0 { - return nil + return } result := as.sl.ReverseMatch(string(sub.subject)) if len(result.psubs) == 0 { - return nil + return } toTrace := []mqttPublish{} for _, psub := range result.psubs { @@ -2743,10 +2654,7 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string] continue } var pi uint16 - qos := mqttGetQoS(rm.Flags) - if qos > sub.mqtt.qos { - qos = sub.mqtt.qos - } + qos := min(mqttGetQoS(rm.Flags), sub.mqtt.qos) if c.mqtt.rejectQoS2Pub && qos == 2 { c.Warnf("Rejecting retained message with QoS2 for subscription %q, as configured", sub.subject) continue @@ -2780,33 +2688,35 @@ func (as *mqttAccountSessionManager) serializeRetainedMsgsForSub(rms map[string] for _, pp := range toTrace { c.traceOutOp("PUBLISH", []byte(mqttPubTrace(&pp))) } - return nil } // Appends the stored message subjects for all retained message records that // match the given subscription's `subject` (which could have wildcards). // // Account session manager NOT lock held on entry. -func (as *mqttAccountSessionManager) addRetainedSubjectsForSubject(list map[string]struct{}, topSubject string) bool { +func (as *mqttAccountSessionManager) addRetainedSubjectsForSubject(list map[string]uint64, topSubject string) { as.mu.RLock() if len(as.retmsgs) == 0 { as.mu.RUnlock() - return false + return } result := as.sl.ReverseMatch(topSubject) as.mu.RUnlock() - added := false for _, sub := range result.psubs { - subject := string(sub.subject) - if _, ok := list[subject]; ok { + if _, ok := list[string(sub.subject)]; ok { continue } - list[subject] = struct{}{} - added = true + var seq uint64 + as.mu.RLock() + if rm, ok := as.retmsgs[string(sub.subject)]; ok { + seq = rm.sseq + } + as.mu.RUnlock() + if seq > 0 { + list[string(sub.subject)] = seq + } } - - return added } type warner interface { @@ -2814,7 +2724,7 @@ type warner interface { } // Loads a list of retained messages given a list of stored message subjects. -func (as *mqttAccountSessionManager) loadRetainedMessages(subjects map[string]struct{}, w warner) map[string]*mqttRetainedMsg { +func (as *mqttAccountSessionManager) loadRetainedMessages(subjects map[string]uint64, w warner) map[string]*mqttRetainedMsg { rms := make(map[string]*mqttRetainedMsg, len(subjects)) ss := []string{} for s := range subjects { @@ -2829,6 +2739,11 @@ func (as *mqttAccountSessionManager) loadRetainedMessages(subjects map[string]st return rms } + // Although we have the stream sequence for a given subject, we still use + // the load with "last for subject" because it will cover the cases where a + // new retained message has arrived since we collected the subject/seq pair. + // If we were doing a load "by seq" and the message is not found, we would + // incorrectly remove the retained message from our map. results, err := as.jsa.loadLastMsgForMulti(mqttRetainedMsgsStreamName, ss) // If an error occurred, warn, but then proceed with what we got. if err != nil { @@ -2838,26 +2753,48 @@ func (as *mqttAccountSessionManager) loadRetainedMessages(subjects map[string]st if result == nil { continue // skip requests that timed out } - if result.ToError() != nil { - w.Warnf("failed to load retained message for subject %q: %v", ss[i], err) + if err := result.ToError(); err != nil { + // Skip the "$MQTT.rmsgs." prefix... + subj := ss[i][len(mqttRetainedMsgsStreamSubject):] + if IsNatsErr(err, JSNoMessageFoundErr) { + // If there is no message for that subject, delete from our map. + // The good thing here is that we handle the race where a retained + // message may just arrive and be replacing it in the map. The + // removeRetainedMsg() function below will not remove if the sequence + // does not match. + seq := subjects[subj] + as.removeRetainedMsg(subj, seq) + } + w.Warnf("failed to load retained message for subject %q: %v", subj, err) continue } - rm, err := mqttDecodeRetainedMessage(result.Message.Header, result.Message.Data) + rm, err := mqttDecodeRetainedMessage(result.Message.Subject, result.Message.Header, result.Message.Data) if err != nil { - w.Warnf("failed to decode retained message for subject %q: %v", ss[i], err) + // Unlikely that we can recover from that, so remove the message. + // (see comment above if failing to load the message). + subj := ss[i][len(mqttRetainedMsgsStreamSubject):] + seq := subjects[subj] + as.removeRetainedMsg(subj, seq) + w.Warnf("failed to decode retained message for subject %q: %v", subj, err) continue } // Add the loaded retained message to the cache, and to the results map. - key := ss[i][len(mqttRetainedMsgsStreamSubject):] - as.setCachedRetainedMsg(key, rm, false, false) - rms[key] = rm + // We don't need setCachedRetainedMsg() to clone the `rm.Msg` bytes slice + // since we own it. + as.setCachedRetainedMsg(rm.Subject, rm, false, false) + rms[rm.Subject] = rm } return rms } // Composes a NATS message for a storeable mqttRetainedMsg. +// If the body is empty, the flags are encoded in a way that will cause older +// servers to fail to decode the message in processRetainedMsg callback and +// will simply ignore it, which is what we want. func mqttEncodeRetainedMessage(rm *mqttRetainedMsg) (natsMsg []byte, headerLen int) { + delRM := len(rm.Msg) == 0 + // No need to encode the subject, we can restore it from topic. l := len(hdrLine) l += len(mqttNatsRetainedMessageTopic) + 1 + len(rm.Topic) + 2 // 1 byte for ':', 2 bytes for CRLF @@ -2869,7 +2806,11 @@ func mqttEncodeRetainedMessage(rm *mqttRetainedMsg) (natsMsg []byte, headerLen i } l += len(mqttNatsRetainedMessageFlags) + 1 + 2 + 2 // 1 byte for ':', 2 bytes for the flags, 2 bytes for CRLF l += 2 // 2 bytes for the extra CRLF after the header - l += len(rm.Msg) + if delRM { + l++ // Will add the delete marker before the flag + } else { + l += len(rm.Msg) + } buf := bytes.NewBuffer(make([]byte, 0, l)) @@ -2882,6 +2823,9 @@ func mqttEncodeRetainedMessage(rm *mqttRetainedMsg) (natsMsg []byte, headerLen i buf.WriteString(mqttNatsRetainedMessageFlags) buf.WriteByte(':') + if delRM { + buf.WriteByte(mqttRetainedFlagDelMarker) + } buf.WriteString(strconv.FormatUint(uint64(rm.Flags), 16)) buf.WriteString(_CRLF_) @@ -2905,30 +2849,111 @@ func mqttEncodeRetainedMessage(rm *mqttRetainedMsg) (natsMsg []byte, headerLen i return buf.Bytes(), headerLen } -func mqttDecodeRetainedMessage(h, m []byte) (*mqttRetainedMsg, error) { - fHeader := getHeader(mqttNatsRetainedMessageFlags, h) +func mqttSliceHeaders(headers map[string][]byte, hdr []byte) { + // Skip the hdrLine + if !bytes.HasPrefix(hdr, stringToBytes(hdrLine)) { + return + } + crLFAsBytes := stringToBytes(CR_LF) + for i := len(hdrLine); i < len(hdr); { + // Search for key/val delimiter. + del := bytes.IndexByte(hdr[i:], ':') + // Not found or key is length 0, we stop. + if del < 0 || del == i { + break + } + keyStart := i + // Walk back to remove spaces between the key and ':' if applicable. + index := keyStart + del - 1 + for index > keyStart && hdr[index] == ' ' { + index-- + } + key := hdr[keyStart : index+1] + // If what we had is only spaces, we stop. + if len(key) == 0 { + break + } + i += del + 1 + valStart := i + // Search for `\r\n`. + nl := bytes.Index(hdr[valStart:], crLFAsBytes) + // If we don't find, we stop. + if nl < 0 { + break + } + // Look if the caller is interested in this key. + if _, ok := headers[bytesToString(key)]; ok { + index := valStart + // Remove possible spaces between the ':' and the value. + for index < valStart+nl && hdr[index] == ' ' { + index++ + } + // Create a slice and limit capacity to the value range. + val := hdr[index : valStart+nl : valStart+nl] + // Record in the caller's map the value for this key. + headers[bytesToString(key)] = val + } + // Reposition to past the `\r\n`. + i += nl + 2 + } +} + +// Decodes a retained message based on the content of the header `h`. +// The returned `*mqttRetainedMsg` object will hold a reference to `m`. +// If the buffer `m` is not owned by the caller, it is the caller +// responsibility to make a copy of the byte slice. +func mqttDecodeRetainedMessage(subject string, h, m []byte) (*mqttRetainedMsg, error) { + headers := map[string][]byte{ + mqttNatsRetainedMessageOrigin: nil, + mqttNatsRetainedMessageFlags: nil, + mqttNatsRetainedMessageSource: nil, + } + var rm *mqttRetainedMsg + // Retrieve the values for the above headers. + mqttSliceHeaders(headers, h) + // Get the flag header. + fHeader := headers[mqttNatsRetainedMessageFlags] + // If we don't, it could be that this is an old retained message that + // was JSON encoded. if len(fHeader) > 0 { - flags, err := strconv.ParseUint(string(fHeader), 16, 8) + if len(fHeader) > 1 && fHeader[0] == mqttRetainedFlagDelMarker { + fHeader = fHeader[1:] + } + flagsUint, err := strconv.ParseUint(bytesToString(fHeader), 16, 8) if err != nil { - return nil, fmt.Errorf("invalid retained message flags: %v", err) - } - topic := getHeader(mqttNatsRetainedMessageTopic, h) - subj, _ := mqttToNATSSubjectConversion(topic, false) - return &mqttRetainedMsg{ - Flags: byte(flags), - Subject: string(subj), - Topic: string(topic), - Origin: string(getHeader(mqttNatsRetainedMessageOrigin, h)), - Source: string(getHeader(mqttNatsRetainedMessageSource, h)), - Msg: m, - }, nil + // Since the error is currently not reported in the server, we + // will simply replace with this one. + return nil, errMQTTInvalidRetainFlags + } + rm = &mqttRetainedMsg{ + Flags: byte(flagsUint), + Origin: string(headers[mqttNatsRetainedMessageOrigin]), + Source: string(headers[mqttNatsRetainedMessageSource]), + Msg: m, + } } else { - var rm mqttRetainedMsg if err := json.Unmarshal(m, &rm); err != nil { return nil, err } - return &rm, nil } + // Now check that the values are correct. + // + // For "Flags", anything at or above binary (1111) is too big. + if rm.Flags >= mqttPacketFlagMask { + return nil, errMQTTInvalidRetainFlags + } + if qos := mqttGetQoS(rm.Flags); qos > 2 { + return nil, errMQTTInvalidRetainFlags + } + // We store `Topic` in the retained message because we used to store + // all retained messages under the same subject `$MQTT_rmsgs` in + // the retained messages stream. That is no longer the case, and to + // cover setups where the retained message stream is sourced from another + // account and has some subject transforms, simply reconstruct the + // topic/subject based on the `subject` passed to this function. + rm.Subject = strings.TrimPrefix(subject, mqttRetainedMsgsStreamSubject) + rm.Topic = bytesToString(natsSubjectStrToMQTTTopic(rm.Subject)) + return rm, nil } // Creates the session stream (limit msgs of 1) for this client ID if it does @@ -2986,6 +3011,7 @@ func (as *mqttAccountSessionManager) deleteRetainedMsg(seq uint64) { // Sends a message indicating that a retained message on a given subject and stream sequence // is being removed. +// NOTE: This is maintained for backward compatibility reasons. Should be removed in 2.14/2.15? func (as *mqttAccountSessionManager) notifyRetainedMsgDeleted(subject string, seq uint64) { req := mqttRetMsgDel{ Subject: subject, @@ -3115,9 +3141,6 @@ func (as *mqttAccountSessionManager) transferRetainedToPerKeySubjectStream(log * } func (as *mqttAccountSessionManager) getCachedRetainedMsg(subject string) *mqttRetainedMsg { - if as.rmsCache == nil { - return nil - } v, ok := as.rmsCache.Load(subject) if !ok { return nil @@ -3141,7 +3164,7 @@ func (as *mqttAccountSessionManager) getCachedRetainedMsg(subject string) *mqttR // value (all `true` or all `false`) however we use different booleans to // better express the intent. func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mqttRetainedMsg, onlyReplace, copyMsgBytes bool) { - if as.rmsCache == nil || rm == nil { + if rm == nil { return } rm.expiresFromCache = time.Now().Add(mqttRetainedCacheTTL) @@ -4364,13 +4387,14 @@ func (c *client) mqttHandlePubRetain() { // Spec [MQTT-3.3.1-11]. Payload of size 0 removes the retained message, but // should still be delivered as a normal message. - if pp.sz == 0 { - if seqToRemove := asm.handleRetainedMsgDel(key, 0); seqToRemove > 0 { - asm.deleteRetainedMsg(seqToRemove) - asm.notifyRetainedMsgDeleted(key, seqToRemove) - } - return - } + // + // We used to delete the message here from our map, the stream, and notify + // the network about the delete. We no longer do that. Instead, we store + // the message with an empty body. When servers will get the empty body + // in processRetainedMsg, then will remove the message from their map. This + // effectively serializes all add/remove of retained messages without the + // need for "network" notifications about deletes (we still support that + // for backward compatibility but will be pulled in future releases). rm := &mqttRetainedMsg{ Origin: asm.jsa.id, @@ -4403,11 +4427,13 @@ func (c *client) mqttHandlePubRetain() { // Store the retained message with the RETAIN flag set. rm.Flags |= mqttPubFlagRetain - // Copy the payload out of pp since we will be sending the message - // asynchronously. - msg := make([]byte, pp.sz) - copy(msg, pp.msg[:pp.sz]) - asm.jsa.sendMsg(key, msg) + if pp.sz > 0 { + // Copy the payload out of pp since we will be sending the message + // asynchronously. + msg := make([]byte, pp.sz) + copy(msg, pp.msg[:pp.sz]) + asm.jsa.sendMsg(key, msg) + } } else { // isRetained // Spec [MQTT-3.3.1-5]. Store the retained message with its QoS. @@ -4421,16 +4447,8 @@ func (c *client) mqttHandlePubRetain() { // $sparkplug subject for sparkB. rm.Subject = key rmBytes, hdr := mqttEncodeRetainedMessage(rm) // will copy the payload bytes - smr, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+key, hdr, rmBytes) - if err == nil { - // Update the new sequence. - rf := &mqttRetainedMsgRef{ - sseq: smr.Sequence, - } - // Add/update the map. The `rm.Msg` bytes slice will be copied if the object - // happens to be stored in the rmsCache. - asm.handleRetainedMsg(key, rf, rm) - } else { + _, err := asm.jsa.storeMsg(mqttRetainedMsgsStreamSubject+key, hdr, rmBytes) + if err != nil { c.mu.Lock() acc := c.acc c.mu.Unlock() @@ -4490,21 +4508,23 @@ func (s *Server) mqttCheckPubRetainedPerms() { rmsg: rf, }) } + jsaID := asm.jsa.id asm.mu.RUnlock() slices.SortFunc(rms, func(i, j retainedMsg) int { return cmp.Compare(i.rmsg.sseq, j.rmsg.sseq) }) perms := map[string]*perm{} - deletes := map[string]uint64{} for _, rf := range rms { jsm, err := asm.jsa.loadMsg(mqttRetainedMsgsStreamName, rf.rmsg.sseq) if err != nil || jsm == nil { continue } - rm, err := mqttDecodeRetainedMessage(jsm.Header, jsm.Data) + rm, err := mqttDecodeRetainedMessage(jsm.Subject, jsm.Header, jsm.Data) if err != nil { continue } - if rm.Source == _EMPTY_ { + // We deal only with messages that have a source (the username that produced + // this message) and were produced on this server. + if rm.Source == _EMPTY_ || rm.Origin != jsaID { continue } // Lookup source from global users. @@ -4523,20 +4543,15 @@ func (s *Server) mqttCheckPubRetainedPerms() { } // Not present or permissions have changed such that the source can't - // publish on that subject anymore: remove it from the map. + // publish on that subject anymore: delete this retained message. if u == nil { - asm.mu.Lock() - delete(asm.retmsgs, rf.subj) - asm.sl.Remove(rf.rmsg.sub) - asm.mu.Unlock() - deletes[rf.subj] = rf.rmsg.sseq + // Set the payload to empty to notify that we are deleting this + // retained message. We will send this message async. + rm.Msg = nil + rmBytes, hdrLen := mqttEncodeRetainedMessage(rm) + asm.jsa.storeMsgNoWait(mqttRetainedMsgsStreamSubject+rm.Subject, hdrLen, rmBytes) } } - - for subject, seq := range deletes { - asm.deleteRetainedMsg(seq) - asm.notifyRetainedMsgDeleted(subject, seq) - } } } diff --git a/server/mqtt_ex_bench_test.go b/server/mqtt_ex_bench_test.go index 2d14c3f78c8..8767e87aac6 100644 --- a/server/mqtt_ex_bench_test.go +++ b/server/mqtt_ex_bench_test.go @@ -187,8 +187,6 @@ func (bc mqttBenchContext) runAndReport(b *testing.B, name string, extraArgs ... func (bc *mqttBenchContext) startServer(b *testing.B, disableRMSCache bool) func() { b.Helper() b.StopTimer() - prevDisableRMSCache := testDisableRMSCache - testDisableRMSCache = disableRMSCache o := testMQTTDefaultOptions() s := testMQTTRunServer(b, o) @@ -198,15 +196,12 @@ func (bc *mqttBenchContext) startServer(b *testing.B, disableRMSCache bool) func mqttInitTestServer(b, mqttNewDial("", "", bc.Host, bc.Port, "")) return func() { testMQTTShutdownServer(s) - testDisableRMSCache = prevDisableRMSCache } } func (bc *mqttBenchContext) startCluster(b *testing.B, disableRMSCache bool) func() { b.Helper() b.StopTimer() - prevDisableRMSCache := testDisableRMSCache - testDisableRMSCache = disableRMSCache conf := ` listen: 127.0.0.1:-1 server_name: %s @@ -234,7 +229,6 @@ func (bc *mqttBenchContext) startCluster(b *testing.B, disableRMSCache bool) fun mqttInitTestServer(b, mqttNewDial("", "", bc.Host, bc.Port, "")) return func() { cl.shutdown() - testDisableRMSCache = prevDisableRMSCache } } diff --git a/server/mqtt_test.go b/server/mqtt_test.go index 225a9201e6b..17493a20aa4 100644 --- a/server/mqtt_test.go +++ b/server/mqtt_test.go @@ -795,6 +795,15 @@ func testMQTTGetClient(t testing.TB, s *Server, clientID string) *client { return mc } +func testMQTTGetAccountSessionManager(t *testing.T, s *Server, cid string) *mqttAccountSessionManager { + t.Helper() + c := testMQTTGetClient(t, s, cid) + require_NotNil(t, c) + asm := c.mqtt.asm + require_NotNil(t, asm) + return asm +} + func testMQTTRead(c net.Conn) ([]byte, error) { var buf [512]byte // Make sure that test does not block @@ -2023,6 +2032,11 @@ func testMQTTFlush(t testing.TB, c net.Conn, bw *bufio.Writer, r *mqttReader) { func testMQTTExpectNothing(t testing.TB, r *mqttReader) { t.Helper() + // First, check that we don't have buffered data. + if r.hasMore() { + t.Fatalf("Expected nothing, got %v", r.buf[r.pos:]) + } + // Then, try to read from the reader with some timeout. var buf [128]byte r.reader.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) if n, err := r.reader.Read(buf[:]); err == nil { @@ -3226,103 +3240,6 @@ func TestMQTTClusterRetainedMsg(t *testing.T) { testMQTTCheckPubMsg(t, mc, rc, "bar", mqttPubQos1|mqttPubFlagRetain, []byte("msg2")) } -func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) { - o := testMQTTDefaultOptions() - s := testMQTTRunServer(t, o) - defer testMQTTShutdownServer(s) - - mc, rc := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) - defer mc.Close() - testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false) - - c := testMQTTGetClient(t, s, "sub") - asm := c.mqtt.asm - - // For this test, we are going to simulate updates arriving in a - // mixed order and verify that we have the expected outcome. - check := func(t *testing.T, subject string, present bool, current, floor uint64) { - t.Helper() - asm.mu.RLock() - defer asm.mu.RUnlock() - erm, ok := asm.retmsgs[subject] - if present && !ok { - t.Fatalf("Subject %q not present", subject) - } else if !present && ok { - t.Fatalf("Subject %q should not be present", subject) - } else if !present { - return - } - if floor != erm.floor { - t.Fatalf("Expected floor to be %v, got %v", floor, erm.floor) - } - if erm.sseq != current { - t.Fatalf("Expected current sequence to be %v, got %v", current, erm.sseq) - } - } - - type action struct { - add bool - seq uint64 - } - for _, test := range []struct { - subject string - order []action - seq uint64 - floor uint64 - }{ - {"foo.1", []action{{true, 1}, {true, 2}, {true, 3}}, 3, 0}, - {"foo.2", []action{{true, 3}, {true, 1}, {true, 2}}, 3, 0}, - {"foo.3", []action{{true, 1}, {false, 1}, {true, 2}}, 2, 0}, - {"foo.4", []action{{false, 2}, {true, 1}, {true, 3}, {true, 2}}, 3, 0}, - {"foo.5", []action{{false, 2}, {true, 1}, {true, 2}}, 0, 2}, - {"foo.6", []action{{true, 1}, {true, 2}, {false, 2}}, 0, 2}, - } { - t.Run(test.subject, func(t *testing.T) { - for _, a := range test.order { - if a.add { - rf := &mqttRetainedMsgRef{sseq: a.seq} - asm.handleRetainedMsg(test.subject, rf, nil) - } else { - asm.handleRetainedMsgDel(test.subject, a.seq) - } - } - check(t, test.subject, true, test.seq, test.floor) - }) - } - - for _, subject := range []string{"foo.5", "foo.6"} { - t.Run("clear_"+subject, func(t *testing.T) { - // Now add a new message, which should clear the floor. - rf := &mqttRetainedMsgRef{sseq: 3} - asm.handleRetainedMsg(subject, rf, nil) - check(t, subject, true, 3, 0) - // Now do a non network delete and make sure it is gone. - asm.handleRetainedMsgDel(subject, 0) - check(t, subject, false, 0, 0) - }) - } -} - -func TestMQTTRetainedMsgDel(t *testing.T) { - o := testMQTTDefaultOptions() - s := testMQTTRunServer(t, o) - defer testMQTTShutdownServer(s) - mc, _ := testMQTTConnect(t, &mqttConnInfo{clientID: "sub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) - defer mc.Close() - - c := testMQTTGetClient(t, s, "sub") - asm := c.mqtt.asm - var i uint64 - for i = 0; i < 3; i++ { - rf := &mqttRetainedMsgRef{sseq: i} - asm.handleRetainedMsg("subject", rf, nil) - } - asm.handleRetainedMsgDel("subject", 2) - if asm.sl.count > 0 { - t.Fatalf("all retained messages subs should be removed, but %d still present", asm.sl.count) - } -} - func TestMQTTRetainedMsgMigration(t *testing.T) { o := testMQTTDefaultOptions() s := testMQTTRunServer(t, o) @@ -3369,6 +3286,16 @@ func TestMQTTRetainedMsgMigration(t *testing.T) { defer mc.Close() testMQTTCheckConnAck(t, rc, mqttConnAckRCConnectionAccepted, false) + as := testMQTTGetAccountSessionManager(t, s, "sub") + checkFor(t, time.Second, 10*time.Millisecond, func() error { + as.mu.RLock() + defer as.mu.RUnlock() + if n := len(as.retmsgs); n != N { + return fmt.Errorf("Got only %v retained messages", n) + } + return nil + }) + testMQTTSub(t, 1, mc, rc, []*mqttFilter{{filter: "+", qos: 0}}, []byte{0}) topics := map[string]struct{}{} for i := 0; i < N; i++ { @@ -3413,7 +3340,7 @@ func TestMQTTRetainedNoMsgBodyCorruption(t *testing.T) { defer testMQTTShutdownServer(s) // Send a retained message. - c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + c, r := testMQTTConnect(t, &mqttConnInfo{clientID: "pub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) defer c.Close() testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) testMQTTPublish(t, c, r, 0, false, true, "foo/bar", 0, []byte("retained 1")) @@ -3466,13 +3393,8 @@ func TestMQTTRetainedNoMsgBodyCorruption(t *testing.T) { } }() - s.mu.RLock() - sm := &s.mqtt.sessmgr - s.mu.RUnlock() - sm.mu.RLock() - as := sm.sessions[globalAccountName] - sm.mu.RUnlock() - require_True(t, as != nil) + // Retrieve the account session manager using the "pub" client we have. + as := testMQTTGetAccountSessionManager(t, s, "pub") as.mu.RLock() cache := as.rmsCache as.mu.RUnlock() @@ -4540,7 +4462,7 @@ func TestMQTTPublishRetainPermViolation(t *testing.T) { Username: "mqtt3", Password: "pass", Permissions: &Permissions{ - Publish: &SubjectPermission{Allow: []string{"foo.bar", "baz"}}, + Publish: &SubjectPermission{Allow: []string{"foo.bar", "baz", "barbaz"}}, Subscribe: &SubjectPermission{Allow: []string{">"}}, }, }, @@ -4552,87 +4474,115 @@ func TestMQTTPublishRetainPermViolation(t *testing.T) { s := testMQTTRunServer(t, o) defer testMQTTShutdownServer(s) - pubRetained := func(user, pass, subject string) { + var asm *mqttAccountSessionManager + + pubRetained := func(user, subject string) { t.Helper() mc, rs := testMQTTConnect(t, &mqttConnInfo{ cleanSess: true, + clientID: "pub", user: user, - pass: pass, + pass: "pass", }, o.MQTT.Host, o.MQTT.Port) defer mc.Close() testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) testMQTTPublish(t, mc, rs, 0, false, true, subject, 0, []byte("retained")) testMQTTFlush(t, mc, nil, rs) + if asm == nil { + asm = testMQTTGetAccountSessionManager(t, s, "pub") + } testMQTTDisconnect(t, mc, nil) } - consumeRetained := func(user, pass, subject string) { - t.Helper() - mc, rs := testMQTTConnect(t, &mqttConnInfo{ - cleanSess: true, - user: user, - pass: pass, - }, o.MQTT.Host, o.MQTT.Port) - defer mc.Close() - testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) - testMQTTSub(t, 1, mc, rs, []*mqttFilter{{filter: subject, qos: 0}}, []byte{0}) - testMQTTCheckPubMsg(t, mc, rs, subject, mqttPubFlagRetain, []byte("retained")) - testMQTTDisconnect(t, mc, nil) - } - consumeRetainedFail := func(user, pass, subject string) { + consumeRetained := func(user, subject string, expected bool) { t.Helper() mc, rs := testMQTTConnect(t, &mqttConnInfo{ cleanSess: true, user: user, - pass: pass, + pass: "pass", }, o.MQTT.Host, o.MQTT.Port) defer mc.Close() testMQTTCheckConnAck(t, rs, mqttConnAckRCConnectionAccepted, false) testMQTTSub(t, 1, mc, rs, []*mqttFilter{{filter: subject, qos: 0}}, []byte{0}) - testMQTTExpectNothing(t, rs) + if expected { + testMQTTCheckPubMsg(t, mc, rs, subject, mqttPubFlagRetain, []byte("retained")) + } else { + testMQTTExpectNothing(t, rs) + } testMQTTDisconnect(t, mc, nil) } // With user "mqtt", publish a retained message on "bar". // Since this user has no permission, the server should not have stored it. - pubRetained("mqtt1", "pass", "bar") + pubRetained("mqtt1", "bar") // Verify that we can't get it with a new subscription. - consumeRetainedFail("mqtt1", "pass", "bar") + consumeRetained("mqtt1", "bar", false) // Use the user "mqtt2" that has permissions to publish on foo and bar. // Publish on "foo" and check retained message can be received. - pubRetained("mqtt2", "pass", "foo") - consumeRetained("mqtt2", "pass", "foo") + pubRetained("mqtt2", "foo") + consumeRetained("mqtt2", "foo", true) // For user "mqtt3", we will publish on "foo/bar" and check retained // message is properly received. - pubRetained("mqtt3", "pass", "foo/bar") - consumeRetained("mqtt3", "pass", "foo/bar") + pubRetained("mqtt3", "foo/bar") + consumeRetained("mqtt3", "foo/bar", true) + + // Simulate a message that would have been produced in a different server + // on subject "barbaz". We will use user "mqtt4" that has no pub permissions + // since we need to send to low-level "$MQTT.rmsgs.barbaz" subject... + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("mqtt4", "pass")) + defer nc.Close() + msg := nats.NewMsg("$MQTT.rmsgs.barbaz") + msg.Header.Set(mqttNatsRetainedMessageOrigin, "SomeOtherServer") + msg.Header.Set(mqttNatsRetainedMessageTopic, "barbaz") + msg.Header.Set(mqttNatsRetainedMessageFlags, "1") + msg.Data = []byte("retained") + nc.PublishMsg(msg) + natsFlush(t, nc) + + // Wait a bit to make sure it is processed. + time.Sleep(250 * time.Millisecond) + + // Then check that it can be received + consumeRetained("mqtt3", "barbaz", true) // Same with user "mqtt4" that does not have permissions defined, which // means allowed to pub/sub on everything. - pubRetained("mqtt4", "pass", "bat") - consumeRetained("mqtt4", "pass", "bat") + pubRetained("mqtt4", "bat") + consumeRetained("mqtt4", "bat", true) // Do a config reload and make sure that the server does not panic // and we can still get the retained messages. no := *o // Remove the "bar" publish permission from "mqtt2" no.Users[1].Permissions.Publish = &SubjectPermission{Allow: []string{"foo"}} - // And the "foo.bar" publish permission from "mqtt3" + // And the "foo.bar" and "barbaz" publish permissions from "mqtt3" no.Users[2].Permissions.Publish = &SubjectPermission{Allow: []string{"baz"}} err := s.ReloadOptions(&no) require_NoError(t, err) + checkFor(t, time.Second, 10*time.Millisecond, func() error { + asm.mu.RLock() + defer asm.mu.RUnlock() + if _, ok := asm.retmsgs["foo.bar"]; ok { + return errors.New("foo.bar subject still in map") + } + return nil + }) + // Still message on "bar" should not exist - consumeRetainedFail("mqtt1", "pass", "bar") + consumeRetained("mqtt1", "bar", false) // This one should still be able to be received - consumeRetained("mqtt2", "pass", "foo") + consumeRetained("mqtt2", "foo", true) // Retained message on "foo.bar" should have been removed. - consumeRetainedFail("mqtt3", "pass", "foo/bar") + consumeRetained("mqtt3", "foo/bar", false) + // However, message on "barbaz" should have been left alone since it + // was produced on a different server. + consumeRetained("mqtt3", "barbaz", true) // And finally, this user that had no permission should still be able // to get the retained message on "bat". - consumeRetained("mqtt4", "pass", "bat") + consumeRetained("mqtt4", "bat", true) } func TestMQTTPublishViolation(t *testing.T) { @@ -5076,8 +5026,7 @@ func TestMQTTFlappingSession(t *testing.T) { testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) // Let's get a handle on the asm to check things later. - cli := testMQTTGetClient(t, s, "flapper") - asm := cli.mqtt.asm + asm := testMQTTGetAccountSessionManager(t, s, "flapper") // Start a new connection with the same clientID, which should replace // the old one and put it in the flappers map. @@ -5323,8 +5272,7 @@ func TestMQTTRetainedMsgCleanup(t *testing.T) { time.Sleep(2 * mqttRetainedCacheTTL) // Make sure not in cache anymore - cli := testMQTTGetClient(t, s, "cache") - asm := cli.mqtt.asm + asm := testMQTTGetAccountSessionManager(t, s, "cache") if v, ok := asm.rmsCache.Load("foo"); ok { t.Fatalf("Should not be in cache, got %+v", v) } @@ -8208,6 +8156,412 @@ func TestMQTTMappingsQoS0(t *testing.T) { } } +func TestMQTTSliceHeadersAndDecodeRetainedMessage(t *testing.T) { + // First check low level mqttSliceHeaders + for _, test := range []struct { + name string + hdr string + expected []string + }{ + // Valid cases + {"one key and some random", hdrLine + "key1:val1\r\nsomeotherkey:someval\r\n", []string{"val1", _EMPTY_, _EMPTY_}}, + {"two keys", hdrLine + "key2:val2\r\nthisisnotkey1:someval\r\nkey1:val1\r\n", []string{"val1", "val2", _EMPTY_}}, + {"three keys", hdrLine + "key2:val2\r\nkey3:val3\r\nkey1:val1\r\n", []string{"val1", "val2", "val3"}}, + {"space before value", hdrLine + "somekey:someval\r\nkey2: val2withspacebefore\r\n", []string{_EMPTY_, "val2withspacebefore", _EMPTY_}}, + {"space between key and colon sign", hdrLine + "key1:val1\r\nkey2 :val2\r\nkey3 : val3\r\n", []string{"val1", "val2", "val3"}}, + // Error cases + {"no hdr line", "key1:val1\r\n", []string{_EMPTY_, _EMPTY_, _EMPTY_}}, + {"key length 0", hdrLine + "key1:val1\r\n:val2\r\nkey3:val3\r\n", []string{"val1", _EMPTY_, _EMPTY_}}, + {"key is only spaces", hdrLine + "key1:val1\r\nkey2:val2\r\n :val3\r\n", []string{"val1", "val2", _EMPTY_}}, + {"value no crlf", hdrLine + "key1:val1\r\nkey2:val2", []string{"val1", _EMPTY_, _EMPTY_}}, + } { + t.Run(test.name, func(t *testing.T) { + headers := map[string][]byte{ + "key1": nil, + "key2": nil, + "key3": nil, + } + mqttSliceHeaders(headers, []byte(test.hdr)) + for i := range len(headers) { + key := fmt.Sprintf("key%d", i+1) + val := string(headers[key]) + if ev := test.expected[i]; ev != val { + t.Fatalf("For key %q, expected value to be %q, got %q", key, ev, val) + } + } + }) + } + + // Now test mqttDecodeRetainedMessage() itself. + t.Run("flag with delete marker", func(t *testing.T) { + hdr := fmt.Appendf(nil, "%sNmqtt-RFlags:%c1\r\n\r\n", hdrLine, mqttRetainedFlagDelMarker) + rm, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, nil) + require_NoError(t, err) + require_Equal(t, rm.Flags, 1) + }) + t.Run("flag not a number", func(t *testing.T) { + hdr := fmt.Appendf(nil, "%sNmqtt-RFlags:bad\r\n\r\n", hdrLine) + _, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, []byte("msg")) + require_Error(t, err, errMQTTInvalidRetainFlags) + }) + t.Run("flag not a number with delete marker", func(t *testing.T) { + hdr := fmt.Appendf(nil, "%sNmqtt-RFlags:%cad\r\n\r\n", hdrLine, mqttRetainedFlagDelMarker) + _, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, []byte("msg")) + require_Error(t, err, errMQTTInvalidRetainFlags) + }) + t.Run("flag too big", func(t *testing.T) { + hdr := fmt.Appendf(nil, "%sNmqtt-RFlags:%c15\r\n\r\n", hdrLine, mqttRetainedFlagDelMarker) + _, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, []byte("msg")) + require_Error(t, err, errMQTTInvalidRetainFlags) + }) + t.Run("flag invalid qos", func(t *testing.T) { + hdr := fmt.Appendf(nil, "%sNmqtt-RFlags:%c7\r\n\r\n", hdrLine, mqttRetainedFlagDelMarker) + _, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, []byte("msg")) + require_Error(t, err, errMQTTInvalidRetainFlags) + }) + t.Run("decode retained msg with space before header value", func(t *testing.T) { + msg, hdrLen := mqttEncodeRetainedMessage(&mqttRetainedMsg{ + Topic: "foo/x", + Origin: " Origin", // Add spaces in front on purpose + Source: "Source", + Flags: 1, + Msg: []byte("msg1"), + }) + hdr := msg[:hdrLen] + msg = msg[hdrLen:] + rm, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.foo.x", hdr, msg) + require_NoError(t, err) + require_Equal(t, rm.Topic, "foo/x") + require_Equal(t, rm.Subject, "foo.x") + require_Equal(t, rm.Origin, "Origin") + require_Equal(t, rm.Source, "Source") + require_Equal(t, rm.Flags, 1) + require_Equal(t, string(rm.Msg), "msg1") + }) + t.Run("decode retained msg with subject transformed", func(t *testing.T) { + msg, hdrLen := mqttEncodeRetainedMessage(&mqttRetainedMsg{ + Topic: "foo/x", + Origin: "Origin", + Source: "Source", + Flags: 1, + Msg: []byte("msg2"), + }) + hdr := msg[:hdrLen] + msg = msg[hdrLen:] + // Use different subject when calling the function. Make sure the + // topic is properly reflecting the subject. + rm, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, msg) + require_NoError(t, err) + require_Equal(t, rm.Topic, "bar/x") + require_Equal(t, rm.Subject, "bar.x") + require_Equal(t, rm.Origin, "Origin") + require_Equal(t, rm.Source, "Source") + require_Equal(t, rm.Flags, 1) + require_Equal(t, string(rm.Msg), "msg2") + }) + t.Run("decode deleted retained message", func(t *testing.T) { + msg, hdrLen := mqttEncodeRetainedMessage(&mqttRetainedMsg{ + Topic: "foo/x", + Origin: "Origin", + Source: "Source", + Flags: 1, + Msg: nil, + }) + hdr := msg[:hdrLen] + msg = msg[hdrLen:] + // Use different subject too + rm, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", hdr, msg) + require_NoError(t, err) + require_Equal(t, rm.Topic, "bar/x") + require_Equal(t, rm.Subject, "bar.x") + require_Equal(t, rm.Origin, "Origin") + require_Equal(t, rm.Source, "Source") + require_Equal(t, rm.Flags, 1) + require_Len(t, len(rm.Msg), 0) + }) + t.Run("decode retained message as JSON with bad flags", func(t *testing.T) { + rmo := &mqttRetainedMsg{Flags: 15} + msg, err := json.Marshal(rmo) + require_NoError(t, err) + _, err = mqttDecodeRetainedMessage("$MQTT.rmsgs.foo.x", nil, msg) + require_Error(t, err, errMQTTInvalidRetainFlags) + }) + t.Run("decode retained message as JSON subject transform", func(t *testing.T) { + rmo := &mqttRetainedMsg{ + Topic: "foo/x", + Origin: "Origin", + Source: "Source", + Flags: 1, + Msg: []byte("hello"), + } + msg, err := json.Marshal(rmo) + require_NoError(t, err) + rm, err := mqttDecodeRetainedMessage("$MQTT.rmsgs.bar.x", nil, msg) + require_NoError(t, err) + require_Equal(t, rm.Topic, "bar/x") + require_Equal(t, rm.Subject, "bar.x") + require_Equal(t, rm.Origin, "Origin") + require_Equal(t, rm.Source, "Source") + require_Equal(t, rm.Flags, 1) + require_Equal(t, string(rm.Msg), "hello") + }) +} + +func TestMQTTRetainedMsgRemovedFromMapIfNotInStream(t *testing.T) { + mqttRetainedCacheTTL = 250 * time.Millisecond + defer func() { mqttRetainedCacheTTL = mqttDefaultRetainedCacheTTL }() + + o := testMQTTDefaultOptions() + s := testMQTTRunServer(t, o) + defer testMQTTShutdownServer(s) + + c, r := testMQTTConnect(t, &mqttConnInfo{clientID: "pub", cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, c, r, 0, false, true, "foo", 0, []byte("msg1")) + testMQTTFlush(t, c, nil, r) + + checkRetained := func(expected string) { + t.Helper() + c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo", qos: 0}}, []byte{0}) + if expected == _EMPTY_ { + testMQTTExpectNothing(t, r) + } else { + testMQTTCheckPubMsg(t, c, r, "foo", mqttPubFlagRetain, []byte(expected)) + } + } + checkRetained("msg1") + + testMQTTPublish(t, c, r, 0, false, true, "foo", 0, []byte("msg2")) + testMQTTFlush(t, c, nil, r) + + checkRetained("msg2") + + // Now we will get the current sequence for the retained message and + // remove it from the stream. We expect to get a warning that indicates + // that the load failed. Restarting the subscription should not longer + // cause this warning and the retained message should have been removed + // from the map/cache. + l := &captureWarnLogger{warn: make(chan string, 10)} + s.SetLogger(l, false, false) + + asm := testMQTTGetAccountSessionManager(t, s, "pub") + // Make sure it is in the cache + rm := asm.getCachedRetainedMsg("foo") + require_NotNil(t, rm) + // Get the mqttRetainedMsgRef from the map + asm.mu.RLock() + rf, ok := asm.retmsgs["foo"] + asm.mu.RUnlock() + require_True(t, ok) + nc, js := jsClientConnect(t, s) + defer nc.Close() + err := js.DeleteMsg(mqttRetainedMsgsStreamName, rf.sseq) + require_NoError(t, err) + + // Wait for more than the cache TTL + time.Sleep(2 * mqttRetainedCacheTTL) + + checkRetained(_EMPTY_) + + // We should have got a warning. + select { + case w := <-l.warn: + if !strings.Contains(w, ApiErrors[JSNoMessageFoundErr].Description) { + t.Fatalf("Unexpected warning: %q", w) + } + case <-time.After(time.Second): + t.Fatalf("Test timed out") + } + + // But restarting it should not cause the server to try to load the retained + // message again. So we should not have a warning. + checkRetained(_EMPTY_) + + select { + case w := <-l.warn: + if strings.Contains(w, ApiErrors[JSNoMessageFoundErr].Description) { + t.Fatalf("Got the warning: %q", w) + } + case <-time.After(250 * time.Millisecond): + // OK + } + + // Finally, check that the retmsgs map is empty. + asm.mu.RLock() + ok = len(asm.retmsgs) == 0 + asm.mu.RUnlock() + require_True(t, ok) +} + +func TestMQTTCrossAccountRetain(t *testing.T) { + for _, test := range []struct { + name string + importTransform string + sourceDest string + bDest string + getLastMsgSubj string + }{ + {"without transform", "", "", "foo/x", "$MQTT.rmsgs.foo.x"}, + {"with transform", `, to: "foobar.>"`, "$MQTT.rmsgs.foobar.>", "foobar/x", "$MQTT.rmsgs.foobar.x"}, + } { + t.Run(test.name, func(t *testing.T) { + td := t.TempDir() + dir := filepath.Join(td, "js") + conf := createConfFile(t, fmt.Appendf(nil, ` + server_name: server + listen: "127.0.0.1:-1" + jetstream { + domain: "MYDOMAIN" + store_dir: "%s" + } + mqtt { + listen: "127.0.0.1:-1" + } + accounts: { + A: { + jetstream: true + users: [ { user:a, password:x }] + exports: [ + { stream: "foo.>" } + { service: "$JS.API.>", response_type: stream } + { stream: "a2b.>" } + ] + } + B: { + jetstream: true + users: [ { user:b, password:x }] + imports: [ + { stream: { account: A, subject: "foo.>" }%s } + { service: { account: A, subject: "$JS.API.>"}, to: "A.$JS.API.>" } + { stream: { account: A, subject: "a2b.>" } } + ] + } + } + `, dir, test.importTransform)) + s, o := RunServerWithConfig(conf) + defer s.Shutdown() + + // Connect a user on "B" to create the MQTT assets. + c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true, user: "b", pass: "x"}, "127.0.0.1", o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTDisconnect(t, c, nil) + + pubRetained := func(user, dest, msg string) { + t.Helper() + c, r := testMQTTConnect(t, &mqttConnInfo{ + cleanSess: true, + user: user, + pass: "x", + }, "127.0.0.1", o.MQTT.Port) + defer c.Close() + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTPublish(t, c, r, 0, false, true, dest, 0, []byte(msg)) + testMQTTFlush(t, c, nil, r) + } + + // Publish a retained message from account "A". + retainInAMsg := "Retain in A" + pubRetained("a", "foo/x", retainInAMsg) + + // Now we are going to do something unusual, which is to update + // the MQTT retain stream in "B" to source from "A". + nc, js := jsClientConnect(t, s, nats.UserInfo("b", "x")) + defer nc.Close() + + si, err := js.StreamInfo(mqttRetainedMsgsStreamName) + require_NoError(t, err) + src := &nats.StreamSource{ + Name: mqttRetainedMsgsStreamName, + External: &nats.ExternalStream{ + APIPrefix: "A.$JS.API", + DeliverPrefix: "a2b", + }, + SubjectTransforms: []nats.SubjectTransformConfig{ + { + Source: mqttRetainedMsgsStreamSubject + "foo.>", + Destination: test.sourceDest, + }, + }, + } + si.Config.Sources = []*nats.StreamSource{src} + _, err = js.UpdateStream(&si.Config) + require_NoError(t, err) + + // Now wait to make sure that the "B" retained messages stream + // contains the message with body "Retain in A" + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + msg, err := js.GetLastMsg(mqttRetainedMsgsStreamName, test.getLastMsgSubj) + if err != nil { + return err + } + if !bytes.Contains(msg.Data, []byte(retainInAMsg)) { + return fmt.Errorf("Message is not from A: %q", msg.Data) + } + return nil + }) + + getRetained := func(user, dest, msg string) { + t.Helper() + c, r := testMQTTConnect(t, &mqttConnInfo{ + cleanSess: true, + user: user, + pass: "x", + }, "127.0.0.1", o.MQTT.Port) + testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false) + testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: dest, qos: 0}}, []byte{0}) + if msg == _EMPTY_ { + testMQTTExpectNothing(t, r) + } else { + testMQTTCheckPubMsg(t, c, r, dest, mqttPubFlagRetain, []byte(msg)) + } + } + + // The retained message in account A should of course be "Retain in A" + getRetained("a", "foo/x", retainInAMsg) + // But because of the sourcing, the retained message in "B" should be + // the retained message from the "A" account. + getRetained("b", test.bDest, retainInAMsg) + + // Now publish a retained message from "B" account and make sure that + // it is correctly replacing "Retain in A". + retainInBMsg := "Retain in B" + pubRetained("b", test.bDest, retainInBMsg) + // Check that we can receive it. + getRetained("b", test.bDest, retainInBMsg) + // And "A" still has the "Retain in A" message. + getRetained("a", "foo/x", retainInAMsg) + + // Publish from "A" a new message: + retainInAMsg = "Retain in A2" + pubRetained("a", "foo/x", retainInAMsg) + // Make sure that this retained appears on "A" and "B". + getRetained("a", "foo/x", retainInAMsg) + getRetained("b", test.bDest, retainInAMsg) + + // Now publish an empty body retained message from "A". This + // should remove the retained message from both "A" and "B". + pubRetained("a", "foo/x", _EMPTY_) + + // We will check that the message gets removed from the stream. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + _, err := js.GetLastMsg(mqttRetainedMsgsStreamName, test.getLastMsgSubj) + if err == nats.ErrMsgNotFound { + return nil + } + return fmt.Errorf("Message still present or unexpected error %v", err) + }) + // The helper will use "expect nothing" if the given string is empty. + getRetained("a", "foo/x", _EMPTY_) + getRetained("b", test.bDest, _EMPTY_) + }) + } +} + ////////////////////////////////////////////////////////////////////////// // // Benchmarks