diff --git a/ddtrace/_monkey.py b/ddtrace/_monkey.py index 7223dc59ffe..6c9efdb6242 100644 --- a/ddtrace/_monkey.py +++ b/ddtrace/_monkey.py @@ -1,6 +1,5 @@ import importlib import os -import threading from types import ModuleType from typing import TYPE_CHECKING # noqa:F401 from typing import Union @@ -9,6 +8,7 @@ from ddtrace.appsec._listeners import load_common_appsec_modules from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE +from ddtrace.internal.threads import Lock from ddtrace.settings._config import config from ddtrace.settings.asm import config as asm_config from ddtrace.vendor.debtcollector import deprecate @@ -130,7 +130,7 @@ } -_LOCK = threading.Lock() +_LOCK = Lock() _PATCHED_MODULES = set() # Module names that need to be patched for a given integration. If the module diff --git a/ddtrace/_trace/context.py b/ddtrace/_trace/context.py index 80f705c9122..cd6f4e86708 100644 --- a/ddtrace/_trace/context.py +++ b/ddtrace/_trace/context.py @@ -1,6 +1,5 @@ import base64 import re -import threading from typing import Any from typing import Dict from typing import List @@ -19,6 +18,7 @@ from ddtrace.internal.constants import W3C_TRACEPARENT_KEY from ddtrace.internal.constants import W3C_TRACESTATE_KEY from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import RLock from ddtrace.internal.utils.http import w3c_get_dd_list_member as _w3c_get_dd_list_member @@ -65,7 +65,7 @@ def __init__( sampling_priority: Optional[float] = None, meta: Optional[_MetaDictType] = None, metrics: Optional[_MetricDictType] = None, - lock: Optional[threading.RLock] = None, + lock: Optional[RLock] = None, span_links: Optional[List[SpanLink]] = None, baggage: Optional[Dict[str, Any]] = None, is_remote: bool = True, @@ -91,10 +91,7 @@ def __init__( if lock is not None: self._lock = lock else: - # DEV: A `forksafe.RLock` is not necessary here since Contexts - # are recreated by the tracer after fork - # https://github.com/DataDog/dd-trace-py/blob/a1932e8ddb704d259ea8a3188d30bf542f59fd8d/ddtrace/tracer.py#L489-L508 - self._lock = threading.RLock() + self._lock = RLock() def __getstate__(self) -> _ContextState: return ( @@ -121,7 +118,7 @@ def __setstate__(self, state: _ContextState) -> None: self._reactivate, ) = state # We cannot serialize and lock, so we must recreate it unless we already have one - self._lock = threading.RLock() + self._lock = RLock() def __enter__(self) -> "Context": self._lock.acquire() diff --git a/ddtrace/_trace/processor/__init__.py b/ddtrace/_trace/processor/__init__.py index 7fb655b313b..103e5e5886b 100644 --- a/ddtrace/_trace/processor/__init__.py +++ b/ddtrace/_trace/processor/__init__.py @@ -1,7 +1,6 @@ import abc from collections import defaultdict from itertools import chain -from threading import RLock from typing import Any from typing import DefaultDict from typing import Dict @@ -29,6 +28,7 @@ from ddtrace.internal.service import ServiceStatusError from ddtrace.internal.telemetry.constants import TELEMETRY_LOG_LEVEL from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE +from ddtrace.internal.threads import RLock from ddtrace.internal.writer import AgentResponse from ddtrace.internal.writer import create_trace_writer from ddtrace.settings._config import config @@ -280,7 +280,7 @@ def __init__( self.writer = create_trace_writer(response_callback=self._agent_response_callback) # Initialize the trace buffer and lock self._traces: DefaultDict[int, _Trace] = defaultdict(lambda: _Trace()) - self._lock: RLock = RLock() + self._lock = RLock() # Track telemetry span metrics by span api # ex: otel api, opentracing api, datadog api self._span_metrics: Dict[str, DefaultDict] = { diff --git a/ddtrace/_trace/tracer.py b/ddtrace/_trace/tracer.py index 53c49f9f44c..6979933b879 100644 --- a/ddtrace/_trace/tracer.py +++ b/ddtrace/_trace/tracer.py @@ -6,7 +6,6 @@ import logging import os from os import getpid -from threading import RLock from typing import Any from typing import Callable from typing import Dict @@ -53,6 +52,7 @@ from ddtrace.internal.processor.endpoint_call_counter import EndpointCallCounterProcessor from ddtrace.internal.runtime import get_runtime_id from ddtrace.internal.schema.processor import BaseServiceProcessor +from ddtrace.internal.threads import RLock from ddtrace.internal.utils import _get_metas_to_propagate from ddtrace.internal.utils.formats import format_trace_id from ddtrace.internal.writer import AgentWriterInterface diff --git a/ddtrace/appsec/_iast/_overhead_control_engine.py b/ddtrace/appsec/_iast/_overhead_control_engine.py index 5d7b3377c3d..8d0654c1660 100644 --- a/ddtrace/appsec/_iast/_overhead_control_engine.py +++ b/ddtrace/appsec/_iast/_overhead_control_engine.py @@ -3,11 +3,12 @@ limit. It will measure operations being executed in a request and it will deactivate detection (and therefore reduce the overhead to nearly 0) if a certain threshold is reached. """ + from ddtrace._trace.sampler import RateSampler from ddtrace._trace.span import Span from ddtrace.appsec._iast._utils import _is_iast_debug_enabled -from ddtrace.internal._unpatched import _threading as threading from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import Lock from ddtrace.settings.asm import config as asm_config @@ -24,7 +25,7 @@ class OverheadControl(object): The goal is to do sampling at different levels of the IAST analysis (per process, per request, etc) """ - _lock = threading.Lock() + _lock = Lock() _request_quota = asm_config._iast_max_concurrent_requests _sampler = RateSampler(sample_rate=get_request_sampling_value() / 100.0) diff --git a/ddtrace/contrib/internal/subprocess/patch.py b/ddtrace/contrib/internal/subprocess/patch.py index d4394538813..2ae2415a189 100644 --- a/ddtrace/contrib/internal/subprocess/patch.py +++ b/ddtrace/contrib/internal/subprocess/patch.py @@ -19,8 +19,8 @@ from ddtrace.contrib.internal.subprocess.constants import COMMANDS from ddtrace.ext import SpanTypes from ddtrace.internal import core -from ddtrace.internal.forksafe import RLock from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import RLock from ddtrace.settings._config import config from ddtrace.settings.asm import config as asm_config diff --git a/ddtrace/debugging/_encoding.py b/ddtrace/debugging/_encoding.py index 1a53bea93fe..24b09f62a86 100644 --- a/ddtrace/debugging/_encoding.py +++ b/ddtrace/debugging/_encoding.py @@ -18,9 +18,9 @@ from ddtrace.debugging._config import di_config from ddtrace.debugging._signal.log import LogSignal from ddtrace.debugging._signal.snapshot import Snapshot -from ddtrace.internal import forksafe from ddtrace.internal._encoding import BufferFull from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import Lock from ddtrace.internal.utils.formats import format_trace_id @@ -310,7 +310,7 @@ def __init__( ) -> None: self._encoder = encoder self._buffer = JsonBuffer(buffer_size) - self._lock = forksafe.Lock() + self._lock = Lock() self._on_full = on_full self.count = 0 self.max_size = buffer_size - self._buffer.size diff --git a/ddtrace/debugging/_probe/registry.py b/ddtrace/debugging/_probe/registry.py index 83a56ad40f9..3436afa4c95 100644 --- a/ddtrace/debugging/_probe/registry.py +++ b/ddtrace/debugging/_probe/registry.py @@ -8,8 +8,8 @@ from ddtrace.debugging._probe.model import Probe from ddtrace.debugging._probe.model import ProbeLocationMixin from ddtrace.debugging._probe.status import ProbeStatusLogger -from ddtrace.internal import forksafe from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import RLock logger = get_logger(__name__) @@ -68,7 +68,7 @@ def __init__(self, status_logger: ProbeStatusLogger, *args: Any, **kwargs: Any) # Used to keep track of probes pending installation self._pending: Dict[str, List[Probe]] = defaultdict(list) - self._lock = forksafe.RLock() + self._lock = RLock() def register(self, *probes: Probe) -> None: """Register a probe.""" diff --git a/ddtrace/internal/_encoding.pyx b/ddtrace/internal/_encoding.pyx index 72b1b173b90..a9efbc93576 100644 --- a/ddtrace/internal/_encoding.pyx +++ b/ddtrace/internal/_encoding.pyx @@ -4,7 +4,6 @@ from libc cimport stdint from libc.string cimport strlen from json import dumps as json_dumps -import threading from json import dumps as json_dumps from ._utils cimport PyBytesLike_Check @@ -26,6 +25,8 @@ from .constants import MAX_UINT_64BITS from .._trace._limits import MAX_SPAN_META_VALUE_LEN from .._trace._limits import TRUNCATED_SPAN_ATTRIBUTE_LEN from ..settings._agent import config as agent_config +from ddtrace.internal.threads import Lock +from ddtrace.internal.threads import RLock DEF MSGPACK_ARRAY_LENGTH_PREFIX_SIZE = 5 @@ -256,7 +257,7 @@ cdef class MsgpackStringTable(StringTable): self.max_size = max_size self.pk.length = MSGPACK_STRING_TABLE_LENGTH_PREFIX_SIZE self._sp_len = 0 - self._lock = threading.RLock() + self._lock = RLock() super(MsgpackStringTable, self).__init__() self.index(ORIGIN_KEY) @@ -371,7 +372,7 @@ cdef class BufferedEncoder(object): def __cinit__(self, size_t max_size, size_t max_item_size): self.max_size = max_size self.max_item_size = max_item_size - self._lock = threading.Lock() + self._lock = Lock() # ---- Abstract methods ---- @@ -443,7 +444,7 @@ cdef class MsgpackEncoderBase(BufferedEncoder): self.max_size = max_size self.pk.buf_size = buf_size self.max_item_size = max_item_size if max_item_size < max_size else max_size - self._lock = threading.RLock() + self._lock = RLock() self._reset_buffer() def __dealloc__(self): diff --git a/ddtrace/internal/_threads.cpp b/ddtrace/internal/_threads.cpp index d775544827b..ace50d912ee 100644 --- a/ddtrace/internal/_threads.cpp +++ b/ddtrace/internal/_threads.cpp @@ -85,6 +85,10 @@ class PyRef PyObject* _obj; }; +// ---------------------------------------------------------------------------- + +#include "_threads/lock.hpp" + // ---------------------------------------------------------------------------- class Event { @@ -511,6 +515,7 @@ static PyTypeObject PeriodicThreadType = { // ---------------------------------------------------------------------------- static PyMethodDef _threads_methods[] = { + { "reset_locks", (PyCFunction)lock_reset_locks, METH_NOARGS, "Reset all locks (generally after a fork)" }, { NULL, NULL, 0, NULL } /* Sentinel */ }; @@ -533,6 +538,12 @@ PyInit__threads(void) if (PyType_Ready(&PeriodicThreadType) < 0) return NULL; + if (PyType_Ready(&LockType) < 0) + return NULL; + + if (PyType_Ready(&RLockType) < 0) + return NULL; + _periodic_threads = PyDict_New(); if (_periodic_threads == NULL) return NULL; @@ -541,6 +552,7 @@ PyInit__threads(void) if (m == NULL) goto error; + // Periodic thread Py_INCREF(&PeriodicThreadType); if (PyModule_AddObject(m, "PeriodicThread", (PyObject*)&PeriodicThreadType) < 0) { Py_DECREF(&PeriodicThreadType); @@ -550,6 +562,20 @@ PyInit__threads(void) if (PyModule_AddObject(m, "periodic_threads", _periodic_threads) < 0) goto error; + // Lock + Py_INCREF(&LockType); + if (PyModule_AddObject(m, "Lock", (PyObject*)&LockType) < 0) { + Py_DECREF(&LockType); + goto error; + } + + // RLock + Py_INCREF(&RLockType); + if (PyModule_AddObject(m, "RLock", (PyObject*)&RLockType) < 0) { + Py_DECREF(&RLockType); + goto error; + } + return m; error: diff --git a/ddtrace/internal/_threads.pyi b/ddtrace/internal/_threads.pyi index 84a6ca0eb3f..33f5b4c5097 100644 --- a/ddtrace/internal/_threads.pyi +++ b/ddtrace/internal/_threads.pyi @@ -1,5 +1,20 @@ import typing as t +class _BaseLock: + def __init__(self, reentrant: bool = False) -> None: ... + def acquire(self, timeout: t.Optional[float] = None) -> bool: ... + def release(self) -> None: ... + def locked(self) -> bool: ... + def __enter__(self) -> None: ... + def __exit__(self, exc_type, exc_value, traceback) -> t.Literal[False]: ... + +class Lock(_BaseLock): ... +class RLock(_BaseLock): ... + +def reset_locks() -> None: ... +def begin_reset_locks() -> None: ... +def end_reset_locks() -> None: ... + class PeriodicThread: name: str ident: int diff --git a/ddtrace/internal/_threads/lock.hpp b/ddtrace/internal/_threads/lock.hpp new file mode 100644 index 00000000000..20b75cad572 --- /dev/null +++ b/ddtrace/internal/_threads/lock.hpp @@ -0,0 +1,403 @@ +#pragma once + +#define PY_SSIZE_T_CLEAN +#include + +#include +#include +#include + +std::mutex _lock_set_mutex; + +// ---------------------------------------------------------------------------- +// Lock class +// ---------------------------------------------------------------------------- + +typedef struct lock +{ + PyObject_HEAD + + std::atomic + _locked = 0; + + std::unique_ptr _mutex = nullptr; +} Lock; + +std::set lock_set; // Global set of locks for reset after fork + +// ---------------------------------------------------------------------------- +static int +Lock_init(Lock* self, PyObject* args, PyObject* kwargs) +{ + self->_mutex = std::make_unique(); + + // Register the lock for reset after fork + { + std::lock_guard guard(_lock_set_mutex); + + lock_set.insert(self); + } + + return 0; +} + +// ---------------------------------------------------------------------------- +static inline void +_Lock_maybe_leak(Lock* self) +{ + // This function is used to ensure that the mutex is not leaked if it is + // still locked when the lock object is deallocated. + if (self->_locked) { + self->_mutex.release(); // DEV: This releases the unique_ptr, not the mutex! + } +} + +// ---------------------------------------------------------------------------- +static void +Lock_dealloc(Lock* self) +{ + // Unregister the lock from the global set + { + std::lock_guard guard(_lock_set_mutex); + + lock_set.erase(self); + } + + _Lock_maybe_leak(self); + + self->_mutex = nullptr; + + Py_TYPE(self)->tp_free((PyObject*)self); +} + +// ---------------------------------------------------------------------------- +static PyObject* +Lock_acquire(Lock* self, PyObject* args, PyObject* kwargs) +{ + // Get timeout argument + static const char* kwlist[] = { "timeout", NULL }; + PyObject* timeout = Py_None; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", (char**)kwlist, &timeout)) { + return NULL; + } + + if (timeout == Py_None) { + AllowThreads _; + + self->_mutex->lock(); + } else { + double timeout_value = 0.0; + if (PyFloat_Check(timeout)) { + timeout_value = PyFloat_AsDouble(timeout); + } else if (PyLong_Check(timeout)) { + timeout_value = PyLong_AsDouble(timeout); + } else { + PyErr_SetString(PyExc_TypeError, "timeout must be a float or an int"); + return NULL; + } + + AllowThreads _; + + if (!self->_mutex->try_lock_for(std::chrono::milliseconds((long long)(timeout_value * 1000)))) { + Py_RETURN_FALSE; + } + } + + self->_locked = 1; + + Py_RETURN_TRUE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +Lock_release(Lock* self) +{ + if (self->_locked <= 0) { + PyErr_SetString(PyExc_RuntimeError, "Lock is not acquired"); + return NULL; + } + + self->_mutex->unlock(); + self->_locked = 0; // Reset the lock state + + Py_RETURN_NONE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +Lock_locked(Lock* self) +{ + if (self->_locked > 0) { + Py_RETURN_TRUE; + } + + Py_RETURN_FALSE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +Lock_enter(Lock* self, PyObject* args, PyObject* kwargs) +{ + AllowThreads _; + + self->_mutex->lock(); + + self->_locked = 1; + + Py_RETURN_NONE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +Lock_exit(Lock* self, PyObject* args, PyObject* kwargs) +{ + // This method is called when the lock is used in a "with" statement + if (Lock_release(self) == NULL) { + return NULL; // Propagate any error from release + } + + Py_RETURN_FALSE; +} + +static inline void +Lock_reset(Lock* self) +{ + _Lock_maybe_leak(self); + self->_mutex = std::make_unique(); + self->_locked = 0; +} + +// ---------------------------------------------------------------------------- +static PyMethodDef Lock_methods[] = { + { "acquire", (PyCFunction)Lock_acquire, METH_VARARGS | METH_KEYWORDS, "Acquire the lock with an optional timeout" }, + { "release", (PyCFunction)Lock_release, METH_NOARGS, "Release the lock" }, + { "locked", (PyCFunction)Lock_locked, METH_NOARGS, "Return whether the lock is acquired" }, + { "__enter__", (PyCFunction)Lock_enter, METH_NOARGS, "Enter the lock context" }, + { "__exit__", (PyCFunction)Lock_exit, METH_VARARGS | METH_KEYWORDS, "Exit the lock context" }, + { NULL } /* Sentinel */ +}; + +// ---------------------------------------------------------------------------- +static PyMemberDef Lock_members[] = { + { NULL } /* Sentinel */ +}; + +// ---------------------------------------------------------------------------- +static PyTypeObject LockType = { + .ob_base = PyVarObject_HEAD_INIT(NULL, 0).tp_name = "ddtrace.internal._threads.Lock", + .tp_basicsize = sizeof(Lock), + .tp_itemsize = 0, + .tp_dealloc = (destructor)Lock_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = PyDoc_STR("Native lock implementation"), + .tp_methods = Lock_methods, + .tp_members = Lock_members, + .tp_init = (initproc)Lock_init, + .tp_new = PyType_GenericNew, +}; + +// ---------------------------------------------------------------------------- +// RLock class +// ---------------------------------------------------------------------------- + +typedef struct rlock +{ + PyObject_HEAD + + std::atomic + _locked = 0; + + std::unique_ptr _mutex = nullptr; +} RLock; + +std::set rlock_set; // Global set of re-entrant locks for reset after fork + +// ---------------------------------------------------------------------------- +static int +RLock_init(RLock* self, PyObject* args, PyObject* kwargs) +{ + self->_mutex = std::make_unique(); + + // Register the re-entrant lock for reset after fork + { + std::lock_guard guard(_lock_set_mutex); + + rlock_set.insert(self); + } + + return 0; +} + +// ---------------------------------------------------------------------------- +static inline void +_RLock_maybe_leak(RLock* self) +{ + // This function is used to ensure that the mutex is not leaked if it is + // still locked when the re-entrant lock object is deallocated. + if (self->_locked) { + self->_mutex.release(); // DEV: This releases the unique_ptr, not the mutex! + } +} + +// ---------------------------------------------------------------------------- +static void +RLock_dealloc(RLock* self) +{ + { + std::lock_guard guard(_lock_set_mutex); + + rlock_set.erase(self); + } + + _RLock_maybe_leak(self); + self->_mutex = nullptr; + + Py_TYPE(self)->tp_free((PyObject*)self); +} + +// ---------------------------------------------------------------------------- +static PyObject* +RLock_acquire(RLock* self, PyObject* args, PyObject* kwargs) +{ + // Get timeout argument + static const char* kwlist[] = { "timeout", NULL }; + PyObject* timeout = Py_None; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", (char**)kwlist, &timeout)) { + return NULL; + } + + if (timeout == Py_None) { + AllowThreads _; + + self->_mutex->lock(); + } else { + double timeout_value = 0.0; + if (PyFloat_Check(timeout)) { + timeout_value = PyFloat_AsDouble(timeout); + } else if (PyLong_Check(timeout)) { + timeout_value = PyLong_AsDouble(timeout); + } else { + PyErr_SetString(PyExc_TypeError, "timeout must be a float or an int"); + return NULL; + } + + AllowThreads _; + + if (!self->_mutex->try_lock_for(std::chrono::milliseconds((long long)(timeout_value * 1000)))) { + Py_RETURN_FALSE; + } + } + + self->_locked++; + + Py_RETURN_TRUE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +RLock_release(RLock* self) +{ + if (self->_locked <= 0) { + PyErr_SetString(PyExc_RuntimeError, "Lock is not acquired"); + return NULL; + } + + self->_mutex->unlock(); + self->_locked--; + + Py_RETURN_NONE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +RLock_locked(RLock* self) +{ + if (self->_locked > 0) { + Py_RETURN_TRUE; + } + + Py_RETURN_FALSE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +RLock_enter(RLock* self, PyObject* args, PyObject* kwargs) +{ + AllowThreads _; + + self->_mutex->lock(); + + self->_locked++; + + Py_RETURN_NONE; +} + +// ---------------------------------------------------------------------------- +static PyObject* +RLock_exit(RLock* self, PyObject* args, PyObject* kwargs) +{ + // This method is called when the lock is used in a "with" statement + if (RLock_release(self) == NULL) { + return NULL; // Propagate any error from release + } + + Py_RETURN_FALSE; +} + +static inline void +RLock_reset(RLock* self) +{ + _RLock_maybe_leak(self); + self->_mutex = std::make_unique(); + self->_locked = 0; +} + +// ---------------------------------------------------------------------------- +static PyMethodDef RLock_methods[] = { + { "acquire", + (PyCFunction)RLock_acquire, + METH_VARARGS | METH_KEYWORDS, + "Acquire the lock with an optional timeout" }, + { "release", (PyCFunction)RLock_release, METH_NOARGS, "Release the lock" }, + { "locked", (PyCFunction)RLock_locked, METH_NOARGS, "Return whether the lock is acquired at least once" }, + { "__enter__", (PyCFunction)RLock_enter, METH_NOARGS, "Enter the lock context" }, + { "__exit__", (PyCFunction)RLock_exit, METH_VARARGS | METH_KEYWORDS, "Exit the lock context" }, + { NULL } /* Sentinel */ +}; + +// ---------------------------------------------------------------------------- +static PyMemberDef RLock_members[] = { + { NULL } /* Sentinel */ +}; + +// ---------------------------------------------------------------------------- +static PyTypeObject RLockType = { + .ob_base = PyVarObject_HEAD_INIT(NULL, 0).tp_name = "ddtrace.internal._threads.RLock", + .tp_basicsize = sizeof(RLock), + .tp_itemsize = 0, + .tp_dealloc = (destructor)RLock_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = PyDoc_STR("Native re-entrant lock implementation"), + .tp_methods = RLock_methods, + .tp_members = RLock_members, + .tp_init = (initproc)RLock_init, + .tp_new = PyType_GenericNew, +}; + +// ---------------------------------------------------------------------------- +static PyObject* +lock_reset_locks(PyObject* Py_UNUSED(self), PyObject* Py_UNUSED(args)) +{ + // Reset all locks that have been registered for reset after a fork. This + // MUST be called in a single-thread scenario only, e.g. soon after the + // fork. + for (Lock* lock : lock_set) { + Lock_reset(lock); + } + + for (RLock* rlock : rlock_set) { + RLock_reset(rlock); + } + + Py_RETURN_NONE; +} diff --git a/ddtrace/internal/ci_visibility/encoder.py b/ddtrace/internal/ci_visibility/encoder.py index 8b8acec2eef..4f31df7be23 100644 --- a/ddtrace/internal/ci_visibility/encoder.py +++ b/ddtrace/internal/ci_visibility/encoder.py @@ -1,6 +1,5 @@ import json import os -import threading from typing import TYPE_CHECKING # noqa:F401 from uuid import uuid4 @@ -21,6 +20,7 @@ from ddtrace.internal.ci_visibility.telemetry.payload import record_endpoint_payload_events_serialization_time from ddtrace.internal.encoding import JSONEncoderV2 from ddtrace.internal.logger import get_logger +from ddtrace.internal.threads import RLock from ddtrace.internal.utils.time import StopWatch from ddtrace.internal.writer.writer import NoEncodableSpansError @@ -48,7 +48,7 @@ def __init__(self, *args): # DEV: args are not used here, but are used by BufferedEncoder's __cinit__() method, # which is called implicitly by Cython. super(CIVisibilityEncoderV01, self).__init__() - self._lock = threading.RLock() + self._lock = RLock() self._metadata = {} self._init_buffer() diff --git a/ddtrace/internal/datastreams/processor.py b/ddtrace/internal/datastreams/processor.py index 511d3543e0f..d678028789e 100644 --- a/ddtrace/internal/datastreams/processor.py +++ b/ddtrace/internal/datastreams/processor.py @@ -19,6 +19,7 @@ from ddtrace.internal.atexit import register_on_exit_signal from ddtrace.internal.constants import DEFAULT_SERVICE_NAME from ddtrace.internal.native import DDSketch +from ddtrace.internal.threads import Lock from ddtrace.internal.utils.retry import fibonacci_backoff_with_jitter from ddtrace.settings._agent import config as agent_config from ddtrace.settings._config import config @@ -26,7 +27,6 @@ from .._encoding import packb from ..agent import get_connection -from ..forksafe import Lock from ..hostname import get_hostname from ..logger import get_logger from ..periodic import PeriodicService diff --git a/ddtrace/internal/datastreams/schemas/schema_sampler.py b/ddtrace/internal/datastreams/schemas/schema_sampler.py index b9ff0a613b8..50c3036cee9 100644 --- a/ddtrace/internal/datastreams/schemas/schema_sampler.py +++ b/ddtrace/internal/datastreams/schemas/schema_sampler.py @@ -1,4 +1,4 @@ -import threading +from ddtrace.internal.threads import Lock class SchemaSampler: @@ -7,7 +7,7 @@ class SchemaSampler: def __init__(self): self.weight = 0 self.last_sample_millis = 0 - self.lock = threading.Lock() + self.lock = Lock() def try_sample(self, current_time_millis): if current_time_millis >= self.last_sample_millis + self.SAMPLE_INTERVAL_MILLIS: diff --git a/ddtrace/internal/forksafe.py b/ddtrace/internal/forksafe.py index 1f9d24cfd7e..2981aa45d66 100644 --- a/ddtrace/internal/forksafe.py +++ b/ddtrace/internal/forksafe.py @@ -133,21 +133,10 @@ def __init__( self._self_wrapped_class = wrapped_class _resetable_objects.add(self) - def _reset_object(self): - # type: (...) -> None + def _reset_object(self) -> None: self.__wrapped__ = self._self_wrapped_class() -def Lock(): - # type: (...) -> ResetObject[threading.Lock] - return ResetObject(threading.Lock) - - -def RLock(): - # type: (...) -> ResetObject[threading.RLock] - return ResetObject(threading.RLock) - - def Event(): # type: (...) -> ResetObject[threading.Event] return ResetObject(threading.Event) diff --git a/ddtrace/internal/periodic.py b/ddtrace/internal/periodic.py index d4c9d7e3d4c..2440e1601c6 100644 --- a/ddtrace/internal/periodic.py +++ b/ddtrace/internal/periodic.py @@ -1,31 +1,9 @@ # -*- encoding: utf-8 -*- -import atexit import typing # noqa:F401 from ddtrace.internal import forksafe from ddtrace.internal import service -from ddtrace.internal._threads import PeriodicThread -from ddtrace.internal._threads import periodic_threads - - -@atexit.register -def _(): - # If the interpreter is shutting down we need to make sure that the threads - # are stopped before the runtime is marked as finalising. This is because - # any attempt to acquire the GIL while the runtime is finalising will cause - # the acquiring thread to be terminated with pthread_exit (on Linux). This - # causes a SIGABRT with GCC that cannot be caught, so we need to avoid - # getting to that stage. - for thread in periodic_threads.values(): - thread._atexit() - - -@forksafe.register -def _(): - # No threads are running after a fork so we clean up the periodic threads - for thread in periodic_threads.values(): - thread._after_fork() - periodic_threads.clear() +from ddtrace.internal.threads import PeriodicThread class PeriodicService(service.Service): diff --git a/ddtrace/internal/processor/endpoint_call_counter.py b/ddtrace/internal/processor/endpoint_call_counter.py index f00b16528c6..825155e9d1c 100644 --- a/ddtrace/internal/processor/endpoint_call_counter.py +++ b/ddtrace/internal/processor/endpoint_call_counter.py @@ -5,8 +5,8 @@ from ddtrace._trace.processor import SpanProcessor from ddtrace._trace.span import Span # noqa:F401 from ddtrace.ext import SpanTypes -from ddtrace.internal import forksafe from ddtrace.internal.compat import ensure_text +from ddtrace.internal.threads import Lock EndpointCountsType = typing.Dict[str, int] @@ -21,9 +21,7 @@ class EndpointCallCounterProcessor(SpanProcessor): endpoint_to_span_ids: typing.Dict[str, typing.List[int]] = field( default_factory=dict, init=False, repr=False, compare=False ) - _endpoint_counts_lock: typing.ContextManager = field( - default_factory=forksafe.Lock, init=False, repr=False, compare=False - ) + _endpoint_counts_lock: typing.ContextManager = field(default_factory=Lock, init=False, repr=False, compare=False) _enabled: bool = field(default=False, repr=False, compare=False) def enable(self): diff --git a/ddtrace/internal/processor/stats.py b/ddtrace/internal/processor/stats.py index 45e4cd533f7..8f361a1ba9a 100644 --- a/ddtrace/internal/processor/stats.py +++ b/ddtrace/internal/processor/stats.py @@ -12,6 +12,7 @@ from ddtrace._trace.span import Span from ddtrace.internal import compat from ddtrace.internal.native import DDSketch +from ddtrace.internal.threads import Lock from ddtrace.internal.utils.retry import fibonacci_backoff_with_jitter from ddtrace.settings._config import config from ddtrace.version import get_version @@ -19,7 +20,6 @@ from ...constants import _SPAN_MEASURED_KEY from .. import agent from .._encoding import packb -from ..forksafe import Lock from ..hostname import get_hostname from ..logger import get_logger from ..periodic import PeriodicService diff --git a/ddtrace/internal/rate_limiter.py b/ddtrace/internal/rate_limiter.py index 9b514e5ff32..d58f9332a9d 100644 --- a/ddtrace/internal/rate_limiter.py +++ b/ddtrace/internal/rate_limiter.py @@ -3,12 +3,13 @@ from dataclasses import dataclass from dataclasses import field import random -import threading import time from typing import Any # noqa:F401 from typing import Callable # noqa:F401 from typing import Optional # noqa:F401 +from ddtrace.internal.threads import Lock + class RateLimiter(object): """ @@ -52,7 +53,7 @@ def __init__(self, rate_limit: int, time_window: float = 1e9): self.tokens_total = 0 self.prev_window_rate = None # type: Optional[float] - self._lock = threading.Lock() + self._lock = Lock() def is_allowed(self) -> bool: """ @@ -202,7 +203,7 @@ class BudgetRateLimiterWithJitter: budget: float = field(init=False) max_budget: float = field(init=False) last_time: float = field(init=False, default_factory=time.monotonic) - _lock: threading.Lock = field(init=False, default_factory=threading.Lock) + _lock: Lock = field(init=False, default_factory=Lock) def __post_init__(self): if self.limit_rate == float("inf"): diff --git a/ddtrace/internal/runtime/runtime_metrics.py b/ddtrace/internal/runtime/runtime_metrics.py index 124b97ae262..0747e702213 100644 --- a/ddtrace/internal/runtime/runtime_metrics.py +++ b/ddtrace/internal/runtime/runtime_metrics.py @@ -8,6 +8,7 @@ from ddtrace.internal import atexit from ddtrace.internal import forksafe from ddtrace.internal.constants import EXPERIMENTAL_FEATURES +from ddtrace.internal.threads import Lock from ddtrace.vendor.debtcollector import deprecate from ddtrace.vendor.dogstatsd import DogStatsd @@ -84,7 +85,7 @@ class RuntimeWorker(periodic.PeriodicService): enabled = False _instance = None # type: ClassVar[Optional[RuntimeWorker]] - _lock = forksafe.Lock() + _lock = Lock() def __init__(self, interval=_get_interval_or_default(), tracer=None, dogstatsd_url=None) -> None: super().__init__(interval=interval) diff --git a/ddtrace/internal/service.py b/ddtrace/internal/service.py index 20d82f8a192..7651e434ea3 100644 --- a/ddtrace/internal/service.py +++ b/ddtrace/internal/service.py @@ -2,7 +2,7 @@ import enum import typing # noqa:F401 -from . import forksafe +from ddtrace.internal.threads import Lock class ServiceStatus(enum.Enum): @@ -30,7 +30,7 @@ class Service(metaclass=abc.ABCMeta): def __init__(self) -> None: self.status: ServiceStatus = ServiceStatus.STOPPED - self._service_lock: typing.ContextManager = forksafe.Lock() + self._service_lock: typing.ContextManager = Lock() def __repr__(self): class_name = self.__class__.__name__ diff --git a/ddtrace/internal/telemetry/metrics_namespaces.pyx b/ddtrace/internal/telemetry/metrics_namespaces.pyx index 25060787aec..e30da84d4ef 100644 --- a/ddtrace/internal/telemetry/metrics_namespaces.pyx +++ b/ddtrace/internal/telemetry/metrics_namespaces.pyx @@ -4,7 +4,7 @@ import time from typing import Optional from typing import Tuple -from ddtrace.internal import forksafe +from ddtrace.internal.threads import Lock from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE from ddtrace.internal.telemetry.constants import TELEMETRY_TYPE_DISTRIBUTION from ddtrace.internal.telemetry.constants import TELEMETRY_TYPE_GENERATE_METRICS @@ -25,7 +25,7 @@ cdef class MetricNamespace: cdef public dict _metrics_data def __cinit__(self): - self._metrics_data_lock = forksafe.Lock() + self._metrics_data_lock = Lock() self._metrics_data = {} def flush(self, interval: float = None): diff --git a/ddtrace/internal/threads.py b/ddtrace/internal/threads.py new file mode 100644 index 00000000000..6f99b009e61 --- /dev/null +++ b/ddtrace/internal/threads.py @@ -0,0 +1,38 @@ +import atexit + +from ddtrace.internal import forksafe +from ddtrace.internal._threads import Lock +from ddtrace.internal._threads import PeriodicThread +from ddtrace.internal._threads import RLock +from ddtrace.internal._threads import periodic_threads +from ddtrace.internal._threads import reset_locks + + +__all__ = [ + "Lock", + "PeriodicThread", + "RLock", +] + + +@atexit.register +def _(): + # If the interpreter is shutting down we need to make sure that the threads + # are stopped before the runtime is marked as finalising. This is because + # any attempt to acquire the GIL while the runtime is finalising will cause + # the acquiring thread to be terminated with pthread_exit (on Linux). This + # causes a SIGABRT with GCC that cannot be caught, so we need to avoid + # getting to that stage. + for thread in periodic_threads.values(): + thread._atexit() + + +@forksafe.register +def _() -> None: + # No threads are running after a fork so we clean up the periodic threads + for thread in periodic_threads.values(): + thread._after_fork() + periodic_threads.clear() + + +forksafe.register(reset_locks) diff --git a/ddtrace/internal/utils/cache.py b/ddtrace/internal/utils/cache.py index 9c05a726315..5baa761a258 100644 --- a/ddtrace/internal/utils/cache.py +++ b/ddtrace/internal/utils/cache.py @@ -2,13 +2,14 @@ from inspect import FullArgSpec from inspect import getfullargspec from inspect import isgeneratorfunction -from threading import RLock from typing import Any # noqa:F401 from typing import Callable # noqa:F401 from typing import Optional # noqa:F401 from typing import Type # noqa:F401 from typing import TypeVar # noqa:F401 +from ddtrace.internal.threads import RLock + miss = object() diff --git a/ddtrace/internal/writer/writer.py b/ddtrace/internal/writer/writer.py index ed45c3cce31..51f7efbd41c 100644 --- a/ddtrace/internal/writer/writer.py +++ b/ddtrace/internal/writer/writer.py @@ -5,7 +5,6 @@ import logging import os import sys -import threading from typing import TYPE_CHECKING from typing import Callable from typing import Dict @@ -15,7 +14,7 @@ import ddtrace from ddtrace import config -import ddtrace.internal.utils.http +from ddtrace.internal.threads import RLock from ddtrace.internal.utils.retry import fibonacci_backoff_with_jitter from ddtrace.settings._agent import config as agent_config from ddtrace.settings.asm import config as asm_config @@ -184,7 +183,7 @@ def __init__( # The connection has to be locked since there exists a race between # the periodic thread of HTTPWriter and other threads that might # force a flush with `flush_queue()`. - self._conn_lck: threading.RLock = threading.RLock() + self._conn_lck = RLock() self._send_payload_with_backoff = fibonacci_backoff_with_jitter( # type ignore[assignment] attempts=self.RETRY_ATTEMPTS, diff --git a/ddtrace/llmobs/_evaluators/runner.py b/ddtrace/llmobs/_evaluators/runner.py index 7c00b9543e7..cb2e5e9c799 100644 --- a/ddtrace/llmobs/_evaluators/runner.py +++ b/ddtrace/llmobs/_evaluators/runner.py @@ -3,12 +3,12 @@ from typing import List from typing import Tuple -from ddtrace.internal import forksafe from ddtrace.internal.logger import get_logger from ddtrace.internal.periodic import PeriodicService from ddtrace.internal.service import ServiceStatus from ddtrace.internal.telemetry import telemetry_writer from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE +from ddtrace.internal.threads import RLock from ddtrace.llmobs._evaluators.ragas.answer_relevancy import RagasAnswerRelevancyEvaluator from ddtrace.llmobs._evaluators.ragas.context_precision import RagasContextPrecisionEvaluator from ddtrace.llmobs._evaluators.ragas.faithfulness import RagasFaithfulnessEvaluator @@ -38,7 +38,7 @@ class EvaluatorRunner(PeriodicService): def __init__(self, interval: float, llmobs_service=None, evaluators=None): super(EvaluatorRunner, self).__init__(interval=interval) - self._lock = forksafe.RLock() + self._lock = RLock() self._buffer: List[Tuple[LLMObsSpanEvent, Span]] = [] self._buffer_limit = 1000 diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index f01dbfd6ce4..db221294344 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -38,6 +38,7 @@ from ddtrace.internal.service import ServiceStatusError from ddtrace.internal.telemetry import telemetry_writer from ddtrace.internal.telemetry.constants import TELEMETRY_APM_PRODUCT +from ddtrace.internal.threads import RLock from ddtrace.internal.utils.formats import asbool from ddtrace.internal.utils.formats import format_trace_id from ddtrace.internal.utils.formats import parse_tags_str @@ -205,7 +206,7 @@ def __init__( self._link_tracker = LinkTracker() self._annotations: List[Tuple[str, str, Dict[str, Any]]] = [] - self._annotation_context_lock = forksafe.RLock() + self._annotation_context_lock = RLock() self._tool_call_tracker = ToolCallTracker() diff --git a/ddtrace/llmobs/_log_writer.py b/ddtrace/llmobs/_log_writer.py index 4a4b86dcdf9..4b7ba18d167 100644 --- a/ddtrace/llmobs/_log_writer.py +++ b/ddtrace/llmobs/_log_writer.py @@ -11,9 +11,9 @@ import http.client as httplib -from ddtrace.internal import forksafe from ddtrace.internal.logger import get_logger from ddtrace.internal.periodic import PeriodicService +from ddtrace.internal.threads import RLock logger = get_logger(__name__) @@ -50,7 +50,7 @@ class V2LogWriter(PeriodicService): def __init__(self, site, api_key, interval, timeout): # type: (str, str, float, float) -> None super(V2LogWriter, self).__init__(interval=interval) - self._lock = forksafe.RLock() + self._lock = RLock() self._buffer = [] # type: List[V2LogEvent] # match the API limit self._buffer_limit = 1000 diff --git a/ddtrace/llmobs/_writer.py b/ddtrace/llmobs/_writer.py index 6b04834b250..7e4967a310d 100644 --- a/ddtrace/llmobs/_writer.py +++ b/ddtrace/llmobs/_writer.py @@ -20,9 +20,9 @@ import ddtrace from ddtrace import config from ddtrace.internal import agent -from ddtrace.internal import forksafe from ddtrace.internal.logger import get_logger from ddtrace.internal.periodic import PeriodicService +from ddtrace.internal.threads import RLock from ddtrace.internal.utils.http import Response from ddtrace.internal.utils.http import get_connection from ddtrace.internal.utils.retry import fibonacci_backoff_with_jitter @@ -138,7 +138,7 @@ def __init__( _override_url: str = "", ) -> None: super(BaseLLMObsWriter, self).__init__(interval=interval) - self._lock = forksafe.RLock() + self._lock = RLock() self._buffer: List[Union[LLMObsSpanEvent, LLMObsEvaluationMetricEvent]] = [] self._buffer_size: int = 0 self._timeout: float = timeout diff --git a/ddtrace/opentracer/span.py b/ddtrace/opentracer/span.py index 75bb522d06f..8672a400001 100644 --- a/ddtrace/opentracer/span.py +++ b/ddtrace/opentracer/span.py @@ -1,4 +1,3 @@ -import threading from typing import TYPE_CHECKING # noqa:F401 from typing import Any # noqa:F401 from typing import Dict # noqa:F401 @@ -14,6 +13,7 @@ from ddtrace.constants import ERROR_TYPE from ddtrace.internal.compat import NumericType # noqa:F401 from ddtrace.internal.constants import SPAN_API_OPENTRACING +from ddtrace.internal.threads import Lock from ddtrace.trace import Context as DatadogContext # noqa:F401 from ddtrace.trace import Span as DatadogSpan @@ -41,7 +41,7 @@ def __init__(self, tracer, context, operation_name): super(Span, self).__init__(tracer, context) self.finished = False - self._lock = threading.Lock() + self._lock = Lock() # use a datadog span self._dd_span = DatadogSpan(operation_name, context=context._dd_context, span_api=SPAN_API_OPENTRACING) diff --git a/tests/internal/test_forksafe.py b/tests/internal/test_forksafe.py index 823de102212..97bf1257f11 100644 --- a/tests/internal/test_forksafe.py +++ b/tests/internal/test_forksafe.py @@ -5,6 +5,7 @@ import pytest from ddtrace.internal import forksafe +from ddtrace.internal import threads def test_forksafe(): @@ -173,7 +174,7 @@ def f3(): def test_lock_basic(): # type: (...) -> None """Check that a forksafe.Lock implements the correct threading.Lock interface""" - lock = forksafe.Lock() + lock = threads.Lock() assert lock.acquire() assert lock.release() is None with pytest.raises(lock_release_exc_type): @@ -185,7 +186,7 @@ def test_lock_fork(): This test fails with a regular threading.Lock. """ - lock = forksafe.Lock() + lock = threads.Lock() lock.acquire() pid = os.fork() @@ -208,7 +209,7 @@ def test_lock_fork(): def test_rlock_basic(): # type: (...) -> None """Check that a forksafe.RLock implements the correct threading.RLock interface""" - lock = forksafe.RLock() + lock = threads.RLock() assert lock.acquire() assert lock.acquire() assert lock.release() is None @@ -222,7 +223,7 @@ def test_rlock_fork(): This test fails with a regular threading.RLock. """ - lock = forksafe.RLock() + lock = threads.RLock() lock.acquire() lock.acquire()