Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deps/rabbitmq_aws/include/rabbitmq_aws.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@
security_token :: security_token() | undefined,
region :: region() | undefined,
imdsv2_token:: imdsv2token() | undefined,
error :: atom() | string() | undefined}).
error :: atom() | string() | undefined,
gun_connections = #{} :: #{string() => pid()} % host -> gun_pid mapping
}).
-type state() :: #state{}.

-type scheme() :: atom().
Expand Down
177 changes: 155 additions & 22 deletions deps/rabbitmq_aws/src/rabbitmq_aws.erl
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,12 @@ init([]) ->
{ok, #state{}}.


terminate(_, _) ->
ok.

terminate(_, State) ->
%% Close all Gun connections
maps:fold(fun(_Host, ConnPid, _Acc) ->
gun:close(ConnPid)
end, ok, State#state.gun_connections),
ok.

code_change(_, _, State) ->
{ok, State}.
Expand Down Expand Up @@ -218,11 +221,15 @@ handle_msg({set_credentials, AccessKey, SecretAccessKey}, State) ->
error = undefined}};

handle_msg({set_credentials, NewState}, State) ->
{reply, ok, State#state{access_key = NewState#state.access_key,
secret_access_key = NewState#state.secret_access_key,
security_token = NewState#state.security_token,
expiration = NewState#state.expiration,
error = NewState#state.error}};
spawn(fun() -> maps:fold(fun(_Host, ConnPid, _Acc) ->
gun:close(ConnPid)
end, ok, State#state.gun_connections) end),
{reply, ok, State#state{access_key = NewState#state.access_key,
secret_access_key = NewState#state.secret_access_key,
security_token = NewState#state.security_token,
expiration = NewState#state.expiration,
error = NewState#state.error,
gun_connections = #{}}}; % Potentially new credentials, so clear the connection pool?

handle_msg({set_region, Region}, State) ->
{reply, ok, State#state{region = Region}};
Expand Down Expand Up @@ -292,7 +299,7 @@ get_content_type(Headers) ->
proplists:get_value("Content-Type", Headers, "text/xml");
Other -> Other
end,
parse_content_type(Value).
parse_content_type(Value).

-spec has_credentials() -> boolean().
has_credentials() ->
Expand Down Expand Up @@ -323,7 +330,7 @@ expired_credentials(Expiration) ->
%% - Credentials file
%% - EC2 Instance Metadata Service
%% @end
load_credentials(#state{region = Region}) ->
load_credentials(#state{region = Region, gun_connections = GunConnections}) ->
case rabbitmq_aws_config:credentials() of
{ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} ->
{ok, #state{region = Region,
Expand All @@ -332,7 +339,8 @@ load_credentials(#state{region = Region}) ->
secret_access_key = SecretAccessKey,
expiration = Expiration,
security_token = SecurityToken,
imdsv2_token = undefined}};
imdsv2_token = undefined,
gun_connections = GunConnections}};
{error, Reason} ->
error_logger:error_msg("Could not load AWS credentials from environment variables, AWS_CONFIG_FILE, AWS_SHARED_CREDENTIALS_FILE or EC2 metadata endpoint: ~tp. Will depend on config settings to be set~n", [Reason]),
{error, #state{region = Region,
Expand All @@ -341,7 +349,8 @@ load_credentials(#state{region = Region}) ->
secret_access_key = undefined,
expiration = undefined,
security_token = undefined,
imdsv2_token = undefined}}
imdsv2_token = undefined,
gun_connections = GunConnections}}
end.


Expand Down Expand Up @@ -382,7 +391,7 @@ parse_content_type(ContentType) ->
%% @doc Make the API request and return the formatted response.
%% @end
perform_request(State, Service, Method, Headers, Path, Body, Options, Host) ->
perform_request_has_creds(has_credentials(State), State, Service, Method,
perform_request_has_creds(has_credentials(State), State, Service, Method,
Headers, Path, Body, Options, Host).


Expand All @@ -396,7 +405,7 @@ perform_request(State, Service, Method, Headers, Path, Body, Options, Host) ->
%% otherwise return an error result.
%% @end
perform_request_has_creds(true, State, Service, Method, Headers, Path, Body, Options, Host) ->
perform_request_creds_expired(expired_credentials(State#state.expiration), State,
perform_request_creds_expired(expired_credentials(State#state.expiration), State,
Service, Method, Headers, Path, Body, Options, Host);
perform_request_has_creds(false, State, _, _, _, _, _, _, _) ->
perform_request_creds_error(State).
Expand All @@ -412,7 +421,7 @@ perform_request_has_creds(false, State, _, _, _, _, _, _, _) ->
%% credentials before performing the request.
%% @end
perform_request_creds_expired(false, State, Service, Method, Headers, Path, Body, Options, Host) ->
perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host);
perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host);
perform_request_creds_expired(true, State, _, _, _, _, _, _, _) ->
perform_request_creds_error(State#state{error = "Credentials expired!"}).

Expand All @@ -428,7 +437,7 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options,
URI = endpoint(State, Host, Service, Path),
SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body),
ContentType = proplists:get_value("content-type", SignedHeaders, undefined),
perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options).
perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options).


-spec perform_request_with_creds(State :: state(), Method :: method(), URI :: string(),
Expand All @@ -439,13 +448,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options,
%% expired, perform the request and return the response.
%% @end
perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) ->
Options1 = ensure_timeout(Options0),
Response = httpc:request(Method, {URI, Headers}, Options1, []),
{format_response(Response), State};
{Response, NewState} = gun_request(State, Method, URI, Headers, <<>>, Options0),
{format_response(Response), NewState};
perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) ->
Options1 = ensure_timeout(Options0),
Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []),
{format_response(Response), State}.
GunHeaders = [{"content-type", ContentType} | Headers],
{Response, NewState} = gun_request(State, Method, URI, GunHeaders, Body, Options0),
{format_response(Response), NewState}.


-spec perform_request_creds_error(State :: state()) ->
Expand Down Expand Up @@ -566,3 +574,128 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) ->
timer:sleep(WaitTimeBetweenRetries),
api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries)
end.

%% Gun HTTP client functions
gun_request(State, Method, URI, Headers, Body, Options) ->
{Host, Port, Path} = parse_uri(URI),
{ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Path, Options),
Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT),
try
StreamRef = do_gun_request(ConnPid, Method, Path, Headers, Body),
case gun:await(ConnPid, StreamRef, Timeout) of
{response, fin, Status, RespHeaders} ->
Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}},
{Response, NewState};
{response, nofin, Status, RespHeaders} ->
{ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout),
Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}},
{Response, NewState};
{error, Reason} ->
{{error, Reason}, NewState}
end
catch
_:Error ->
% Connection failed, remove from pool and return error
HostKey = get_connection_key(Host, Port, Path, Options),
NewConnections = maps:remove(HostKey, NewState#state.gun_connections),
gun:close(ConnPid),
{{error, Error}, NewState#state{gun_connections = NewConnections}}
end.

do_gun_request(ConnPid, get, Path, Headers, _Body) ->
gun:get(ConnPid, Path, Headers);
do_gun_request(ConnPid, post, Path, Headers, Body) ->
gun:post(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, put, Path, Headers, Body) ->
gun:put(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, head, Path, Headers, _Body) ->
gun:head(ConnPid, Path, Headers, #{});
do_gun_request(ConnPid, delete, Path, Headers, _Body) ->
gun:delete(ConnPid, Path, Headers, #{});
do_gun_request(ConnPid, patch, Path, Headers, Body) ->
gun:patch(ConnPid, Path, Headers, Body, #{});
do_gun_request(ConnPid, options, Path, Headers, _Body) ->
gun:options(ConnPid, Path, Headers, #{}).

get_or_create_gun_connection(State, Host, Port, Path, Options) ->
HostKey = get_connection_key(Host, Port, Path, Options),
case maps:get(HostKey, State#state.gun_connections, undefined) of
undefined ->
create_gun_connection(State, Host, Port, Path, HostKey, Options);
ConnPid ->
case is_process_alive(ConnPid) andalso gun:info(ConnPid) =/= undefined of
true ->
{ConnPid, State};
false ->
% Connection is dead, create new one
gun:close(ConnPid),
create_gun_connection(State, Host, Port, Path, HostKey, Options)
end
end.

get_connection_key(Host, Port, Path, Options) ->
case proplists:get_value(connection_per_path, Options, false) of
true -> Host ++ ":" ++ integer_to_list(Port) ++ Path; % Per-path
false -> Host ++ ":" ++ integer_to_list(Port) % Per-host (default)
end.

create_gun_connection(State, Host, Port, Path, HostKey, Options) ->
% Map HTTP version to Gun protocols, always include http as fallback
HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"),
Protocols = case HttpVersion of
"HTTP/2" -> [http2, http];
"HTTP/2.0" -> [http2, http];
"HTTP/1.1" -> [http];
"HTTP/1.0" -> [http];
_ -> [http2, http] % Default: try HTTP/2, fallback to HTTP/1.1
end,
ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000),
Opts = #{
transport => if Port == 443 -> tls; true -> tcp end,
protocols => Protocols,
connect_timeout => ConnectTimeout
},
application:ensure_all_started(gun),
case gun:open(Host, Port, Opts) of
{ok, ConnPid} ->
case gun:await_up(ConnPid, ConnectTimeout) of
{ok, _Protocol} ->
NewConnections = maps:put(HostKey, ConnPid, State#state.gun_connections),
NewState = State#state{gun_connections = NewConnections},
{ConnPid, NewState};
{error, Reason} ->
gun:close(ConnPid),
error({gun_connection_failed, Reason})
end;
{error, Reason} ->
error({gun_open_failed, Reason})
end.

parse_uri(URI) ->
case string:split(URI, "://", leading) of
[_Scheme, Rest] ->
case string:split(Rest, "/", leading) of
[HostPort] ->
{Host, Port} = parse_host_port(HostPort),
{Host, Port, "/"};
[HostPort, Path] ->
{Host, Port} = parse_host_port(HostPort),
{Host, Port, "/" ++ Path}
end
end.

parse_host_port(HostPort) ->
case string:split(HostPort, ":", trailing) of
[Host] ->
{Host, 443}; % Default HTTPS port
[Host, PortStr] ->
{Host, list_to_integer(PortStr)}
end.

status_text(200) -> "OK";
status_text(400) -> "Bad Request";
status_text(401) -> "Unauthorized";
status_text(403) -> "Forbidden";
status_text(404) -> "Not Found";
status_text(500) -> "Internal Server Error";
status_text(Code) -> integer_to_list(Code).
2 changes: 2 additions & 0 deletions deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
-include_lib("xmerl/include/xmerl.hrl").

-spec parse(Value :: string() | binary()) -> list().
parse(Value) when is_binary(Value) ->
parse(binary_to_list(Value));
parse(Value) ->
{Element, _} = xmerl_scan:string(Value),
parse_node(Element).
Expand Down