Skip to content
88 changes: 73 additions & 15 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification
// The _sessions dictionary contains all session, the _subscriberSessions hash set contains subscriber sessions only.
// See the MqttSubscription object for a detailed explanation.
readonly MqttSessionsStorage _sessionsStorage = new();
readonly HashSet<MqttSession> _subscriberSessions = [];
readonly HashSet<MqttSession> _subscriberSessionsWithWildcards = [];
readonly Dictionary<string, HashSet<MqttSession>> _simpleTopicsToSession = [];

public MqttClientSessionsManager(MqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, MqttServerEventContainer eventContainer, IMqttNetLogger logger)
{
Expand Down Expand Up @@ -77,7 +78,7 @@ public async Task DeleteSessionAsync(string clientId)
{
if (_sessionsStorage.TryRemoveSession(clientId, out session))
{
_subscriberSessions.Remove(session);
CleanupClientSessionUnsafe(session);
}
}
finally
Expand Down Expand Up @@ -161,11 +162,30 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage).ConfigureAwait(false);
}

List<MqttSession> subscriberSessions;
HashSet<MqttSession> subscriberSessions;
_sessionsManagementLock.EnterReadLock();
try
{
subscriberSessions = _subscriberSessions.ToList();
if (_simpleTopicsToSession.TryGetValue(applicationMessage.Topic, out var sessionsWithSimpleTopics))
{
// Create the initial subscriberSessions from whichever set is larger to take advantage
// of the internal ConstructFrom other HashSet optimizations
if(sessionsWithSimpleTopics.Count > _subscriberSessionsWithWildcards.Count)
{
subscriberSessions = new HashSet<MqttSession>(sessionsWithSimpleTopics);
subscriberSessions.UnionWith(_subscriberSessionsWithWildcards);
}
else
{
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
subscriberSessions.UnionWith(sessionsWithSimpleTopics);
}
}
else
{
// Always include the sessions with wildcards. They need to be properly matched against the topic filter.
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
}
}
finally
{
Expand Down Expand Up @@ -451,15 +471,28 @@ public void OnSubscriptionsAdded(MqttSession clientSession, List<string> topics)
_sessionsManagementLock.EnterWriteLock();
try
{
if (!clientSession.HasSubscribedTopics)
{
// first subscribed topic
_subscriberSessions.Add(clientSession);
}

foreach (var topic in topics)
{
clientSession.AddSubscribedTopic(topic);
bool hasWildcard = MqttTopicFilterComparer.ContainsWildcards(topic);
if (hasWildcard)
{
if (!clientSession.HasSubscribedWildcardTopics)
{
_subscriberSessionsWithWildcards.Add(clientSession);
}
}
else
{
if (_simpleTopicsToSession.TryGetValue(topic, out var sessionsWithSimpleTopics))
{
sessionsWithSimpleTopics.Add(clientSession);
}
else
{
_simpleTopicsToSession[topic] = [clientSession];
}
}
clientSession.AddSubscribedTopic(topic, hasWildcard);
}
}
finally
Expand All @@ -475,13 +508,21 @@ public void OnSubscriptionsRemoved(MqttSession clientSession, List<string> subsc
{
foreach (var subscriptionTopic in subscriptionTopics)
{
if (_simpleTopicsToSession.TryGetValue(subscriptionTopic, out var sessionsWithSimpleTopics))
{
sessionsWithSimpleTopics.Remove(clientSession);
if (sessionsWithSimpleTopics.Count == 0)
{
_simpleTopicsToSession.Remove(subscriptionTopic);
}
}
clientSession.RemoveSubscribedTopic(subscriptionTopic);
}

if (!clientSession.HasSubscribedTopics)
if (!clientSession.HasSubscribedWildcardTopics)
{
// last subscription removed
_subscriberSessions.Remove(clientSession);
// Last wildcard subscription removed
_subscriberSessionsWithWildcards.Remove(clientSession);
}
}
finally
Expand Down Expand Up @@ -564,7 +605,7 @@ async Task<MqttConnectedClient> CreateClientConnection(
if (connectPacket.CleanSession)
{
_logger.Verbose("Deleting existing session of client '{0}' due to clean start", connectPacket.ClientId);
_subscriberSessions.Remove(oldSession);
CleanupClientSessionUnsafe(oldSession);
session = CreateSession(connectPacket, validatingConnectionEventArgs);
}
else
Expand Down Expand Up @@ -669,6 +710,23 @@ MqttSession GetClientSession(string clientId)
}
}

//* Must be called with the _sessionsManagementLock held.
void CleanupClientSessionUnsafe(MqttSession session)
{
_subscriberSessionsWithWildcards.Remove(session);
foreach (var simpleTopic in session.GetSimpleSubscribedTopics)
{
if (_simpleTopicsToSession.TryGetValue(simpleTopic, out var sessionsWithSimpleTopics))
{
sessionsWithSimpleTopics.Remove(session);
if (sessionsWithSimpleTopics.Count == 0)
{
_simpleTopicsToSession.Remove(simpleTopic);
}
}
}
}

async Task<MqttConnectPacket> ReceiveConnectPacket(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
try
Expand Down
23 changes: 14 additions & 9 deletions Source/MQTTnet.Server/Internal/MqttSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ public sealed class MqttSession : IDisposable
// Do not use a dictionary in order to keep the ordering of the messages.
readonly List<MqttPublishPacket> _unacknowledgedPublishPackets = new();

// Bookkeeping to know if this is a subscribing client; lazy initialize later.
HashSet<string> _subscribedTopics;
readonly HashSet<string> _subscribedSimpleTopics = [];
readonly HashSet<string> _subscribedWildcardTopics = [];

public MqttSession(
MqttConnectPacket connectPacket,
Expand All @@ -50,7 +50,9 @@ public MqttSession(

public uint ExpiryInterval => _connectPacket.SessionExpiryInterval;

public bool HasSubscribedTopics => _subscribedTopics != null && _subscribedTopics.Count > 0;
public bool HasSubscribedWildcardTopics => _subscribedWildcardTopics.Count > 0;

public HashSet<string> GetSimpleSubscribedTopics => _subscribedSimpleTopics;

public string Id => _connectPacket.ClientId;

Expand Down Expand Up @@ -79,14 +81,16 @@ public MqttPublishPacket AcknowledgePublishPacket(ushort packetIdentifier)
return publishPacket;
}

public void AddSubscribedTopic(string topic)
public void AddSubscribedTopic(string topic, bool isWildcardTopic)
{
if (_subscribedTopics == null)
if (isWildcardTopic)
{
_subscribedTopics = new HashSet<string>();
_subscribedWildcardTopics.Add(topic);
}
else
{
_subscribedSimpleTopics.Add(topic);
}

_subscribedTopics.Add(topic);
}

public Task DeleteAsync()
Expand Down Expand Up @@ -208,7 +212,8 @@ public void Recover()

public void RemoveSubscribedTopic(string topic)
{
_subscribedTopics?.Remove(topic);
_subscribedSimpleTopics.Remove(topic);
_subscribedWildcardTopics.Remove(topic);
}

public Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
Expand Down
14 changes: 14 additions & 0 deletions Source/MQTTnet/MqttTopicFilterComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,19 @@ public static unsafe MqttTopicFilterCompareResult Compare(string topic, string f

return MqttTopicFilterCompareResult.NoMatch;
}

public static bool ContainsWildcards(string topicFilter)
{
for (var i = 0; i < topicFilter.Length; i++)
{
var c = topicFilter[i];
if (c == MultiLevelWildcard || c == SingleLevelWildcard)
{
return true;
}
}

return false;
}
}
}