diff --git a/spec/connection_rate_limit_spec.cr b/spec/connection_rate_limit_spec.cr new file mode 100644 index 0000000000..bca57f234e --- /dev/null +++ b/spec/connection_rate_limit_spec.cr @@ -0,0 +1,130 @@ +require "./spec_helper" + +describe LavinMQ::ConnectionRateLimiter do + describe "global rate limit" do + it "allows connections when disabled (limit=0)" do + config = LavinMQ::Config.new + config.connection_rate_limit = 0 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 100.times { limiter.allow?("127.0.0.1").should be_true } + end + + it "allows connections up to the limit" do + config = LavinMQ::Config.new + config.connection_rate_limit = 5 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 5.times { |i| limiter.allow?("127.0.0.1").should(be_true, "failed on attempt #{i}") } + limiter.allow?("127.0.0.1").should be_false + end + + it "replenishes tokens over time" do + config = LavinMQ::Config.new + config.connection_rate_limit = 10 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 10.times { limiter.allow?("127.0.0.1") } + limiter.allow?("127.0.0.1").should be_false + sleep 0.2.seconds + limiter.allow?("127.0.0.1").should be_true + end + end + + describe "per-IP rate limit" do + it "allows connections when disabled (limit=0)" do + config = LavinMQ::Config.new + config.connection_rate_limit_per_ip = 0 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 100.times { limiter.allow?("10.0.0.1").should be_true } + end + + it "limits each IP independently" do + config = LavinMQ::Config.new + config.connection_rate_limit_per_ip = 2 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 2.times { limiter.allow?("10.0.0.1").should be_true } + limiter.allow?("10.0.0.1").should be_false + 2.times { limiter.allow?("10.0.0.2").should be_true } + limiter.allow?("10.0.0.2").should be_false + end + + it "replenishes per-IP tokens over time" do + config = LavinMQ::Config.new + config.connection_rate_limit_per_ip = 10 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 10.times { limiter.allow?("10.0.0.1") } + limiter.allow?("10.0.0.1").should be_false + sleep 0.2.seconds + limiter.allow?("10.0.0.1").should be_true + end + end + + describe "combined global and per-IP limits" do + it "per-IP limit stops connections before global limit" do + config = LavinMQ::Config.new + config.connection_rate_limit = 10 + config.connection_rate_limit_per_ip = 2 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 2.times { limiter.allow?("10.0.0.1").should be_true } + # Per-IP limit reached, even though global has tokens + limiter.allow?("10.0.0.1").should be_false + # Different IP still works + 2.times { limiter.allow?("10.0.0.2").should be_true } + limiter.allow?("10.0.0.2").should be_false + end + + it "global limit stops connections even if per-IP has tokens" do + config = LavinMQ::Config.new + config.connection_rate_limit = 3 + config.connection_rate_limit_per_ip = 5 + limiter = LavinMQ::ConnectionRateLimiter.new(config) + 3.times { limiter.allow?("10.0.0.1").should be_true } + # Global limit reached + limiter.allow?("10.0.0.1").should be_false + # Different IP also blocked by global limit + limiter.allow?("10.0.0.2").should be_false + end + end +end + +describe "Server connection rate limiting" do + it "rejects connections exceeding the global rate limit" do + config = LavinMQ::Config.new + config.connection_rate_limit = 2 + with_amqp_server(config: config) do |s| + conn1 = AMQP::Client.new(port: amqp_port(s)).connect + conn2 = AMQP::Client.new(port: amqp_port(s)).connect + # Third connection should fail - server closes socket before handshake + expect_raises(Exception) do + AMQP::Client.new(port: amqp_port(s)).connect + end + conn1.close + conn2.close + end + end + + it "allows unlimited connections when rate limit is 0" do + config = LavinMQ::Config.new + config.connection_rate_limit = 0 + with_amqp_server(config: config) do |s| + conns = Array(AMQP::Client::Connection).new + 5.times do + conns << AMQP::Client.new(port: amqp_port(s)).connect + end + conns.size.should eq 5 + conns.each &.close + end + end + + it "rejects connections exceeding the per-IP rate limit" do + config = LavinMQ::Config.new + config.connection_rate_limit_per_ip = 2 + with_amqp_server(config: config) do |s| + conn1 = AMQP::Client.new(port: amqp_port(s)).connect + conn2 = AMQP::Client.new(port: amqp_port(s)).connect + expect_raises(Exception) do + AMQP::Client.new(port: amqp_port(s)).connect + end + conn1.close + conn2.close + end + end +end diff --git a/src/lavinmq/config/options.cr b/src/lavinmq/config/options.cr index e889c58685..3377c7f121 100644 --- a/src/lavinmq/config/options.cr +++ b/src/lavinmq/config/options.cr @@ -300,6 +300,12 @@ module LavinMQ @[IniOpt(section: "main", transform: ->ConsistentHashAlgorithm.parse(String))] property default_consistent_hash_algorithm : ConsistentHashAlgorithm = ConsistentHashAlgorithm::Ring + @[IniOpt(section: "main")] + property connection_rate_limit = 0 # max new connections per second, 0 = unlimited + + @[IniOpt(section: "main")] + property connection_rate_limit_per_ip = 0 # max new connections per second per source IP, 0 = unlimited + # Deprecated options - these forward to the primary option in [main] @[IniOpt(ini_name: tls_cert, section: "amqp", deprecated: "tls_cert in [main]")] diff --git a/src/lavinmq/connection_rate_limiter.cr b/src/lavinmq/connection_rate_limiter.cr new file mode 100644 index 0000000000..ef0a518dce --- /dev/null +++ b/src/lavinmq/connection_rate_limiter.cr @@ -0,0 +1,119 @@ +module LavinMQ + # Token bucket rate limiter for incoming connections. + # Thread-safe via a mutex. + class ConnectionRateLimiter + Log = LavinMQ::Log.for "rate_limiter" + MAX_TRACKED_IPS = 100_000 + + private record PerIPState, + tokens : Float64, + last_refill : Time::Instant + + @mu = Mutex.new + @global_tokens : Float64 + @global_last_refill : Time::Instant + @last_log_time : Time::Instant + @last_table_full_log_time : Time::Instant + + def initialize(@config : Config) + @global_tokens = @config.connection_rate_limit.to_f + @global_last_refill = Time.instant + @per_ip = Hash(String, PerIPState).new + @last_log_time = Time.instant + @last_table_full_log_time = Time.instant + end + + # Returns true if the connection should be allowed. + # Per-IP is checked first to avoid consuming a global token + # when the per-IP limit rejects the connection. Note: if per-IP + # passes but global rejects, the per-IP token is still consumed. + # This is acceptable since tokens refill continuously. + def allow?(remote_address : String) : Bool + return true unless rate_limiting_enabled? + @mu.synchronize { allow_per_ip?(remote_address) && allow_global? } + end + + private def rate_limiting_enabled? : Bool + @config.connection_rate_limit > 0 || + @config.connection_rate_limit_per_ip > 0 + end + + private def allow_global? : Bool + limit = @config.connection_rate_limit + return true if limit <= 0 + + now = Time.instant + elapsed = (now - @global_last_refill).total_seconds + @global_tokens = Math.min( + limit.to_f, + @global_tokens + elapsed * limit + ) + @global_last_refill = now + + if @global_tokens >= 1.0 + @global_tokens -= 1.0 + true + else + false + end + end + + private def allow_per_ip?(ip : String) : Bool + limit = @config.connection_rate_limit_per_ip + return true if limit <= 0 + + now = Time.instant + unless @per_ip.has_key?(ip) + if @per_ip.size >= MAX_TRACKED_IPS + evict_oldest_entry + end + @per_ip[ip] = PerIPState.new(limit.to_f, now) + end + + state = @per_ip[ip] + elapsed = (now - state.last_refill).total_seconds + tokens = Math.min(limit.to_f, state.tokens + elapsed * limit) + + if tokens >= 1.0 + @per_ip[ip] = PerIPState.new(tokens - 1.0, now) + true + else + @per_ip[ip] = PerIPState.new(tokens, now) + false + end + end + + def log_rate_limited(remote_address : String) + now = Time.instant + @mu.synchronize do + return unless (now - @last_log_time).total_seconds >= 1.0 + @last_log_time = now + end + Log.warn { "Connection rate limited: #{remote_address}" } + end + + # Remove stale per-IP entries to prevent unbounded growth. + # Called periodically (e.g. from stats_loop). + def cleanup_stale_entries + return if @config.connection_rate_limit_per_ip <= 0 + now = Time.instant + @mu.synchronize do + @per_ip.reject! do |_ip, state| + (now - state.last_refill).total_seconds > 60 + end + end + end + + # Evict the oldest-inserted entry. Crystal's Hash is + # insertion-ordered, so `first` is O(1). + private def evict_oldest_entry + now = Time.instant + if (now - @last_table_full_log_time).total_seconds >= 1.0 + @last_table_full_log_time = now + Log.warn { "Per-IP tracking limit reached (#{MAX_TRACKED_IPS}), evicting oldest entry" } + end + oldest_ip = @per_ip.first_key + @per_ip.delete(oldest_ip) + end + end +end diff --git a/src/lavinmq/server.cr b/src/lavinmq/server.cr index b01a95f763..ac25d9ccc5 100644 --- a/src/lavinmq/server.cr +++ b/src/lavinmq/server.cr @@ -20,6 +20,7 @@ require "./mqtt/connection_factory" require "./stats" require "./auth/chain" require "./auth/jwt/jwks_fetcher" +require "./connection_rate_limiter" module LavinMQ class Server @@ -38,6 +39,7 @@ module LavinMQ @listeners = Hash(Socket::Server, Protocol).new # Socket => protocol @connection_factories = Hash(Protocol, ConnectionFactory).new @replicator : Clustering::Replicator? + @rate_limiter : ConnectionRateLimiter Log = LavinMQ::Log.for "server" def initialize(@config : Config, @replicator = nil) @@ -54,6 +56,7 @@ module LavinMQ Protocol::AMQP => AMQP::ConnectionFactory.new(@authenticator, @vhosts), Protocol::MQTT => MQTT::ConnectionFactory.new(@authenticator, @mqtt_brokers, @config), } + @rate_limiter = ConnectionRateLimiter.new(@config) apply_parameter spawn stats_loop, name: "Server#stats_loop" end @@ -102,6 +105,7 @@ module LavinMQ @vhosts.load! @connection_factories[Protocol::AMQP] = AMQP::ConnectionFactory.new(@authenticator, @vhosts) @connection_factories[Protocol::MQTT] = MQTT::ConnectionFactory.new(@authenticator, @mqtt_brokers, @config) + @rate_limiter = ConnectionRateLimiter.new(@config) @parameters = ParameterStore(Parameter).new(@data_dir, "parameters.json", @replicator) apply_parameter @closed.set(false) @@ -127,8 +131,13 @@ module LavinMQ end private def accept_tcp(client, protocol) + remote_address = client.remote_address + unless @rate_limiter.allow?(remote_address.address) + @rate_limiter.log_rate_limited(remote_address.to_s) + client.close rescue nil + return + end spawn(name: "Accept TCP socket") do - remote_address = client.remote_address set_socket_options(client) set_buffer_size(client) conn_info = extract_conn_info(client) @@ -177,6 +186,11 @@ module LavinMQ end private def accept_unix(client, protocol) + unless @rate_limiter.allow?("unix") + @rate_limiter.log_rate_limited("unix") + client.close rescue nil + return + end spawn(name: "Accept UNIX socket") do remote_address = client.remote_address set_buffer_size(client) @@ -213,8 +227,13 @@ module LavinMQ end private def accept_tls(client, context, protocol) + remote_addr = client.remote_address + unless @rate_limiter.allow?(remote_addr.address) + @rate_limiter.log_rate_limited(remote_addr.to_s) + client.close rescue nil + return + end spawn(name: "Accept TLS socket") do - remote_addr = client.remote_address set_socket_options(client) if @config.tls_ktls? ssl_client = OpenSSL::SSL::NativeSocket::Server.new(client, context, sync_close: true) @@ -413,6 +432,7 @@ module LavinMQ @gc_stats = GC.prof_stats control_flow! + @rate_limiter.cleanup_stale_entries sleep @config.stats_interval.milliseconds end ensure