Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions spec/connection_rate_limit_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
require "./spec_helper"

describe LavinMQ::ConnectionRateLimiter do
describe "global rate limit" do
it "allows connections when disabled (limit=0)" do
config = LavinMQ::Config.new
config.connection_rate_limit = 0
limiter = LavinMQ::ConnectionRateLimiter.new(config)
100.times { limiter.allow?("127.0.0.1").should be_true }
end

it "allows connections up to the limit" do
config = LavinMQ::Config.new
config.connection_rate_limit = 5
limiter = LavinMQ::ConnectionRateLimiter.new(config)
5.times { |i| limiter.allow?("127.0.0.1").should(be_true, "failed on attempt #{i}") }
limiter.allow?("127.0.0.1").should be_false
end

it "replenishes tokens over time" do
config = LavinMQ::Config.new
config.connection_rate_limit = 10
limiter = LavinMQ::ConnectionRateLimiter.new(config)
10.times { limiter.allow?("127.0.0.1") }
limiter.allow?("127.0.0.1").should be_false
sleep 0.2.seconds
limiter.allow?("127.0.0.1").should be_true
end
end

describe "per-IP rate limit" do
it "allows connections when disabled (limit=0)" do
config = LavinMQ::Config.new
config.connection_rate_limit_per_ip = 0
limiter = LavinMQ::ConnectionRateLimiter.new(config)
100.times { limiter.allow?("10.0.0.1").should be_true }
end

it "limits each IP independently" do
config = LavinMQ::Config.new
config.connection_rate_limit_per_ip = 2
limiter = LavinMQ::ConnectionRateLimiter.new(config)
2.times { limiter.allow?("10.0.0.1").should be_true }
limiter.allow?("10.0.0.1").should be_false
2.times { limiter.allow?("10.0.0.2").should be_true }
limiter.allow?("10.0.0.2").should be_false
end

it "replenishes per-IP tokens over time" do
config = LavinMQ::Config.new
config.connection_rate_limit_per_ip = 10
limiter = LavinMQ::ConnectionRateLimiter.new(config)
10.times { limiter.allow?("10.0.0.1") }
limiter.allow?("10.0.0.1").should be_false
sleep 0.2.seconds
limiter.allow?("10.0.0.1").should be_true
end
end

describe "combined global and per-IP limits" do
it "per-IP limit stops connections before global limit" do
config = LavinMQ::Config.new
config.connection_rate_limit = 10
config.connection_rate_limit_per_ip = 2
limiter = LavinMQ::ConnectionRateLimiter.new(config)
2.times { limiter.allow?("10.0.0.1").should be_true }
# Per-IP limit reached, even though global has tokens
limiter.allow?("10.0.0.1").should be_false
# Different IP still works
2.times { limiter.allow?("10.0.0.2").should be_true }
limiter.allow?("10.0.0.2").should be_false
end

it "global limit stops connections even if per-IP has tokens" do
config = LavinMQ::Config.new
config.connection_rate_limit = 3
config.connection_rate_limit_per_ip = 5
limiter = LavinMQ::ConnectionRateLimiter.new(config)
3.times { limiter.allow?("10.0.0.1").should be_true }
# Global limit reached
limiter.allow?("10.0.0.1").should be_false
# Different IP also blocked by global limit
limiter.allow?("10.0.0.2").should be_false
end
end
end

describe "Server connection rate limiting" do
it "rejects connections exceeding the global rate limit" do
config = LavinMQ::Config.new
config.connection_rate_limit = 2
with_amqp_server(config: config) do |s|
conn1 = AMQP::Client.new(port: amqp_port(s)).connect
conn2 = AMQP::Client.new(port: amqp_port(s)).connect
# Third connection should fail - server closes socket before handshake
expect_raises(Exception) do
AMQP::Client.new(port: amqp_port(s)).connect
end
conn1.close
conn2.close
end
end

it "allows unlimited connections when rate limit is 0" do
config = LavinMQ::Config.new
config.connection_rate_limit = 0
with_amqp_server(config: config) do |s|
conns = Array(AMQP::Client::Connection).new
5.times do
conns << AMQP::Client.new(port: amqp_port(s)).connect
end
conns.size.should eq 5
conns.each &.close
end
end

it "rejects connections exceeding the per-IP rate limit" do
config = LavinMQ::Config.new
config.connection_rate_limit_per_ip = 2
with_amqp_server(config: config) do |s|
conn1 = AMQP::Client.new(port: amqp_port(s)).connect
conn2 = AMQP::Client.new(port: amqp_port(s)).connect
expect_raises(Exception) do
AMQP::Client.new(port: amqp_port(s)).connect
end
conn1.close
conn2.close
end
end
end
6 changes: 6 additions & 0 deletions src/lavinmq/config/options.cr
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ module LavinMQ
@[IniOpt(section: "main", transform: ->ConsistentHashAlgorithm.parse(String))]
property default_consistent_hash_algorithm : ConsistentHashAlgorithm = ConsistentHashAlgorithm::Ring

@[IniOpt(section: "main")]
property connection_rate_limit = 0 # max new connections per second, 0 = unlimited

@[IniOpt(section: "main")]
property connection_rate_limit_per_ip = 0 # max new connections per second per source IP, 0 = unlimited

# Deprecated options - these forward to the primary option in [main]

@[IniOpt(ini_name: tls_cert, section: "amqp", deprecated: "tls_cert in [main]")]
Expand Down
119 changes: 119 additions & 0 deletions src/lavinmq/connection_rate_limiter.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
module LavinMQ
# Token bucket rate limiter for incoming connections.
# Thread-safe via a mutex.
class ConnectionRateLimiter
Log = LavinMQ::Log.for "rate_limiter"
MAX_TRACKED_IPS = 100_000

private record PerIPState,
tokens : Float64,
last_refill : Time::Instant

@mu = Mutex.new
@global_tokens : Float64
@global_last_refill : Time::Instant
@last_log_time : Time::Instant
@last_table_full_log_time : Time::Instant

def initialize(@config : Config)
@global_tokens = @config.connection_rate_limit.to_f
@global_last_refill = Time.instant
@per_ip = Hash(String, PerIPState).new
@last_log_time = Time.instant
@last_table_full_log_time = Time.instant
end

# Returns true if the connection should be allowed.
# Per-IP is checked first to avoid consuming a global token
# when the per-IP limit rejects the connection. Note: if per-IP
# passes but global rejects, the per-IP token is still consumed.
# This is acceptable since tokens refill continuously.
def allow?(remote_address : String) : Bool
return true unless rate_limiting_enabled?
@mu.synchronize { allow_per_ip?(remote_address) && allow_global? }
end

private def rate_limiting_enabled? : Bool
@config.connection_rate_limit > 0 ||
@config.connection_rate_limit_per_ip > 0
end

private def allow_global? : Bool
limit = @config.connection_rate_limit
return true if limit <= 0

now = Time.instant
elapsed = (now - @global_last_refill).total_seconds
@global_tokens = Math.min(
limit.to_f,
@global_tokens + elapsed * limit
)
@global_last_refill = now

if @global_tokens >= 1.0
@global_tokens -= 1.0
true
else
false
end
end

private def allow_per_ip?(ip : String) : Bool
limit = @config.connection_rate_limit_per_ip
return true if limit <= 0

now = Time.instant
unless @per_ip.has_key?(ip)
if @per_ip.size >= MAX_TRACKED_IPS
evict_oldest_entry
end
@per_ip[ip] = PerIPState.new(limit.to_f, now)
end

state = @per_ip[ip]
elapsed = (now - state.last_refill).total_seconds
tokens = Math.min(limit.to_f, state.tokens + elapsed * limit)

if tokens >= 1.0
@per_ip[ip] = PerIPState.new(tokens - 1.0, now)
true
else
@per_ip[ip] = PerIPState.new(tokens, now)
false
end
end

def log_rate_limited(remote_address : String)
now = Time.instant
@mu.synchronize do
return unless (now - @last_log_time).total_seconds >= 1.0
@last_log_time = now
end
Log.warn { "Connection rate limited: #{remote_address}" }
end

# Remove stale per-IP entries to prevent unbounded growth.
# Called periodically (e.g. from stats_loop).
def cleanup_stale_entries
return if @config.connection_rate_limit_per_ip <= 0
now = Time.instant
@mu.synchronize do
@per_ip.reject! do |_ip, state|
(now - state.last_refill).total_seconds > 60
end
end
end

# Evict the oldest-inserted entry. Crystal's Hash is
# insertion-ordered, so `first` is O(1).
private def evict_oldest_entry
now = Time.instant
if (now - @last_table_full_log_time).total_seconds >= 1.0
@last_table_full_log_time = now
Log.warn { "Per-IP tracking limit reached (#{MAX_TRACKED_IPS}), evicting oldest entry" }
end
oldest_ip = @per_ip.first_key
@per_ip.delete(oldest_ip)
end
end
end
24 changes: 22 additions & 2 deletions src/lavinmq/server.cr
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require "./mqtt/connection_factory"
require "./stats"
require "./auth/chain"
require "./auth/jwt/jwks_fetcher"
require "./connection_rate_limiter"

module LavinMQ
class Server
Expand All @@ -38,6 +39,7 @@ module LavinMQ
@listeners = Hash(Socket::Server, Protocol).new # Socket => protocol
@connection_factories = Hash(Protocol, ConnectionFactory).new
@replicator : Clustering::Replicator?
@rate_limiter : ConnectionRateLimiter
Log = LavinMQ::Log.for "server"

def initialize(@config : Config, @replicator = nil)
Expand All @@ -54,6 +56,7 @@ module LavinMQ
Protocol::AMQP => AMQP::ConnectionFactory.new(@authenticator, @vhosts),
Protocol::MQTT => MQTT::ConnectionFactory.new(@authenticator, @mqtt_brokers, @config),
}
@rate_limiter = ConnectionRateLimiter.new(@config)
apply_parameter
spawn stats_loop, name: "Server#stats_loop"
end
Expand Down Expand Up @@ -102,6 +105,7 @@ module LavinMQ
@vhosts.load!
@connection_factories[Protocol::AMQP] = AMQP::ConnectionFactory.new(@authenticator, @vhosts)
@connection_factories[Protocol::MQTT] = MQTT::ConnectionFactory.new(@authenticator, @mqtt_brokers, @config)
@rate_limiter = ConnectionRateLimiter.new(@config)
@parameters = ParameterStore(Parameter).new(@data_dir, "parameters.json", @replicator)
apply_parameter
@closed.set(false)
Expand All @@ -127,8 +131,13 @@ module LavinMQ
end

private def accept_tcp(client, protocol)
remote_address = client.remote_address
unless @rate_limiter.allow?(remote_address.address)
@rate_limiter.log_rate_limited(remote_address.to_s)
client.close rescue nil
return
end
spawn(name: "Accept TCP socket") do
remote_address = client.remote_address
set_socket_options(client)
set_buffer_size(client)
conn_info = extract_conn_info(client)
Expand Down Expand Up @@ -177,6 +186,11 @@ module LavinMQ
end

private def accept_unix(client, protocol)
unless @rate_limiter.allow?("unix")
@rate_limiter.log_rate_limited("unix")
client.close rescue nil
return
end
spawn(name: "Accept UNIX socket") do
remote_address = client.remote_address
set_buffer_size(client)
Expand Down Expand Up @@ -213,8 +227,13 @@ module LavinMQ
end

private def accept_tls(client, context, protocol)
remote_addr = client.remote_address
unless @rate_limiter.allow?(remote_addr.address)
@rate_limiter.log_rate_limited(remote_addr.to_s)
client.close rescue nil
return
end
spawn(name: "Accept TLS socket") do
remote_addr = client.remote_address
set_socket_options(client)
if @config.tls_ktls?
ssl_client = OpenSSL::SSL::NativeSocket::Server.new(client, context, sync_close: true)
Expand Down Expand Up @@ -413,6 +432,7 @@ module LavinMQ
@gc_stats = GC.prof_stats

control_flow!
@rate_limiter.cleanup_stale_entries
sleep @config.stats_interval.milliseconds
end
ensure
Expand Down
Loading