|
67 | 67 | -export([has_resp_header/2]). |
68 | 68 | -export([has_resp_body/1]). |
69 | 69 | -export([delete_resp_header/2]). |
| 70 | +-export([set_cors_headers/2]). |
| 71 | +-export([set_cors_preflight_headers/2]). |
70 | 72 | -export([reply/2]). |
71 | 73 | -export([reply/3]). |
72 | 74 | -export([reply/4]). |
|
86 | 88 | -export([lock/1]). |
87 | 89 | -export([to_list/1]). |
88 | 90 |
|
| 91 | +-type cors_allowed_origins() :: [binary()] | binary(). |
| 92 | +-type cors_allowed_methods() :: [binary()]. |
| 93 | +-type cors_allowed_headers() :: [binary()]. |
| 94 | +-type cors_max_age() :: non_neg_integer() | max. |
| 95 | +-type cors_header() :: {origins, cors_allowed_origins()} |
| 96 | + | {exposed_headers, cors_allowed_headers()} |
| 97 | + | {credentials, boolean()}. |
| 98 | +-export_type([cors_header/0]). |
| 99 | +-type cors_preflight_header() :: {origins, cors_allowed_origins()} |
| 100 | + | {methods, cors_allowed_methods()} |
| 101 | + | {headers, cors_allowed_headers()} |
| 102 | + | {credentials, boolean()} |
| 103 | + | {max_age, max | non_neg_integer()}. |
| 104 | +-export_type([cors_preflight_header/0]). |
| 105 | +-record(cors, { |
| 106 | + origins = [] :: cors_allowed_origins(), |
| 107 | + methods = [] :: cors_allowed_methods(), |
| 108 | + headers = [] :: cors_allowed_headers(), |
| 109 | + exposed_headers = [] :: cors_allowed_headers(), |
| 110 | + credentials = false :: boolean(), |
| 111 | + max_age :: cors_max_age() |
| 112 | +}). |
| 113 | +-type cors_state() :: #cors{}. |
| 114 | + |
89 | 115 | -type cookie_opts() :: cow_cookie:cookie_opts(). |
90 | 116 | -export_type([cookie_opts/0]). |
91 | 117 |
|
@@ -666,6 +692,137 @@ delete_resp_header(Name, Req=#http_req{resp_headers=RespHeaders}) -> |
666 | 692 | RespHeaders2 = lists:keydelete(Name, 1, RespHeaders), |
667 | 693 | Req#http_req{resp_headers=RespHeaders2}. |
668 | 694 |
|
| 695 | +-spec set_cors_headers([cors_header()], Req) -> Req when Req :: req(). |
| 696 | +set_cors_headers(Input, Req) -> |
| 697 | + try |
| 698 | + State = cors_state(Input), |
| 699 | + Origin = |
| 700 | + match_cors_origin( |
| 701 | + header(<<"origin">>, Req), |
| 702 | + State#cors.origins), |
| 703 | + |
| 704 | + Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req), |
| 705 | + set_cors_exposed_headers(State#cors.exposed_headers, Req2) |
| 706 | + catch throw:_Reason -> |
| 707 | + Req |
| 708 | + end. |
| 709 | + |
| 710 | +-spec set_cors_preflight_headers([cors_preflight_header()], Req) -> Req when Req :: req(). |
| 711 | +set_cors_preflight_headers(Input, Req) -> |
| 712 | + try |
| 713 | + State = cors_state(Input), |
| 714 | + Origin = |
| 715 | + match_cors_origin( |
| 716 | + header(<<"origin">>, Req), |
| 717 | + State#cors.origins), |
| 718 | + Method = |
| 719 | + match_cors_method( |
| 720 | + header(<<"access-control-request-method">>, Req), |
| 721 | + State#cors.methods), |
| 722 | + Headers = |
| 723 | + match_cors_headers( |
| 724 | + header(<<"access-control-request-headers">>, Req), |
| 725 | + State#cors.headers), |
| 726 | + |
| 727 | + Req2 = set_cors_allow_credentials(State#cors.credentials, Origin, Req), |
| 728 | + Req3 = set_cors_max_age(State#cors.max_age, Req2), |
| 729 | + Req4 = set_cors_allowed_methods([Method], Req3), |
| 730 | + set_cors_allowed_headers(Headers, Req4) |
| 731 | + catch throw:_Reason -> |
| 732 | + Req |
| 733 | + end. |
| 734 | + |
| 735 | +-spec set_cors_allow_credentials(boolean(), binary(), Req) -> Req when Req :: req(). |
| 736 | +set_cors_allow_credentials(Credentials, Origin, Req) -> |
| 737 | + case match_cors_credentials(Credentials, Origin) of |
| 738 | + true -> |
| 739 | + Req2 = set_resp_header(<<"access-control-allow-origin">>, Origin, Req), |
| 740 | + set_resp_header(<<"access-control-allow-credentials">>, <<"true">>, Req2); |
| 741 | + _ -> |
| 742 | + set_resp_header(<<"access-control-allow-origin">>, Origin, Req) |
| 743 | + end. |
| 744 | + |
| 745 | +-spec set_cors_max_age(cors_max_age(), Req) -> Req when Req :: req(). |
| 746 | +set_cors_max_age(undefined, Req) -> |
| 747 | + Req; |
| 748 | +set_cors_max_age(max, Req) -> |
| 749 | + set_resp_header(<<"access-control-max-age">>, <<"1728000">>, Req); |
| 750 | +set_cors_max_age(Val, Req) -> |
| 751 | + set_resp_header(<<"access-control-max-age">>, integer_to_binary(Val), Req). |
| 752 | + |
| 753 | +-spec set_cors_allowed_methods(cors_allowed_methods(), Req) -> Req when Req :: req(). |
| 754 | +%% NOTE: just to make dialyzer happy. We would need this statement |
| 755 | +%% if we decided to return an entire list of allowed methods |
| 756 | +%% instead of single one passed with the particular request. |
| 757 | +%% set_cors_allowed_methods([], Req) -> |
| 758 | +%% Req; |
| 759 | +set_cors_allowed_methods(Val, Req) -> |
| 760 | + set_resp_header(<<"access-control-allow-methods">>, binary_join(Val, <<$,>>), Req). |
| 761 | + |
| 762 | +-spec set_cors_allowed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). |
| 763 | +set_cors_allowed_headers([], Req) -> |
| 764 | + Req; |
| 765 | +set_cors_allowed_headers(Val, Req) -> |
| 766 | + set_resp_header(<<"access-control-allow-headers">>, binary_join(Val, <<$,>>), Req). |
| 767 | + |
| 768 | +-spec set_cors_exposed_headers(cors_allowed_headers(), Req) -> Req when Req :: req(). |
| 769 | +set_cors_exposed_headers([], Req) -> |
| 770 | + Req; |
| 771 | +set_cors_exposed_headers(L, Req) -> |
| 772 | + set_resp_header(<<"access-control-expose-headers">>, binary_join(L, <<$,>>), Req). |
| 773 | + |
| 774 | +-spec match_cors_origin(binary() | undefined, cors_allowed_origins()) -> binary(). |
| 775 | +match_cors_origin(undefined, Origins) -> |
| 776 | + throw({bad_origin, undefined, Origins}); |
| 777 | +match_cors_origin(Val, Val) -> |
| 778 | + Val; |
| 779 | +match_cors_origin(Val, <<$*>>) -> |
| 780 | + Val; |
| 781 | +match_cors_origin(Val, Origins) when is_list(Origins) -> |
| 782 | + case lists:member(Val, Origins) of |
| 783 | + true -> Val; |
| 784 | + _ -> throw({nomatch_origin, Val, Origins}) |
| 785 | + end; |
| 786 | +match_cors_origin(Val, Origins) -> |
| 787 | + throw({nomatch_origin, Val, Origins}). |
| 788 | + |
| 789 | +-spec match_cors_method(binary() | undefined, cors_allowed_methods()) -> binary(). |
| 790 | +match_cors_method(undefined, Methods) -> |
| 791 | + throw({bad_method, undefined, Methods}); |
| 792 | +match_cors_method(Val, Methods) -> |
| 793 | + case lists:member(Val, Methods) of |
| 794 | + true -> Val; |
| 795 | + _ -> throw({nomatch_method, Val, Methods}) |
| 796 | + end. |
| 797 | + |
| 798 | +-spec match_cors_headers(binary() | undefined, cors_allowed_headers()) -> cors_allowed_headers(). |
| 799 | +match_cors_headers(undefined, _) -> |
| 800 | + []; |
| 801 | +match_cors_headers(Val, Headers) -> |
| 802 | + lists:filter( |
| 803 | + fun(Header) -> lists:member(Header, Headers) end, |
| 804 | + binary:split(Val, [<<$,>>, <<", ">>], [global])). |
| 805 | + |
| 806 | +-spec match_cors_credentials(boolean(), binary()) -> boolean(). |
| 807 | +match_cors_credentials(true, <<$*>>) -> |
| 808 | + throw({bad_credentials, true, <<$*>>}); |
| 809 | +match_cors_credentials(Val, _) -> |
| 810 | + Val. |
| 811 | + |
| 812 | +-spec cors_state([cors_header() | cors_preflight_header()]) -> cors_state(). |
| 813 | +cors_state(Headers) -> |
| 814 | + cors_state(Headers, #cors{}). |
| 815 | + |
| 816 | +-spec cors_state([cors_header() | cors_preflight_header()], cors_state()) -> cors_state(). |
| 817 | +cors_state([{origins, Val}|T], State) -> cors_state(T, State#cors{origins = Val}); |
| 818 | +cors_state([{methods, Val}|T], State) -> cors_state(T, State#cors{methods = Val}); |
| 819 | +cors_state([{headers, Val}|T], State) -> cors_state(T, State#cors{headers = Val}); |
| 820 | +cors_state([{exposed_headers, Val}|T], State) -> cors_state(T, State#cors{exposed_headers = Val}); |
| 821 | +cors_state([{credentials, Val}|T], State) -> cors_state(T, State#cors{credentials = Val}); |
| 822 | +cors_state([{max_age, Val}|T], State) -> cors_state(T, State#cors{max_age = Val}); |
| 823 | +cors_state([_|T], State) -> cors_state(T, State); |
| 824 | +cors_state([], State) -> State. |
| 825 | + |
669 | 826 | -spec reply(cowboy:http_status(), Req) -> Req when Req::req(). |
670 | 827 | reply(Status, Req=#http_req{resp_body=Body}) -> |
671 | 828 | reply(Status, [], Body, Req). |
@@ -1244,6 +1401,15 @@ filter_constraints(Tail, Map, Key, Value, Constraints) -> |
1244 | 1401 | filter(Tail, Map#{Key => Value2}) |
1245 | 1402 | end. |
1246 | 1403 |
|
| 1404 | +-spec binary_join(binary() | [binary()], binary()) -> binary(). |
| 1405 | +binary_join([H|T], Sep) -> |
| 1406 | + lists:foldl( |
| 1407 | + fun(Val, Acc) -> |
| 1408 | + <<Acc/binary, Sep/binary, Val/binary>> |
| 1409 | + end, H, T). |
| 1410 | +%%binary_join([], _) -> <<>>; |
| 1411 | +%%binary_join(L, _) -> L. |
| 1412 | + |
1247 | 1413 | %% Tests. |
1248 | 1414 |
|
1249 | 1415 | -ifdef(TEST). |
@@ -1298,4 +1464,14 @@ merge_headers_test_() -> |
1298 | 1464 | {<<"server">>,<<"Cowboy">>}]} |
1299 | 1465 | ], |
1300 | 1466 | [fun() -> Res = merge_headers(L,R) end || {L, R, Res} <- Tests]. |
| 1467 | + |
| 1468 | +binary_join_test_() -> |
| 1469 | + Sep = <<$,>>, |
| 1470 | + Test = |
| 1471 | + [%%{<<$b>>, <<"b">>}, |
| 1472 | + %%{[], <<>>}, |
| 1473 | + {[<<$a>>], <<$a>>}, |
| 1474 | + {[<<$a>>, <<$b>>], <<"a,b">>}], |
| 1475 | + [fun() -> Output = binary_join(Input, Sep) end || {Input, Output} <- Test]. |
| 1476 | + |
1301 | 1477 | -endif. |
0 commit comments