diff --git a/elasticapm/instrumentation/packages/urllib3.py b/elasticapm/instrumentation/packages/urllib3.py index cc7206e83..070be49e8 100644 --- a/elasticapm/instrumentation/packages/urllib3.py +++ b/elasticapm/instrumentation/packages/urllib3.py @@ -63,6 +63,8 @@ def update_headers(args, kwargs, instance, transaction, trace_parent): """ from urllib3._version import __version__ as urllib3_version + print("update_headers", args, "kw headers", kwargs.get("headers"), instance) + if urllib3_version.startswith("2") and len(args) >= 5 and args[4]: headers = args[4].copy() args = tuple(itertools.chain((args[:4]), (headers,), args[5:])) diff --git a/tests/fixtures.py b/tests/fixtures.py index 94e89f961..f994c4548 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -303,6 +303,72 @@ def elasticapm_client_log_file(request): logger.setLevel(logging.NOTSET) +@pytest.fixture +def foohttpserver(request): + from pytest_localserver import http + + class FooServer(http.ContentServer): + def __init__(self, host="127.0.0.1", port=0, ssl_context=None): + super().__init__(host, port, ssl_context=ssl_context) + self.responses = [] + + def __call__(self, environ, start_response): + """ + This is the WSGI application. + """ + request = Request(environ) + self.requests.append(request) + + ## HACK: if we have responses pick a response from there + if self.responses: + + content, code, headers = self.responses.pop(0) + + response = http.Response(response=content, status=code) + response.headers.clear() + response.headers.extend(headers) + + return response(environ, start_response) + + if ( + request.content_type == "application/x-www-form-urlencoded" + and request.method == "POST" + and self.show_post_vars + ): + content = json.dumps(request.form) + else: + content = self.content + + response = http.Response(response=content, status=self.code) + response.headers.clear() + response.headers.extend(self.headers) + print(response) + + return response(environ, start_response) + + # HACK: a list of tuples (content, code, headers) + def serve_responses(self, responses): + for content, code, headers in responses: + if not isinstance(content, (str, bytes, list, tuple)): + # If content is an iterable which is not known to be a string, + # bytes, or sequence, it might be something that can only be iterated + # through once, in which case we need to cache it so it can be reused + # to handle multiple requests. + try: + content = tuple(iter(content)) + except TypeError: + # this probably means that content is not iterable, so just go + # ahead in case it's some type that Response knows how to handle + pass + self.responses.append((content, code, headers)) + + server = FooServer() + server.start() + request.addfinalizer(server.stop) + wait_for_open_port(server.server_address[1]) + return server + + @pytest.fixture() def waiting_httpserver(httpserver): wait_for_open_port(httpserver.server_address[1]) diff --git a/tests/instrumentation/requests_tests.py b/tests/instrumentation/requests_tests.py index 69983f147..b3ae4520f 100644 --- a/tests/instrumentation/requests_tests.py +++ b/tests/instrumentation/requests_tests.py @@ -34,6 +34,7 @@ import urllib.parse +import urllib3 import requests from requests.exceptions import InvalidURL, MissingSchema @@ -200,3 +201,50 @@ def test_url_sanitization(instrument, elasticapm_client, waiting_httpserver): assert "pass" not in span["context"]["http"]["url"] assert constants.MASK_URL in span["context"]["http"]["url"] + + +def test_requests_instrumentation_handles_retries(instrument, elasticapm_client, foohttpserver): + foohttpserver.serve_responses([("", 429, {}), ("", 200, {})]) + url = foohttpserver.url + "/hello_world" + parsed_url = urllib.parse.urlparse(url) + elasticapm_client.begin_transaction("transaction.test") + with capture_span("test_request", "test"): + retries = urllib3.Retry(status=1, status_forcelist=[429]) + s = requests.Session() + a = requests.adapters.HTTPAdapter(max_retries=retries) + s.mount("http://", a) + try: + s.get(url, allow_redirects=False) + except: + pass + elasticapm_client.end_transaction("MyView") + + transactions = elasticapm_client.events[TRANSACTION] + spans = elasticapm_client.spans_for_transaction(transactions[0]) + print(spans) + assert spans[0]["name"].startswith("GET 127.0.0.1:") + assert spans[0]["type"] == "external" + assert spans[0]["subtype"] == "http" + assert url == spans[0]["context"]["http"]["url"] + assert 200 == spans[0]["context"]["http"]["status_code"] + assert spans[0]["context"]["destination"]["service"] == { + "name": "", + "resource": "127.0.0.1:%d" % parsed_url.port, + "type": "", + } + assert spans[0]["context"]["service"]["target"]["type"] == "http" + assert spans[0]["context"]["service"]["target"]["name"] == f"127.0.0.1:{parsed_url.port}" + assert spans[0]["outcome"] == "success" + + assert constants.TRACEPARENT_HEADER_NAME in foohttpserver.requests[0].headers + trace_parent = TraceParent.from_string( + foohttpserver.requests[0].headers[constants.TRACEPARENT_HEADER_NAME], + tracestate_string=foohttpserver.requests[0].headers[constants.TRACESTATE_HEADER_NAME], + ) + assert trace_parent.trace_id == transactions[0]["trace_id"] + # Check that sample_rate was correctly placed in the tracestate + assert constants.TRACESTATE.SAMPLE_RATE in trace_parent.tracestate_dict + + # this should be the span id of `requests`, not of urllib3 + assert trace_parent.span_id == spans[0]["id"] + assert trace_parent.trace_options.recorded diff --git a/tests/instrumentation/urllib3_tests.py b/tests/instrumentation/urllib3_tests.py index 8cc21ceb0..2eee3639a 100644 --- a/tests/instrumentation/urllib3_tests.py +++ b/tests/instrumentation/urllib3_tests.py @@ -294,3 +294,50 @@ def test_instance_headers_are_respected( assert "kwargs" in request_headers if instance_headers and not (header_arg or header_kwarg): assert "instance" in request_headers + + +def test_urllib3_retries(instrument, elasticapm_client, foohttpserver): + foohttpserver.serve_responses([("", 429, {}), ("", 200, {})]) + url = foohttpserver.url + "/hello_world" + parsed_url = urllib.parse.urlparse(url) + elasticapm_client.begin_transaction("transaction") + expected_sig = "GET {0}".format(parsed_url.netloc) + with capture_span("test_name", "test_type"): + retries = urllib3.Retry(status=1, status_forcelist=[429]) + pool = urllib3.PoolManager(timeout=0.1, retries=retries) + + url = "http://{0}/hello_world".format(parsed_url.netloc) + try: + r = pool.request("GET", url) + except: + pass + + elasticapm_client.end_transaction("MyView") + + transactions = elasticapm_client.events[TRANSACTION] + spans = elasticapm_client.spans_for_transaction(transactions[0]) + + expected_signatures = {"test_name", expected_sig} + + assert {t["name"] for t in spans} == expected_signatures + + assert len(spans) == 2 + + assert spans[0]["name"] == expected_sig + assert spans[0]["type"] == "external" + assert spans[0]["subtype"] == "http" + assert spans[0]["context"]["http"]["url"] == url + assert spans[0]["context"]["http"]["status_code"] == 200 + assert spans[0]["context"]["destination"]["service"] == { + "name": "", + "resource": "127.0.0.1:%d" % parsed_url.port, + "type": "", + } + assert spans[0]["context"]["service"]["target"]["type"] == "http" + assert spans[0]["context"]["service"]["target"]["name"] == f"127.0.0.1:{parsed_url.port}" + assert spans[0]["parent_id"] == spans[1]["id"] + assert spans[0]["outcome"] == "success" + + assert spans[1]["name"] == "test_name" + assert spans[1]["type"] == "test_type" + assert spans[1]["parent_id"] == transactions[0]["id"]