Skip to content

Commit e3fd6b2

Browse files
committed
Make HTTP/2 Websocket call terminate/3 on socket close
The close reason will differ from HTTP/1.1 because we don't have access to the socket. Also trapping exits is required to process the 'EXIT' signal and call terminate/3.
1 parent e713a63 commit e3fd6b2

File tree

4 files changed

+107
-42
lines changed

4 files changed

+107
-42
lines changed

src/cowboy_websocket.erl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,12 @@ loop(State=#state{parent=Parent, socket=Socket, messages=Messages,
492492
before_loop(State, HandlerState, ParseState);
493493
%% System messages.
494494
{'EXIT', Parent, Reason} ->
495-
%% @todo We should exit gracefully.
496-
exit(Reason);
495+
%% The terminate reason will differ with HTTP/1.1
496+
%% since we don't have direct access to the socket.
497+
%% @todo Perhaps we can make cowboy_children:terminate
498+
%% receive the shutdown Reason and send {shutdown, Reason}
499+
%% instead of just 'shutdown' in this scenario.
500+
terminate(State, HandlerState, Reason);
497501
{system, From, Request} ->
498502
sys:handle_system_msg(Request, From, Parent, ?MODULE, [],
499503
{State, HandlerState, ParseState});

test/handlers/ws_terminate_h.erl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ init(Req, _) ->
2121
end,
2222
{cowboy_websocket, Req, #state{pid=Pid}, Opts}.
2323

24-
websocket_init(State) ->
24+
websocket_init(State=#state{pid=Pid}) ->
25+
Pid ! {ws_pid, self()},
26+
%% We must trap 'EXIT' signals for HTTP/2 to call terminate/3.
27+
process_flag(trap_exit, true),
2528
{ok, State}.
2629

2730
websocket_handle(_, State) ->

test/ws_SUITE.erl

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -573,39 +573,6 @@ ws_subprotocol(Config) ->
573573
{_, "foo"} = lists:keyfind("sec-websocket-protocol", 1, Headers),
574574
ok.
575575

576-
ws_terminate(Config) ->
577-
doc("The Req object is kept in a more compact form by default."),
578-
{ok, Socket, _} = do_handshake("/terminate",
579-
"x-test-pid: " ++ pid_to_list(self()) ++ "\r\n", Config),
580-
%% Send a close frame.
581-
ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>),
582-
{ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000),
583-
{error, closed} = gen_tcp:recv(Socket, 0, 6000),
584-
%% Confirm terminate/3 was called with a compacted Req.
585-
receive {terminate, _, Req} ->
586-
true = maps:is_key(path, Req),
587-
false = maps:is_key(headers, Req),
588-
ok
589-
after 1000 ->
590-
error(timeout)
591-
end.
592-
593-
ws_terminate_fun(Config) ->
594-
doc("A function can be given to filter the Req object."),
595-
{ok, Socket, _} = do_handshake("/terminate?req_filter",
596-
"x-test-pid: " ++ pid_to_list(self()) ++ "\r\n", Config),
597-
%% Send a close frame.
598-
ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>),
599-
{ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000),
600-
{error, closed} = gen_tcp:recv(Socket, 0, 6000),
601-
%% Confirm terminate/3 was called with a compacted Req.
602-
receive {terminate, _, Req} ->
603-
filtered = Req,
604-
ok
605-
after 1000 ->
606-
error(timeout)
607-
end.
608-
609576
ws_text_fragments(Config) ->
610577
doc("Client sends fragmented text frames."),
611578
{ok, Socket, _} = do_handshake("/ws_echo", Config),

test/ws_handler_SUITE.erl

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ init_dispatch(Name) ->
7474
{"/active", ws_active_commands_h, InitialState},
7575
{"/deflate", ws_deflate_commands_h, InitialState},
7676
{"/set_options", ws_set_options_commands_h, InitialState},
77-
{"/shutdown_reason", ws_shutdown_reason_commands_h, InitialState}
77+
{"/shutdown_reason", ws_shutdown_reason_commands_h, InitialState},
78+
{"/terminate", ws_terminate_h, InitialState}
7879
]}]).
7980

8081
%% Support functions for testing using Gun.
@@ -116,6 +117,15 @@ ensure_handle_is_called(ConnPid, StreamRef, "/handle") ->
116117
ensure_handle_is_called(_, _, _) ->
117118
ok.
118119

120+
do_receive(Tag) ->
121+
receive
122+
Msg when element(1, Msg) =:= Tag ->
123+
Msg
124+
after 1000 ->
125+
ct:pal("do_receive(~p): ~p", [Tag, process_info(self(), messages)]),
126+
error(timeout)
127+
end.
128+
119129
%% Tests.
120130

121131
websocket_init_nothing(Config) ->
@@ -134,7 +144,7 @@ do_nothing(Config, Path) ->
134144
{ok, ConnPid, StreamRef} = gun_open_ws(Config, Path, []),
135145
ensure_handle_is_called(ConnPid, StreamRef, Path),
136146
{error, timeout} = receive_ws(ConnPid, StreamRef),
137-
ok.
147+
gun:close(ConnPid).
138148

139149
websocket_init_invalid(Config) ->
140150
doc("The connection must be closed when websocket_init/1 returns an invalid command."),
@@ -178,7 +188,7 @@ do_one_frame(Config, Path) ->
178188
]),
179189
ensure_handle_is_called(ConnPid, StreamRef, Path),
180190
{ok, {text, <<"One frame!">>}} = receive_ws(ConnPid, StreamRef),
181-
ok.
191+
gun:close(ConnPid).
182192

183193
websocket_init_many_frames(Config) ->
184194
doc("Multiple frames are received when websocket_init/1 returns them as commands."),
@@ -200,7 +210,7 @@ do_many_frames(Config, Path) ->
200210
ensure_handle_is_called(ConnPid, StreamRef, Path),
201211
{ok, {text, <<"One frame!">>}} = receive_ws(ConnPid, StreamRef),
202212
{ok, {binary, <<"Two frames!">>}} = receive_ws(ConnPid, StreamRef),
203-
ok.
213+
gun:close(ConnPid).
204214

205215
websocket_init_close_frame(Config) ->
206216
doc("A single close frame is received when websocket_init/1 returns it as a command."),
@@ -266,7 +276,7 @@ websocket_active_false(Config) ->
266276
{ok, {binary, _}} = receive_ws(ConnPid, StreamRef),
267277
{ok, {text, <<"Not received until the handler enables active again.">>}}
268278
= receive_ws(ConnPid, StreamRef),
269-
ok.
279+
gun:close(ConnPid).
270280

271281
websocket_deflate_false(Config) ->
272282
doc("The {deflate, false} command temporarily disables compression. "
@@ -305,7 +315,7 @@ websocket_deflate_ignore_if_not_negotiated(Config) ->
305315
gun:ws_send(ConnPid, StreamRef, {text, <<"Hello.">>}),
306316
{ok, {text, <<"Hello.">>}} = receive_ws(ConnPid, StreamRef)
307317
end || _ <- lists:seq(1, 10)],
308-
ok.
318+
gun:close(ConnPid).
309319

310320
websocket_set_options_idle_timeout(Config) ->
311321
doc("The idle_timeout option can be modified using the "
@@ -390,3 +400,84 @@ websocket_shutdown_reason(Config) ->
390400
after 1000 ->
391401
error(timeout)
392402
end.
403+
404+
websocket_terminate_close_normal(Config) ->
405+
doc("Receiving a close frame results in a terminate/3 call. "
406+
"The Req object is kept in a more compact form by default."),
407+
ConnPid = gun_open(Config, #{http2_opts => #{notify_settings_changed => true}}),
408+
do_await_enable_connect_protocol(config(protocol, Config), ConnPid),
409+
StreamRef = gun:ws_upgrade(ConnPid, "/terminate", [
410+
{<<"x-test-pid">>, pid_to_list(self())}
411+
]),
412+
{upgrade, [<<"websocket">>], _} = gun:await(ConnPid, StreamRef),
413+
{ws_pid, WsPid} = do_receive(ws_pid),
414+
MRef = monitor(process, WsPid),
415+
gun:ws_send(ConnPid, StreamRef, close),
416+
{terminate, remote, Req} = do_receive(terminate),
417+
{'DOWN', MRef, process, WsPid, normal} = do_receive('DOWN'),
418+
%% Confirm terminate/3 was called with a compacted Req.
419+
true = maps:is_key(path, Req),
420+
false = maps:is_key(headers, Req),
421+
ok.
422+
423+
websocket_terminate_close_reason(Config) ->
424+
doc("Receiving a close frame results in a terminate/3 call. "
425+
"The Req object is kept in a more compact form by default."),
426+
ConnPid = gun_open(Config, #{http2_opts => #{notify_settings_changed => true}}),
427+
do_await_enable_connect_protocol(config(protocol, Config), ConnPid),
428+
StreamRef = gun:ws_upgrade(ConnPid, "/terminate", [
429+
{<<"x-test-pid">>, pid_to_list(self())}
430+
]),
431+
{upgrade, [<<"websocket">>], _} = gun:await(ConnPid, StreamRef),
432+
{ws_pid, WsPid} = do_receive(ws_pid),
433+
MRef = monitor(process, WsPid),
434+
gun:ws_send(ConnPid, StreamRef, {close, 4000, <<"test-close">>}),
435+
{terminate, {remote, 4000, <<"test-close">>}, Req} = do_receive(terminate),
436+
{'DOWN', MRef, process, WsPid, normal} = do_receive('DOWN'),
437+
%% Confirm terminate/3 was called with a compacted Req.
438+
true = maps:is_key(path, Req),
439+
false = maps:is_key(headers, Req),
440+
ok.
441+
442+
websocket_terminate_socket_close(Config) ->
443+
doc("The socket getting closed results in a terminate/3 call. "
444+
"The Req object is kept in a more compact form by default."),
445+
Protocol = config(protocol, Config),
446+
ConnPid = gun_open(Config, #{http2_opts => #{notify_settings_changed => true}}),
447+
do_await_enable_connect_protocol(Protocol, ConnPid),
448+
StreamRef = gun:ws_upgrade(ConnPid, "/terminate", [
449+
{<<"x-test-pid">>, pid_to_list(self())}
450+
]),
451+
{upgrade, [<<"websocket">>], _} = gun:await(ConnPid, StreamRef),
452+
{ws_pid, WsPid} = do_receive(ws_pid),
453+
MRef = monitor(process, WsPid),
454+
gun:close(ConnPid),
455+
%% Terminate reasons differ depending on the protocol.
456+
{terminate, Reason, Req} = do_receive(terminate),
457+
case Reason of
458+
{error, closed} when Protocol =:= http -> ok;
459+
shutdown when Protocol =:= http2 -> ok
460+
end,
461+
{'DOWN', MRef, process, WsPid, normal} = do_receive('DOWN'),
462+
%% Confirm terminate/3 was called with a compacted Req.
463+
true = maps:is_key(path, Req),
464+
false = maps:is_key(headers, Req),
465+
ok.
466+
467+
websocket_terminate_req_filter(Config) ->
468+
doc("Receiving a close frame results in a terminate/3 call. "
469+
"A function can be given to filter the Req object."),
470+
ConnPid = gun_open(Config, #{http2_opts => #{notify_settings_changed => true}}),
471+
do_await_enable_connect_protocol(config(protocol, Config), ConnPid),
472+
StreamRef = gun:ws_upgrade(ConnPid, "/terminate?req_filter", [
473+
{<<"x-test-pid">>, pid_to_list(self())}
474+
]),
475+
{upgrade, [<<"websocket">>], _} = gun:await(ConnPid, StreamRef),
476+
{ws_pid, WsPid} = do_receive(ws_pid),
477+
MRef = monitor(process, WsPid),
478+
gun:ws_send(ConnPid, StreamRef, close),
479+
{terminate, remote, Req} = do_receive(terminate),
480+
{'DOWN', MRef, process, WsPid, normal} = do_receive('DOWN'),
481+
%% Confirm terminate/3 was called with a filtered Req.
482+
filtered = Req,
483+
ok.

0 commit comments

Comments
 (0)