Skip to content

Commit cd0f9f1

Browse files
committed
Add support of CORS
1 parent dbb6360 commit cd0f9f1

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

src/cowboy_req.erl

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
-export([has_resp_header/2]).
6868
-export([has_resp_body/1]).
6969
-export([delete_resp_header/2]).
70+
-export([set_cors_headers/2]).
71+
-export([set_cors_preflight_headers/2]).
7072
-export([reply/2]).
7173
-export([reply/3]).
7274
-export([reply/4]).
@@ -86,6 +88,30 @@
8688
-export([lock/1]).
8789
-export([to_list/1]).
8890

91+
-type cors_allowed_origins() :: [binary()] | binary().
92+
-type cors_allowed_methods() :: [binary()].
93+
-type cors_allowed_headers() :: [binary()].
94+
-type cors_max_age() :: non_neg_integer() | max.
95+
-type cors_header() :: {origins, cors_allowed_origins()}
96+
| {exposed_headers, cors_allowed_headers()}
97+
| {credentials, boolean()}.
98+
-export_type([cors_header/0]).
99+
-type cors_preflight_header() :: {origins, cors_allowed_origins()}
100+
| {methods, cors_allowed_methods()}
101+
| {headers, cors_allowed_headers()}
102+
| {credentials, boolean()}
103+
| {max_age, max | non_neg_integer()}.
104+
-export_type([cors_preflight_header/0]).
105+
-record(cors, {
106+
origins = [] :: cors_allowed_origins(),
107+
methods = [] :: cors_allowed_methods(),
108+
headers = [] :: cors_allowed_headers(),
109+
exposed_headers = [] :: cors_allowed_headers(),
110+
credentials = false :: boolean(),
111+
max_age :: cors_max_age()
112+
}).
113+
-type cors_state() :: #cors{}.
114+
89115
-type cookie_opts() :: cow_cookie:cookie_opts().
90116
-export_type([cookie_opts/0]).
91117

@@ -666,6 +692,137 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) ->
666692
RespHeaders2 = lists:keydelete(Name, 1, RespHeaders),
667693
Req#http_req{resp_headers=RespHeaders2}.
668694

695+
-spec set_cors_headers([cors_header()], Req) -> Req when Req :: req().
696+
set_cors_headers(Input, Req) ->
697+
try
698+
State = cors_state(Input),
699+
Origin =
700+
match_cors_origin(
701+
header(<<"origin">>, Req),
702+
State#cors.origins),
703+
704+
Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req),
705+
set_cors_exposed_headers(State#cors.exposed_headers, Req2)
706+
catch throw:_Reason ->
707+
Req
708+
end.
709+
710+
-spec set_cors_preflight_headers([cors_preflight_header()], Req) -> Req when Req :: req().
711+
set_cors_preflight_headers(Input, Req) ->
712+
try
713+
State = cors_state(Input),
714+
Origin =
715+
match_cors_origin(
716+
header(<<"origin">>, Req),
717+
State#cors.origins),
718+
Method =
719+
match_cors_method(
720+
header(<<"access-control-request-method">>, Req),
721+
State#cors.methods),
722+
Headers =
723+
match_cors_headers(
724+
header(<<"access-control-request-headers">>, Req),
725+
State#cors.headers),
726+
727+
Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req),
728+
Req3 = set_cors_max_age(State#cors.max_age, Req2),
729+
Req4 = set_cors_allowed_methods([Method], Req3),
730+
set_cors_allowed_headers(Headers, Req4)
731+
catch throw:_Reason ->
732+
Req
733+
end.
734+
735+
-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req().
736+
set_cors_allow_credentials(Credentials, Origin, Req) ->
737+
case match_cors_credentials(Credentials, Origin) of
738+
true ->
739+
Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req),
740+
set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2);
741+
_ ->
742+
set_resp_header(<<"access-control-allow-origin">>, Origin, Req)
743+
end.
744+
745+
-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req().
746+
set_cors_max_age(undefined, Req) ->
747+
Req;
748+
set_cors_max_age(max, Req) ->
749+
set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req);
750+
set_cors_max_age(Val, Req) ->
751+
set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req).
752+
753+
-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req().
754+
%% NOTE: just to make dialyzer happy. We would need this statement
755+
%% if we decided to return an entire list of allowed methods
756+
%% instead of single one passed with the particular request.
757+
%% set_cors_allowed_methods([], Req) ->
758+
%% Req;
759+
set_cors_allowed_methods(Val, Req) ->
760+
set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req).
761+
762+
-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
763+
set_cors_allowed_headers([], Req) ->
764+
Req;
765+
set_cors_allowed_headers(Val, Req) ->
766+
set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req).
767+
768+
-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req().
769+
set_cors_exposed_headers([], Req) ->
770+
Req;
771+
set_cors_exposed_headers(L, Req) ->
772+
set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req).
773+
774+
-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary().
775+
match_cors_origin(undefined, Origins) ->
776+
throw({bad_origin, undefined, Origins});
777+
match_cors_origin(Val, Val) ->
778+
Val;
779+
match_cors_origin(Val, <<$*>>) ->
780+
Val;
781+
match_cors_origin(Val, Origins) when is_list(Origins) ->
782+
case lists:member(Val, Origins) of
783+
true -> Val;
784+
_ -> throw({nomatch_origin, Val, Origins})
785+
end;
786+
match_cors_origin(Val, Origins) ->
787+
throw({nomatch_origin, Val, Origins}).
788+
789+
-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary().
790+
match_cors_method(undefined, Methods) ->
791+
throw({bad_method, undefined, Methods});
792+
match_cors_method(Val, Methods) ->
793+
case lists:member(Val, Methods) of
794+
true -> Val;
795+
_ -> throw({nomatch_method, Val, Methods})
796+
end.
797+
798+
-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers().
799+
match_cors_headers(undefined, _) ->
800+
[];
801+
match_cors_headers(Val, Headers) ->
802+
lists:filter(
803+
fun(Header) -> lists:member(Header, Headers) end,
804+
binary:split(Val, [<<$,>>, <<", ">>], [global])).
805+
806+
-spec match_cors_credentials(boolean(), binary()) -> boolean().
807+
match_cors_credentials(true, <<$*>>) ->
808+
throw({bad_credentials, true, <<$*>>});
809+
match_cors_credentials(Val, _) ->
810+
Val.
811+
812+
-spec cors_state([cors_header() | cors_preflight_header()]) -> cors_state().
813+
cors_state(Headers) ->
814+
cors_state(Headers, #cors{}).
815+
816+
-spec cors_state([cors_header() | cors_preflight_header()], cors_state()) -> cors_state().
817+
cors_state([{origins, Val}|T], State) -> cors_state(T, State#cors{origins = Val});
818+
cors_state([{methods, Val}|T], State) -> cors_state(T, State#cors{methods = Val});
819+
cors_state([{headers, Val}|T], State) -> cors_state(T, State#cors{headers = Val});
820+
cors_state([{exposed_headers, Val}|T], State) -> cors_state(T, State#cors{exposed_headers = Val});
821+
cors_state([{credentials, Val}|T], State) -> cors_state(T, State#cors{credentials = Val});
822+
cors_state([{max_age, Val}|T], State) -> cors_state(T, State#cors{max_age = Val});
823+
cors_state([_|T], State) -> cors_state(T, State);
824+
cors_state([], State) -> State.
825+
669826
-spec reply(cowboy:http_status(), Req) -> Req when Req::req().
670827
reply(Status, Req=#http_req{resp_body=Body}) ->
671828
reply(Status, [], Body, Req).
@@ -1244,6 +1401,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) ->
12441401
filter(Tail, Map#{Key => Value2})
12451402
end.
12461403

1404+
-spec binary_join(binary() | [binary()], binary()) -> binary().
1405+
binary_join([H|T], Sep) ->
1406+
lists:foldl(
1407+
fun(Val, Acc) ->
1408+
<<Acc/binary, Sep/binary, Val/binary>>
1409+
end, H, T).
1410+
%%binary_join([], _) -> <<>>;
1411+
%%binary_join(L, _) -> L.
1412+
12471413
%% Tests.
12481414

12491415
-ifdef(TEST).
@@ -1298,4 +1464,14 @@ merge_headers_test_() ->
12981464
{<<"server">>,<<"Cowboy">>}]}
12991465
],
13001466
[fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests].
1467+
1468+
binary_join_test_() ->
1469+
Sep = <<$,>>,
1470+
Test =
1471+
[%%{<<$b>>, <<"b">>},
1472+
%%{[], <<>>},
1473+
{[<<$a>>], <<$a>>},
1474+
{[<<$a>>, <<$b>>], <<"a,b">>}],
1475+
[fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test].
1476+
13011477
-endif.

0 commit comments

Comments
 (0)