Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/realtime/user_counter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ defmodule Realtime.UsersCounter do
@spec add(pid(), String.t()) :: :ok
def add(pid, tenant_id), do: tenant_id |> scope() |> :syn.join(tenant_id, pid)

@spec already_counted?(pid(), String.t()) :: boolean()
def already_counted?(pid, tenant_id) do
tenant_id |> scope() |> :syn.is_local_member(tenant_id, pid)
end

@doc """
Returns the count of all connected clients for a tenant for the cluster.
"""
Expand Down
63 changes: 34 additions & 29 deletions lib/realtime_web/channels/realtime_channel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defmodule RealtimeWeb.RealtimeChannel do
alias RealtimeWeb.SocketDisconnect
alias DBConnection.Backoff

alias Realtime.Api.Tenant
alias Realtime.Crypto
alias Realtime.GenCounter
alias Realtime.Helpers
Expand All @@ -19,6 +20,7 @@ defmodule RealtimeWeb.RealtimeChannel do
alias Realtime.Tenants.Authorization.Policies
alias Realtime.Tenants.Authorization.Policies.BroadcastPolicies
alias Realtime.Tenants.Connect
alias Realtime.UsersCounter

alias RealtimeWeb.Channels.Payloads.Join
alias RealtimeWeb.ChannelsAuthorization
Expand Down Expand Up @@ -52,8 +54,6 @@ defmodule RealtimeWeb.RealtimeChannel do
socket =
socket
|> assign_access_token(params)
|> assign_counter()
|> assign_presence_counter()
|> assign(:private?, !!params["config"]["private"])
|> assign(:policies, nil)

Expand All @@ -67,10 +67,11 @@ defmodule RealtimeWeb.RealtimeChannel do
end

with :ok <- SignalHandler.shutdown_in_progress?(),
:ok <- only_private?(tenant_id, socket),
:ok <- limit_joins(socket),
:ok <- limit_channels(socket),
:ok <- limit_max_users(socket),
%Tenant{} = tenant <- Tenants.Cache.get_tenant_by_external_id(tenant_id),
:ok <- only_private?(tenant, socket),
:ok <- limit_max_users(tenant, transport_pid),
:ok <- limit_joins(tenant, socket),
:ok <- limit_channels(tenant, socket),
{:ok, claims, confirm_token_ref} <- confirm_token(socket),
socket = assign_authorization_context(socket, sub_topic, claims),
{:ok, db_conn} <- Connect.lookup_or_start_connection(tenant_id),
Expand Down Expand Up @@ -131,10 +132,15 @@ defmodule RealtimeWeb.RealtimeChannel do
presence_enabled?: presence_enabled?
}

socket =
socket
|> assign_counter(tenant)
|> assign_presence_counter(tenant)

# Start presence and add user if presence is enabled
if presence_enabled?, do: send(self(), :sync_presence)

Realtime.UsersCounter.add(transport_pid, tenant_id)
UsersCounter.add(transport_pid, tenant_id)
SocketDisconnect.add(tenant_id, socket)

{:ok, state, assign(socket, assigns)}
Expand Down Expand Up @@ -515,8 +521,8 @@ defmodule RealtimeWeb.RealtimeChannel do
wait
end

def limit_joins(%{assigns: %{tenant: tenant, limits: limits}} = socket) do
rate_args = Tenants.joins_per_second_rate(tenant, limits.max_joins_per_second)
def limit_joins(tenant, socket) do
rate_args = Tenants.joins_per_second_rate(tenant)

RateCounter.new(rate_args)

Expand All @@ -534,36 +540,35 @@ defmodule RealtimeWeb.RealtimeChannel do
end
end

def limit_channels(%{assigns: %{tenant: tenant, limits: limits}, transport_pid: pid}) do
def limit_channels(tenant, %{transport_pid: pid}) do
key = Tenants.channels_per_client_key(tenant)

if Registry.count_match(Realtime.Registry, key, pid) + 1 > limits.max_channels_per_client do
if Registry.count_match(Realtime.Registry, key, pid) + 1 > tenant.max_channels_per_client do
{:error, :too_many_channels}
else
Registry.register(Realtime.Registry, Tenants.channels_per_client_key(tenant), pid)
:ok
end
end

defp limit_max_users(%{assigns: %{limits: %{max_concurrent_users: max_conn_users}, tenant: tenant}}) do
conns = Realtime.UsersCounter.tenant_users(tenant)

if conns < max_conn_users,
do: :ok,
else: {:error, :too_many_connections}
defp limit_max_users(tenant, transport_pid) do
if !UsersCounter.already_counted?(transport_pid, tenant.external_id) and
UsersCounter.tenant_users(tenant.external_id) >= tenant.max_concurrent_users do
{:error, :too_many_connections}
else
:ok
end
end

defp assign_counter(%{assigns: %{tenant: tenant, limits: limits}} = socket) do
rate_args = Tenants.events_per_second_rate(tenant, limits.max_events_per_second)
defp assign_counter(socket, tenant) do
rate_args = Tenants.events_per_second_rate(tenant)

RateCounter.new(rate_args)
assign(socket, :rate_counter, rate_args)
end

defp assign_counter(socket), do: socket

defp assign_presence_counter(%{assigns: %{tenant: tenant, limits: limits}} = socket) do
rate_args = Tenants.presence_events_per_second_rate(tenant, limits.max_events_per_second)
defp assign_presence_counter(socket, tenant) do
rate_args = Tenants.presence_events_per_second_rate(tenant)

RateCounter.new(rate_args)

Expand Down Expand Up @@ -786,12 +791,12 @@ defmodule RealtimeWeb.RealtimeChannel do

defp maybe_assign_policies(_, _, socket), do: {:ok, assign(socket, policies: nil)}

defp only_private?(tenant_id, %{assigns: %{private?: private?}}) do
tenant = Tenants.Cache.get_tenant_by_external_id(tenant_id)

if tenant.private_only and !private?,
do: {:error, :private_only},
else: :ok
defp only_private?(tenant, %{assigns: %{private?: private?}}) do
if tenant.private_only and !private? do
{:error, :private_only}
else
:ok
end
end

defp maybe_replay_messages(%{"broadcast" => %{"replay" => _}}, _sub_topic, _db_conn, _tenant_id, false = _private?) do
Expand Down
17 changes: 1 addition & 16 deletions lib/realtime_web/channels/user_socket.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,12 @@ defmodule RealtimeWeb.UserSocket do
{:ok, claims} <- ChannelsAuthorization.authorize_conn(token, jwt_secret_dec, jwt_jwks),
:ok <- TenantRateLimiters.check_tenant(tenant),
{:ok, postgres_cdc_module} <- PostgresCdc.driver(postgres_cdc_default) do
%Tenant{
extensions: extensions,
max_concurrent_users: max_conn_users,
max_events_per_second: max_events_per_second,
max_bytes_per_second: max_bytes_per_second,
max_joins_per_second: max_joins_per_second,
max_channels_per_client: max_channels_per_client,
postgres_cdc_default: postgres_cdc_default
} = tenant
%Tenant{extensions: extensions, postgres_cdc_default: postgres_cdc_default} = tenant

assigns = %RealtimeChannel.Assigns{
claims: claims,
jwt_secret: jwt_secret,
jwt_jwks: jwt_jwks,
limits: %{
max_concurrent_users: max_conn_users,
max_events_per_second: max_events_per_second,
max_bytes_per_second: max_bytes_per_second,
max_joins_per_second: max_joins_per_second,
max_channels_per_client: max_channels_per_client
},
postgres_extension: PostgresCdc.filter_settings(postgres_cdc_default, extensions),
postgres_cdc_module: postgres_cdc_module,
tenant: external_id,
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defmodule Realtime.MixProject do
def project do
[
app: :realtime,
version: "2.69.2",
version: "2.69.3",
elixir: "~> 1.18",
elixirc_paths: elixirc_paths(Mix.env()),
start_permanent: Mix.env() == :prod,
Expand Down
70 changes: 58 additions & 12 deletions test/integration/rt_channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1731,17 +1731,37 @@ defmodule Realtime.Integration.RtChannelTest do
setup [:rls_context]

test "max_concurrent_users limit respected", %{tenant: tenant, serializer: serializer} do
%{max_concurrent_users: max_concurrent_users} = Tenants.get_tenant_by_external_id(tenant.external_id)
Tenants.get_tenant_by_external_id(tenant.external_id)
change_tenant_configuration(tenant, :max_concurrent_users, 1)

{socket, _} = get_connection(tenant, serializer, role: "authenticated")
{socket1, _} = get_connection(tenant, serializer, role: "authenticated")
{socket2, _} = get_connection(tenant, serializer, role: "authenticated")
config = %{broadcast: %{self: true}, private: false}
realtime_topic = "realtime:#{random_string()}"
WebsocketClient.join(socket, realtime_topic, %{config: config})
WebsocketClient.join(socket, realtime_topic, %{config: config})
topic1 = "realtime:#{random_string()}"
topic2 = "realtime:#{random_string()}"
WebsocketClient.join(socket1, topic1, %{config: config})
WebsocketClient.join(socket1, topic2, %{config: config})

assert_receive %Message{
event: "phx_reply",
topic: ^topic1,
payload: %{"response" => %{"postgres_changes" => []}, "status" => "ok"}
},
500

assert_receive %Message{
event: "phx_reply",
topic: ^topic2,
payload: %{"response" => %{"postgres_changes" => []}, "status" => "ok"}
},
500

topic3 = "realtime:#{random_string()}"
WebsocketClient.join(socket2, topic3, %{config: config})

assert_receive %Message{
event: "phx_reply",
topic: ^topic3,
payload: %{
"response" => %{
"reason" => "ConnectionRateLimitReached: Too many connected users"
Expand All @@ -1751,9 +1771,17 @@ defmodule Realtime.Integration.RtChannelTest do
},
500

assert_receive %Message{event: "phx_close"}
# Limit is updated now joining should succeed
Realtime.Tenants.Cache.update_cache(%{tenant | max_concurrent_users: 2})

WebsocketClient.join(socket2, topic3, %{config: config})

change_tenant_configuration(tenant, :max_concurrent_users, max_concurrent_users)
assert_receive %Message{
event: "phx_reply",
topic: ^topic3,
payload: %{"response" => %{"postgres_changes" => []}, "status" => "ok"}
},
500
end

test "max_events_per_second limit respected", %{tenant: tenant, serializer: serializer} do
Expand All @@ -1762,20 +1790,25 @@ defmodule Realtime.Integration.RtChannelTest do
log =
capture_log(fn ->
{socket, _} = get_connection(tenant, serializer, role: "authenticated")
config = %{broadcast: %{self: true}, private: false, presence: %{enabled: false}}
config = %{broadcast: %{self: true, ack: false}, private: false, presence: %{enabled: false}}
realtime_topic = "realtime:#{random_string()}"

WebsocketClient.join(socket, realtime_topic, %{config: config})
assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^realtime_topic}, 500

for _ <- 1..1000, Process.alive?(socket) do
WebsocketClient.send_event(socket, realtime_topic, "broadcast", %{})
Process.sleep(10)
assert_receive %Message{event: "broadcast", topic: ^realtime_topic}, 500
end

# Wait for the rate counter to run logger function
Process.sleep(1500)
assert_receive %Message{event: "phx_close"}
RateCounterHelper.tick_tenant_rate_counters!(tenant.external_id)

# One more to cause the WebSocket to close

WebsocketClient.send_event(socket, realtime_topic, "broadcast", %{})

assert_receive %Message{event: "phx_close"}, 1000
end)

assert log =~ "MessagePerSecondRateLimitReached"
Expand Down Expand Up @@ -1815,6 +1848,18 @@ defmodule Realtime.Integration.RtChannelTest do

refute_receive %Message{event: "phx_reply", topic: ^realtime_topic_2}, 500
refute_receive %Message{event: "presence_state", topic: ^realtime_topic_2}, 500

# Limit is updated now joining should succeed
Realtime.Tenants.Cache.update_cache(%{tenant | max_channels_per_client: 2})

WebsocketClient.join(socket, realtime_topic_2, %{config: config})

assert_receive %Message{
event: "phx_reply",
payload: %{"response" => %{"postgres_changes" => []}, "status" => "ok"},
topic: ^realtime_topic_2
},
500
end

test "max_joins_per_second limit respected", %{tenant: tenant, serializer: serializer} do
Expand All @@ -1827,10 +1872,11 @@ defmodule Realtime.Integration.RtChannelTest do
# Burst of joins that won't be blocked as RateCounter tick won't run
for _ <- 1..300 do
WebsocketClient.join(socket, realtime_topic, %{config: config})
assert_receive %Message{event: "phx_reply", payload: %{"status" => "ok"}, topic: ^realtime_topic}, 500
end

# Wait for RateCounter tick
Process.sleep(1000)
RateCounterHelper.tick_tenant_rate_counters!(tenant.external_id)

# These ones will be blocked
for _ <- 1..300 do
Expand Down
12 changes: 12 additions & 0 deletions test/realtime/user_counter_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ defmodule Realtime.UsersCounterTest do
%{tenant_id: tenant_id, count: count, nodes: nodes}
end

describe "already_counted?/2" do
test "returns true if pid already counted for tenant", %{tenant_id: tenant_id} do
pid = self()
assert UsersCounter.add(pid, tenant_id) == :ok
assert UsersCounter.already_counted?(pid, tenant_id) == true
end

test "returns false if pid not counted for tenant" do
assert UsersCounter.already_counted?(self(), random_string()) == false
end
end

describe "add/1" do
test "starts counter for tenant" do
assert UsersCounter.add(self(), random_string()) == :ok
Expand Down
37 changes: 21 additions & 16 deletions test/realtime_web/channels/realtime_channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@ defmodule RealtimeWeb.RealtimeChannelTest do
alias Realtime.RateCounter
alias RealtimeWeb.UserSocket

@default_limits %{
max_concurrent_users: 200,
max_events_per_second: 100,
max_joins_per_second: 100,
max_channels_per_client: 100,
max_bytes_per_second: 100_000
}

setup do
tenant = Containers.checkout_tenant(run_migrations: true)
Realtime.Tenants.Cache.update_cache(tenant)
Expand Down Expand Up @@ -594,25 +586,38 @@ defmodule RealtimeWeb.RealtimeChannelTest do
jwt = Generators.generate_jwt_token(tenant)
{:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))

socket = Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 1}})
Realtime.Tenants.Cache.update_cache(%{tenant | max_concurrent_users: 1})

assert {:ok, _, %Socket{}} = subscribe_and_join(socket, "realtime:test", %{})
end

test "reached", %{tenant: tenant} do
test "reached after connecting", %{tenant: tenant} do
jwt = Generators.generate_jwt_token(tenant)
{:ok, %Socket{} = socket} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))

socket_at_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: 0}})
Realtime.Tenants.Cache.update_cache(%{tenant | max_concurrent_users: 1})

socket_over_capacity =
Socket.assign(socket, %{limits: %{@default_limits | max_concurrent_users: -1}})
pid = spawn_link(fn -> Process.sleep(:infinity) end)
Realtime.UsersCounter.add(pid, tenant.external_id)

assert {:error, %{reason: "ConnectionRateLimitReached: Too many connected users"}} =
subscribe_and_join(socket_at_capacity, "realtime:test", %{})
subscribe_and_join(socket, "realtime:test", %{})

pid = spawn_link(fn -> Process.sleep(:infinity) end)
Realtime.UsersCounter.add(pid, tenant.external_id)

assert {:error, %{reason: "ConnectionRateLimitReached: Too many connected users"}} =
subscribe_and_join(socket_over_capacity, "realtime:test", %{})
subscribe_and_join(socket, "realtime:test", %{})
end

test "reached before connecting", %{tenant: tenant} do
jwt = Generators.generate_jwt_token(tenant)

Realtime.Tenants.Cache.update_cache(%{tenant | max_concurrent_users: 1})

Realtime.UsersCounter.add(self(), tenant.external_id)

{:error, :too_many_connections} = connect(UserSocket, %{"log_level" => "warning"}, conn_opts(tenant, jwt))
end
end

Expand Down