Skip to content

Commit 4b6207a

Browse files
committed
Refactor and optimize connection limit tests
Reduced the number of concurrent attempts in ConnectionTrackerTests for faster execution and replaced Barrier with TaskCompletionSource for task synchronization. Improved SmtpServerRaceConditionTests by streamlining connection handling, adding debug info, and ensuring proper cleanup and validation of connection limits.
1 parent ceb4a3c commit 4b6207a

2 files changed

Lines changed: 71 additions & 59 deletions

File tree

tests/Zetian.Tests/ConnectionTrackerTests.cs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,27 @@ public async Task TryAcquireAsync_ShouldBeThreadSafe()
113113
{
114114
// Arrange
115115
const int maxConnections = 5;
116-
const int totalAttempts = 20; // Reduced from 100 for faster test
116+
const int totalAttempts = 10; // Further reduced for faster test
117117
int successCount = 0;
118-
var handles = new ConcurrentBag<ConnectionTracker.ConnectionHandle>();
119-
var barrier = new Barrier(totalAttempts);
118+
ConcurrentBag<ConnectionTracker.ConnectionHandle> handles = new();
119+
TaskCompletionSource<bool> startSignal = new();
120120

121121
// Act - Many concurrent attempts
122122
Task[] tasks = Enumerable.Range(0, totalAttempts).Select(_ => Task.Run(async () =>
123123
{
124-
barrier.SignalAndWait(); // Synchronize all tasks to start at the same time
124+
await startSignal.Task; // Wait for signal to start simultaneously
125125

126126
ConnectionTracker.ConnectionHandle? handle = await _tracker.TryAcquireAsync(_testIp);
127127
if (handle != null)
128128
{
129129
Interlocked.Increment(ref successCount);
130130
handles.Add(handle);
131-
// No delay needed for this test
132131
}
133132
})).ToArray();
134133

134+
// Start all tasks at once
135+
startSignal.SetResult(true);
136+
135137
await Task.WhenAll(tasks);
136138

137139
// Assert - Exactly max connections should succeed
@@ -202,7 +204,7 @@ public void GetConnectionCount_ShouldReturnCorrectCount()
202204
public async Task TryAcquireAsync_ShouldRespectCancellationToken()
203205
{
204206
// Arrange
205-
using var cts = new CancellationTokenSource();
207+
using CancellationTokenSource cts = new();
206208
cts.Cancel();
207209

208210
// Act & Assert - TaskCanceledException derives from OperationCanceledException

tests/Zetian.Tests/SmtpServerRaceConditionTests.cs

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
using System.Net.Sockets;
44
using System.Text;
55
using Xunit;
6-
using System.Linq;
76

87
namespace Zetian.Tests
98
{
@@ -53,7 +52,7 @@ public async Task SmtpServer_ShouldEnforceMaxConnectionsPerIP()
5352
const int attemptCount = 20;
5453
ConcurrentBag<TcpClient> successfulConnections = new();
5554
Barrier barrier = new(attemptCount);
56-
var successCount = 0;
55+
int successCount = 0;
5756

5857
// Act - Try to connect many times simultaneously
5958
Task<bool>[] tasks = Enumerable.Range(0, attemptCount).Select(i => Task.Run(async () =>
@@ -119,82 +118,93 @@ public async Task SmtpServer_ShouldAllowNewConnectionsAfterDisconnect()
119118
// Arrange
120119
_server = await CreateAndStartServerAsync(TestPort + 1); // Use different port
121120

122-
// First, max out connections
121+
// First, quickly max out connections
123122
List<TcpClient> firstBatch = new();
124-
for (int i = 0; i < MaxConnectionsPerIp; i++)
123+
Task<TcpClient>[] connectionTasks = Enumerable.Range(0, MaxConnectionsPerIp).Select(async i =>
125124
{
126125
TcpClient client = new();
127126
await client.ConnectAsync("localhost", TestPort + 1);
127+
return client;
128+
}).ToArray();
128129

129-
// Verify we got the greeting
130+
TcpClient[] connectedClients = await Task.WhenAll(connectionTasks);
131+
132+
// Verify all got greetings
133+
foreach (TcpClient client in connectedClients)
134+
{
130135
byte[] buffer = new byte[1024];
136+
client.ReceiveTimeout = 1000;
131137
int bytes = await client.GetStream().ReadAsync(buffer, 0, buffer.Length);
132138
string response = Encoding.UTF8.GetString(buffer, 0, bytes);
133139
Assert.StartsWith("220", response);
134-
135140
firstBatch.Add(client);
136141
}
137142

138-
// Try one more - should fail to get a valid SMTP greeting
139-
// Note: TCP connection might succeed but should not get SMTP service
140-
using (TcpClient extraClient = new())
143+
// Small delay to ensure all connections are properly tracked
144+
await Task.Delay(100);
145+
146+
// Verify that we're at the limit by checking current count
147+
Assert.Equal(MaxConnectionsPerIp, firstBatch.Count);
148+
149+
// Try one more - should fail or timeout
150+
bool extraConnectionAccepted = false;
151+
string debugInfo = "";
152+
153+
try
141154
{
142-
var gotValidGreeting = false;
143-
try
144-
{
145-
extraClient.ReceiveTimeout = 500;
146-
await extraClient.ConnectAsync("localhost", TestPort + 1);
155+
using TcpClient extraClient = new();
156+
extraClient.ReceiveTimeout = 500; // Increased timeout
147157

148-
// Even if TCP connects, we shouldn't get a valid SMTP greeting
149-
// because the connection tracker should reject it
150-
await Task.Delay(50);
158+
// Try to connect
159+
await extraClient.ConnectAsync("localhost", TestPort + 1);
160+
debugInfo += "TCP connected. ";
151161

152-
if (extraClient.Available > 0)
162+
if (extraClient.Connected)
163+
{
164+
try
153165
{
154-
var buffer = new byte[1024];
155-
var bytes = await extraClient.GetStream().ReadAsync(buffer, 0, buffer.Length);
156-
var response = Encoding.UTF8.GetString(buffer, 0, bytes);
157-
gotValidGreeting = response.StartsWith("220");
166+
// Try to read SMTP greeting
167+
byte[] buffer = new byte[1024];
168+
int bytes = await extraClient.GetStream().ReadAsync(buffer, 0, buffer.Length);
169+
string response = Encoding.UTF8.GetString(buffer, 0, bytes);
170+
debugInfo += $"Got response: {response.Trim()}. ";
171+
172+
// If we got a valid greeting, the connection was incorrectly accepted
173+
if (response.StartsWith("220"))
174+
{
175+
extraConnectionAccepted = true;
176+
debugInfo += "SMTP greeting received - connection was accepted!";
177+
}
178+
}
179+
catch (Exception ex)
180+
{
181+
// Timeout or error is expected - connection was rejected
182+
debugInfo += $"Read failed: {ex.GetType().Name}. Connection rejected.";
158183
}
159184
}
160-
catch
161-
{
162-
// Connection failed - this is expected and OK
163-
}
164-
165-
Assert.False(gotValidGreeting, "Should not get SMTP greeting beyond connection limit");
166185
}
167-
168-
// Close all first batch connections
169-
foreach (TcpClient client in firstBatch)
186+
catch (Exception ex)
170187
{
171-
byte[] quit = Encoding.UTF8.GetBytes("QUIT\r\n");
172-
await client.GetStream().WriteAsync(quit, 0, quit.Length);
173-
client.Close();
188+
// Connection failed - this is expected
189+
debugInfo += $"Connect failed: {ex.GetType().Name}. Connection rejected.";
174190
}
175191

176-
// Wait a bit for cleanup
192+
Assert.False(extraConnectionAccepted, $"Extra connection should be rejected. Debug: {debugInfo}");
193+
194+
// Close all first batch connections quickly
195+
Parallel.ForEach(firstBatch, client => client.Close());
196+
197+
// Wait briefly for cleanup
177198
await Task.Delay(50);
178199

179-
// Now should be able to connect again
180-
List<TcpClient> secondBatch = new();
181-
for (int i = 0; i < MaxConnectionsPerIp; i++)
200+
// Now should be able to connect again - test with just one connection
201+
using (TcpClient newClient = new())
182202
{
183-
TcpClient client = new();
184-
await client.ConnectAsync("localhost", TestPort + 1);
185-
203+
await newClient.ConnectAsync("localhost", TestPort + 1);
186204
byte[] buffer = new byte[1024];
187-
int bytes = await client.GetStream().ReadAsync(buffer, 0, buffer.Length);
205+
int bytes = await newClient.GetStream().ReadAsync(buffer, 0, buffer.Length);
188206
string response = Encoding.UTF8.GetString(buffer, 0, bytes);
189207
Assert.StartsWith("220", response);
190-
191-
secondBatch.Add(client);
192-
}
193-
194-
// Cleanup
195-
foreach (TcpClient client in secondBatch)
196-
{
197-
client.Close();
198208
}
199209
}
200210

@@ -208,8 +218,8 @@ public async Task SmtpServer_ShouldHandleConcurrentSmtpClients()
208218
_server = await CreateAndStartServerAsync(TestPort + 2); // Use different port
209219
const int attemptCount = 10; // Reduced for faster test
210220
List<Task<bool>> tasks = new();
211-
var successCount = 0;
212-
var failureMessages = new ConcurrentBag<string>();
221+
int successCount = 0;
222+
ConcurrentBag<string> failureMessages = new();
213223
Barrier barrier = new(attemptCount);
214224

215225
// Act
@@ -253,7 +263,7 @@ public async Task SmtpServer_ShouldHandleConcurrentSmtpClients()
253263
// no more than MaxConnectionsPerIp are active
254264
if (successCount == 0)
255265
{
256-
var errorDetails = string.Join("\n", failureMessages.Take(5));
266+
string errorDetails = string.Join("\n", failureMessages.Take(5));
257267
Assert.True(false, $"No connections succeeded. Sample errors:\n{errorDetails}");
258268
}
259269
Assert.True(successCount > 0, "At least some connections should succeed");

0 commit comments

Comments
 (0)