diff --git a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl index ab16d9ed49f4..8031d6fb9cb6 100644 --- a/deps/rabbitmq_aws/include/rabbitmq_aws.hrl +++ b/deps/rabbitmq_aws/include/rabbitmq_aws.hrl @@ -68,7 +68,9 @@ security_token :: security_token() | undefined, region :: region() | undefined, imdsv2_token:: imdsv2token() | undefined, - error :: atom() | string() | undefined}). + error :: atom() | string() | undefined, + gun_connections = #{} :: #{string() => pid()} % host -> gun_pid mapping + }). -type state() :: #state{}. -type scheme() :: atom(). diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws.erl b/deps/rabbitmq_aws/src/rabbitmq_aws.erl index 444121d76845..5f5bc6cf03af 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws.erl @@ -179,9 +179,12 @@ init([]) -> {ok, #state{}}. -terminate(_, _) -> - ok. - +terminate(_, State) -> + %% Close all Gun connections + maps:fold(fun(_Host, ConnPid, _Acc) -> + gun:close(ConnPid) + end, ok, State#state.gun_connections), + ok. code_change(_, _, State) -> {ok, State}. @@ -218,11 +221,15 @@ handle_msg({set_credentials, AccessKey, SecretAccessKey}, State) -> error = undefined}}; handle_msg({set_credentials, NewState}, State) -> - {reply, ok, State#state{access_key = NewState#state.access_key, - secret_access_key = NewState#state.secret_access_key, - security_token = NewState#state.security_token, - expiration = NewState#state.expiration, - error = NewState#state.error}}; + spawn(fun() -> maps:fold(fun(_Host, ConnPid, _Acc) -> + gun:close(ConnPid) + end, ok, State#state.gun_connections) end), + {reply, ok, State#state{access_key = NewState#state.access_key, + secret_access_key = NewState#state.secret_access_key, + security_token = NewState#state.security_token, + expiration = NewState#state.expiration, + error = NewState#state.error, + gun_connections = #{}}}; % Potentially new credentials, so clear the connection pool? handle_msg({set_region, Region}, State) -> {reply, ok, State#state{region = Region}}; @@ -292,7 +299,7 @@ get_content_type(Headers) -> proplists:get_value("Content-Type", Headers, "text/xml"); Other -> Other end, - parse_content_type(Value). + parse_content_type(Value). -spec has_credentials() -> boolean(). has_credentials() -> @@ -323,7 +330,7 @@ expired_credentials(Expiration) -> %% - Credentials file %% - EC2 Instance Metadata Service %% @end -load_credentials(#state{region = Region}) -> +load_credentials(#state{region = Region, gun_connections = GunConnections}) -> case rabbitmq_aws_config:credentials() of {ok, AccessKey, SecretAccessKey, Expiration, SecurityToken} -> {ok, #state{region = Region, @@ -332,7 +339,8 @@ load_credentials(#state{region = Region}) -> secret_access_key = SecretAccessKey, expiration = Expiration, security_token = SecurityToken, - imdsv2_token = undefined}}; + imdsv2_token = undefined, + gun_connections = GunConnections}}; {error, Reason} -> error_logger:error_msg("Could not load AWS credentials from environment variables, AWS_CONFIG_FILE, AWS_SHARED_CREDENTIALS_FILE or EC2 metadata endpoint: ~tp. Will depend on config settings to be set~n", [Reason]), {error, #state{region = Region, @@ -341,7 +349,8 @@ load_credentials(#state{region = Region}) -> secret_access_key = undefined, expiration = undefined, security_token = undefined, - imdsv2_token = undefined}} + imdsv2_token = undefined, + gun_connections = GunConnections}} end. @@ -382,7 +391,7 @@ parse_content_type(ContentType) -> %% @doc Make the API request and return the formatted response. %% @end perform_request(State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_has_creds(has_credentials(State), State, Service, Method, + perform_request_has_creds(has_credentials(State), State, Service, Method, Headers, Path, Body, Options, Host). @@ -396,7 +405,7 @@ perform_request(State, Service, Method, Headers, Path, Body, Options, Host) -> %% otherwise return an error result. %% @end perform_request_has_creds(true, State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_creds_expired(expired_credentials(State#state.expiration), State, + perform_request_creds_expired(expired_credentials(State#state.expiration), State, Service, Method, Headers, Path, Body, Options, Host); perform_request_has_creds(false, State, _, _, _, _, _, _, _) -> perform_request_creds_error(State). @@ -412,7 +421,7 @@ perform_request_has_creds(false, State, _, _, _, _, _, _, _) -> %% credentials before performing the request. %% @end perform_request_creds_expired(false, State, Service, Method, Headers, Path, Body, Options, Host) -> - perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host); + perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, Host); perform_request_creds_expired(true, State, _, _, _, _, _, _, _) -> perform_request_creds_error(State#state{error = "Credentials expired!"}). @@ -428,7 +437,7 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, URI = endpoint(State, Host, Service, Path), SignedHeaders = sign_headers(State, Service, Method, URI, Headers, Body), ContentType = proplists:get_value("content-type", SignedHeaders, undefined), - perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options). + perform_request_with_creds(State, Method, URI, SignedHeaders, ContentType, Body, Options). -spec perform_request_with_creds(State :: state(), Method :: method(), URI :: string(), @@ -439,13 +448,12 @@ perform_request_with_creds(State, Service, Method, Headers, Path, Body, Options, %% expired, perform the request and return the response. %% @end perform_request_with_creds(State, Method, URI, Headers, undefined, "", Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers}, Options1, []), - {format_response(Response), State}; + {Response, NewState} = gun_request(State, Method, URI, Headers, <<>>, Options0), + {format_response(Response), NewState}; perform_request_with_creds(State, Method, URI, Headers, ContentType, Body, Options0) -> - Options1 = ensure_timeout(Options0), - Response = httpc:request(Method, {URI, Headers, ContentType, Body}, Options1, []), - {format_response(Response), State}. + GunHeaders = [{"content-type", ContentType} | Headers], + {Response, NewState} = gun_request(State, Method, URI, GunHeaders, Body, Options0), + {format_response(Response), NewState}. -spec perform_request_creds_error(State :: state()) -> @@ -566,3 +574,128 @@ api_get_request_with_retries(Service, Path, Retries, WaitTimeBetweenRetries) -> timer:sleep(WaitTimeBetweenRetries), api_get_request_with_retries(Service, Path, Retries - 1, WaitTimeBetweenRetries) end. + +%% Gun HTTP client functions +gun_request(State, Method, URI, Headers, Body, Options) -> + {Host, Port, Path} = parse_uri(URI), + {ConnPid, NewState} = get_or_create_gun_connection(State, Host, Port, Path, Options), + Timeout = proplists:get_value(timeout, Options, ?DEFAULT_HTTP_TIMEOUT), + try + StreamRef = do_gun_request(ConnPid, Method, Path, Headers, Body), + case gun:await(ConnPid, StreamRef, Timeout) of + {response, fin, Status, RespHeaders} -> + Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, <<>>}}, + {Response, NewState}; + {response, nofin, Status, RespHeaders} -> + {ok, RespBody} = gun:await_body(ConnPid, StreamRef, Timeout), + Response = {ok, {{http_version, Status, status_text(Status)}, RespHeaders, RespBody}}, + {Response, NewState}; + {error, Reason} -> + {{error, Reason}, NewState} + end + catch + _:Error -> + % Connection failed, remove from pool and return error + HostKey = get_connection_key(Host, Port, Path, Options), + NewConnections = maps:remove(HostKey, NewState#state.gun_connections), + gun:close(ConnPid), + {{error, Error}, NewState#state{gun_connections = NewConnections}} + end. + +do_gun_request(ConnPid, get, Path, Headers, _Body) -> + gun:get(ConnPid, Path, Headers); +do_gun_request(ConnPid, post, Path, Headers, Body) -> + gun:post(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, put, Path, Headers, Body) -> + gun:put(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, head, Path, Headers, _Body) -> + gun:head(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, delete, Path, Headers, _Body) -> + gun:delete(ConnPid, Path, Headers, #{}); +do_gun_request(ConnPid, patch, Path, Headers, Body) -> + gun:patch(ConnPid, Path, Headers, Body, #{}); +do_gun_request(ConnPid, options, Path, Headers, _Body) -> + gun:options(ConnPid, Path, Headers, #{}). + +get_or_create_gun_connection(State, Host, Port, Path, Options) -> + HostKey = get_connection_key(Host, Port, Path, Options), + case maps:get(HostKey, State#state.gun_connections, undefined) of + undefined -> + create_gun_connection(State, Host, Port, Path, HostKey, Options); + ConnPid -> + case is_process_alive(ConnPid) andalso gun:info(ConnPid) =/= undefined of + true -> + {ConnPid, State}; + false -> + % Connection is dead, create new one + gun:close(ConnPid), + create_gun_connection(State, Host, Port, Path, HostKey, Options) + end + end. + +get_connection_key(Host, Port, Path, Options) -> + case proplists:get_value(connection_per_path, Options, false) of + true -> Host ++ ":" ++ integer_to_list(Port) ++ Path; % Per-path + false -> Host ++ ":" ++ integer_to_list(Port) % Per-host (default) + end. + +create_gun_connection(State, Host, Port, Path, HostKey, Options) -> + % Map HTTP version to Gun protocols, always include http as fallback + HttpVersion = proplists:get_value(version, Options, "HTTP/1.1"), + Protocols = case HttpVersion of + "HTTP/2" -> [http2, http]; + "HTTP/2.0" -> [http2, http]; + "HTTP/1.1" -> [http]; + "HTTP/1.0" -> [http]; + _ -> [http2, http] % Default: try HTTP/2, fallback to HTTP/1.1 + end, + ConnectTimeout = proplists:get_value(connect_timeout, Options, 5000), + Opts = #{ + transport => if Port == 443 -> tls; true -> tcp end, + protocols => Protocols, + connect_timeout => ConnectTimeout + }, + application:ensure_all_started(gun), + case gun:open(Host, Port, Opts) of + {ok, ConnPid} -> + case gun:await_up(ConnPid, ConnectTimeout) of + {ok, _Protocol} -> + NewConnections = maps:put(HostKey, ConnPid, State#state.gun_connections), + NewState = State#state{gun_connections = NewConnections}, + {ConnPid, NewState}; + {error, Reason} -> + gun:close(ConnPid), + error({gun_connection_failed, Reason}) + end; + {error, Reason} -> + error({gun_open_failed, Reason}) + end. + +parse_uri(URI) -> + case string:split(URI, "://", leading) of + [_Scheme, Rest] -> + case string:split(Rest, "/", leading) of + [HostPort] -> + {Host, Port} = parse_host_port(HostPort), + {Host, Port, "/"}; + [HostPort, Path] -> + {Host, Port} = parse_host_port(HostPort), + {Host, Port, "/" ++ Path} + end + end. + +parse_host_port(HostPort) -> + case string:split(HostPort, ":", trailing) of + [Host] -> + {Host, 443}; % Default HTTPS port + [Host, PortStr] -> + {Host, list_to_integer(PortStr)} + end. + +status_text(200) -> "OK"; +status_text(400) -> "Bad Request"; +status_text(401) -> "Unauthorized"; +status_text(403) -> "Forbidden"; +status_text(404) -> "Not Found"; +status_text(500) -> "Internal Server Error"; +status_text(Code) -> integer_to_list(Code). diff --git a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl index fc3be5c642a8..4787ea82f270 100644 --- a/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl +++ b/deps/rabbitmq_aws/src/rabbitmq_aws_xml.erl @@ -11,6 +11,8 @@ -include_lib("xmerl/include/xmerl.hrl"). -spec parse(Value :: string() | binary()) -> list(). +parse(Value) when is_binary(Value) -> + parse(binary_to_list(Value)); parse(Value) -> {Element, _} = xmerl_scan:string(Value), parse_node(Element).