Skip to content

Commit da3c401

Browse files
committed
Make dict watcher API thread-safe for free-threaded builds
1 parent 7eb00ad commit da3c401

File tree

8 files changed

+157
-28
lines changed

8 files changed

+157
-28
lines changed

Include/internal/pycore_dict.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ _PyDict_NotifyEvent(PyDict_WatchEvent event,
288288
PyObject *value)
289289
{
290290
assert(Py_REFCNT((PyObject*)mp) > 0);
291-
int watcher_bits = mp->_ma_watcher_tag & DICT_WATCHER_MASK;
291+
uint64_t tag = FT_ATOMIC_LOAD_UINT64_RELAXED(mp->_ma_watcher_tag);
292+
int watcher_bits = tag & DICT_WATCHER_MASK;
292293
if (watcher_bits) {
293294
RARE_EVENT_STAT_INC(watched_dict_modification);
294295
_PyDict_SendEvent(watcher_bits, event, mp, key, value);
@@ -364,7 +365,8 @@ PyDictObject *_PyObject_MaterializeManagedDict_LockHeld(PyObject *);
364365
static inline Py_ssize_t
365366
_PyDict_UniqueId(PyDictObject *mp)
366367
{
367-
return (Py_ssize_t)(mp->_ma_watcher_tag >> DICT_UNIQUE_ID_SHIFT);
368+
uint64_t tag = FT_ATOMIC_LOAD_UINT64_RELAXED(mp->_ma_watcher_tag);
369+
return (Py_ssize_t)(tag >> DICT_UNIQUE_ID_SHIFT);
368370
}
369371

370372
static inline void

Include/internal/pycore_dict_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ extern "C" {
88
# error "this header requires Py_BUILD_CORE define"
99
#endif
1010

11+
#include "pycore_lock.h" // PyMutex
12+
1113
#define DICT_MAX_WATCHERS 8
1214
#define DICT_WATCHED_MUTATION_BITS 4
1315

1416
struct _Py_dict_state {
1517
uint32_t next_keys_version;
1618
PyDict_WatchCallback watchers[DICT_MAX_WATCHERS];
19+
PyMutex watcher_mutex; // Protects the watchers array (free-threaded builds)
1720
};
1821

1922
#define _dict_state_INIT \

Include/internal/pycore_pyatomic_ft_wrappers.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ extern "C" {
4949
_Py_atomic_load_uint16_relaxed(&value)
5050
#define FT_ATOMIC_LOAD_UINT32_RELAXED(value) \
5151
_Py_atomic_load_uint32_relaxed(&value)
52+
#define FT_ATOMIC_LOAD_UINT64_RELAXED(value) \
53+
_Py_atomic_load_uint64_relaxed(&value)
5254
#define FT_ATOMIC_LOAD_ULONG_RELAXED(value) \
5355
_Py_atomic_load_ulong_relaxed(&value)
5456
#define FT_ATOMIC_STORE_PTR_RELAXED(value, new_value) \
@@ -125,6 +127,8 @@ extern "C" {
125127
_Py_atomic_load_ullong_relaxed(&value)
126128
#define FT_ATOMIC_ADD_SSIZE(value, new_value) \
127129
(void)_Py_atomic_add_ssize(&value, new_value)
130+
#define FT_ATOMIC_ADD_UINT64(value, new_value) \
131+
(void)_Py_atomic_add_uint64(&value, new_value)
128132
#define FT_MUTEX_LOCK(lock) PyMutex_Lock(lock)
129133
#define FT_MUTEX_UNLOCK(lock) PyMutex_Unlock(lock)
130134

@@ -144,6 +148,7 @@ extern "C" {
144148
#define FT_ATOMIC_LOAD_UINT8_RELAXED(value) value
145149
#define FT_ATOMIC_LOAD_UINT16_RELAXED(value) value
146150
#define FT_ATOMIC_LOAD_UINT32_RELAXED(value) value
151+
#define FT_ATOMIC_LOAD_UINT64_RELAXED(value) value
147152
#define FT_ATOMIC_LOAD_ULONG_RELAXED(value) value
148153
#define FT_ATOMIC_STORE_PTR_RELAXED(value, new_value) value = new_value
149154
#define FT_ATOMIC_STORE_PTR_RELEASE(value, new_value) value = new_value
@@ -182,6 +187,7 @@ extern "C" {
182187
#define FT_ATOMIC_LOAD_ULLONG_RELAXED(value) value
183188
#define FT_ATOMIC_STORE_ULLONG_RELAXED(value, new_value) value = new_value
184189
#define FT_ATOMIC_ADD_SSIZE(value, new_value) (void)(value += new_value)
190+
#define FT_ATOMIC_ADD_UINT64(value, new_value) (void)(value += new_value)
185191
#define FT_MUTEX_LOCK(lock) do {} while (0)
186192
#define FT_MUTEX_UNLOCK(lock) do {} while (0)
187193

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import unittest
2+
3+
from test.support import import_helper, threading_helper
4+
5+
_testcapi = import_helper.import_module("_testcapi")
6+
7+
ITERS = 100
8+
NTHREADS = 20
9+
10+
11+
@threading_helper.requires_working_threading()
12+
class TestDictWatcherThreadSafety(unittest.TestCase):
13+
# Watcher kinds from _testcapi
14+
EVENTS = 0 # appends dict events as strings to global event list
15+
16+
def test_concurrent_add_clear_watchers(self):
17+
"""Race AddWatcher and ClearWatcher from multiple threads.
18+
19+
Uses more threads than available watcher slots (5 user slots out
20+
of DICT_MAX_WATCHERS=8).
21+
"""
22+
results = []
23+
24+
def worker():
25+
for _ in range(ITERS):
26+
try:
27+
wid = _testcapi.add_dict_watcher(self.EVENTS)
28+
except RuntimeError:
29+
continue # All slots taken
30+
self.assertGreaterEqual(wid, 0)
31+
results.append(wid)
32+
_testcapi.clear_dict_watcher(wid)
33+
34+
threading_helper.run_concurrently(worker, NTHREADS)
35+
36+
# Verify at least some watchers were successfully added
37+
self.assertGreater(len(results), 0)
38+
39+
def test_concurrent_watch_unwatch(self):
40+
"""Race Watch and Unwatch on the same dict from multiple threads."""
41+
wid = _testcapi.add_dict_watcher(self.EVENTS)
42+
dicts = [{} for _ in range(10)]
43+
44+
def worker():
45+
for _ in range(ITERS):
46+
for d in dicts:
47+
_testcapi.watch_dict(wid, d)
48+
for d in dicts:
49+
_testcapi.unwatch_dict(wid, d)
50+
51+
try:
52+
threading_helper.run_concurrently(worker, NTHREADS)
53+
54+
# Verify watching still works after concurrent watch/unwatch
55+
_testcapi.watch_dict(wid, dicts[0])
56+
dicts[0]["key"] = "value"
57+
events = _testcapi.get_dict_watcher_events()
58+
self.assertIn("new:key:value", events)
59+
finally:
60+
_testcapi.clear_dict_watcher(wid)
61+
62+
def test_concurrent_modify_watched_dict(self):
63+
"""Race dict mutations (triggering callbacks) with watch/unwatch."""
64+
wid = _testcapi.add_dict_watcher(self.EVENTS)
65+
d = {}
66+
_testcapi.watch_dict(wid, d)
67+
68+
def mutator():
69+
for i in range(ITERS):
70+
d[f"key_{i}"] = i
71+
d.pop(f"key_{i}", None)
72+
73+
def toggler():
74+
for i in range(ITERS):
75+
_testcapi.watch_dict(wid, d)
76+
d[f"toggler_{i}"] = i
77+
_testcapi.unwatch_dict(wid, d)
78+
79+
workers = [mutator, toggler] * (NTHREADS // 2)
80+
try:
81+
threading_helper.run_concurrently(workers)
82+
events = _testcapi.get_dict_watcher_events()
83+
self.assertGreater(len(events), 0)
84+
finally:
85+
_testcapi.clear_dict_watcher(wid)
86+
87+
88+
if __name__ == "__main__":
89+
unittest.main()

Modules/_testcapi/watchers.c

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "pycore_function.h" // FUNC_MAX_WATCHERS
1010
#include "pycore_interp_structs.h" // CODE_MAX_WATCHERS
1111
#include "pycore_context.h" // CONTEXT_MAX_WATCHERS
12+
#include "pycore_lock.h" // _PyOnceFlag
1213

1314
/*[clinic input]
1415
module _testcapi
@@ -18,6 +19,14 @@ module _testcapi
1819
// Test dict watching
1920
static PyObject *g_dict_watch_events = NULL;
2021
static int g_dict_watchers_installed = 0;
22+
static _PyOnceFlag g_dict_watch_once = {0};
23+
24+
static int
25+
_init_dict_watch_events(void *arg)
26+
{
27+
g_dict_watch_events = PyList_New(0);
28+
return g_dict_watch_events ? 0 : -1;
29+
}
2130

2231
static int
2332
dict_watch_callback(PyDict_WatchEvent event,
@@ -106,13 +115,10 @@ add_dict_watcher(PyObject *self, PyObject *kind)
106115
if (watcher_id < 0) {
107116
return NULL;
108117
}
109-
if (!g_dict_watchers_installed) {
110-
assert(!g_dict_watch_events);
111-
if (!(g_dict_watch_events = PyList_New(0))) {
112-
return NULL;
113-
}
118+
if (_PyOnceFlag_CallOnce(&g_dict_watch_once, _init_dict_watch_events, NULL) < 0) {
119+
return NULL;
114120
}
115-
g_dict_watchers_installed++;
121+
_Py_atomic_add_int(&g_dict_watchers_installed, 1);
116122
return PyLong_FromLong(watcher_id);
117123
}
118124

@@ -122,10 +128,8 @@ clear_dict_watcher(PyObject *self, PyObject *watcher_id)
122128
if (PyDict_ClearWatcher(PyLong_AsLong(watcher_id))) {
123129
return NULL;
124130
}
125-
g_dict_watchers_installed--;
126-
if (!g_dict_watchers_installed) {
127-
assert(g_dict_watch_events);
128-
Py_CLEAR(g_dict_watch_events);
131+
if (_Py_atomic_add_int(&g_dict_watchers_installed, -1) == 1) {
132+
PyList_Clear(g_dict_watch_events);
129133
}
130134
Py_RETURN_NONE;
131135
}
@@ -164,7 +168,7 @@ _testcapi_unwatch_dict_impl(PyObject *module, int watcher_id, PyObject *dict)
164168
static PyObject *
165169
get_dict_watcher_events(PyObject *self, PyObject *Py_UNUSED(args))
166170
{
167-
if (!g_dict_watch_events) {
171+
if (_Py_atomic_load_int(&g_dict_watchers_installed) <= 0) {
168172
PyErr_SetString(PyExc_RuntimeError, "no watchers active");
169173
return NULL;
170174
}

Objects/dictobject.c

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7842,13 +7842,19 @@ validate_watcher_id(PyInterpreterState *interp, int watcher_id)
78427842
PyErr_Format(PyExc_ValueError, "Invalid dict watcher ID %d", watcher_id);
78437843
return -1;
78447844
}
7845-
if (!interp->dict_state.watchers[watcher_id]) {
7845+
PyDict_WatchCallback cb = FT_ATOMIC_LOAD_PTR_RELAXED(
7846+
interp->dict_state.watchers[watcher_id]);
7847+
if (cb == NULL) {
78467848
PyErr_Format(PyExc_ValueError, "No dict watcher set for ID %d", watcher_id);
78477849
return -1;
78487850
}
78497851
return 0;
78507852
}
78517853

7854+
// In free-threaded builds, Add/Clear serialize on watcher_mutex and publish
7855+
// callbacks with release stores. SendEvent reads them lock-free using
7856+
// acquire loads.
7857+
78527858
int
78537859
PyDict_Watch(int watcher_id, PyObject* dict)
78547860
{
@@ -7860,7 +7866,9 @@ PyDict_Watch(int watcher_id, PyObject* dict)
78607866
if (validate_watcher_id(interp, watcher_id)) {
78617867
return -1;
78627868
}
7863-
((PyDictObject*)dict)->_ma_watcher_tag |= (1LL << watcher_id);
7869+
Py_BEGIN_CRITICAL_SECTION(dict);
7870+
((PyDictObject*)dict)->_ma_watcher_tag |= (1ULL << watcher_id);
7871+
Py_END_CRITICAL_SECTION();
78647872
return 0;
78657873
}
78667874

@@ -7875,36 +7883,47 @@ PyDict_Unwatch(int watcher_id, PyObject* dict)
78757883
if (validate_watcher_id(interp, watcher_id)) {
78767884
return -1;
78777885
}
7878-
((PyDictObject*)dict)->_ma_watcher_tag &= ~(1LL << watcher_id);
7886+
Py_BEGIN_CRITICAL_SECTION(dict);
7887+
((PyDictObject*)dict)->_ma_watcher_tag &= ~(1ULL << watcher_id);
7888+
Py_END_CRITICAL_SECTION();
78797889
return 0;
78807890
}
78817891

78827892
int
78837893
PyDict_AddWatcher(PyDict_WatchCallback callback)
78847894
{
7895+
int watcher_id = -1;
78857896
PyInterpreterState *interp = _PyInterpreterState_GET();
78867897

7898+
FT_MUTEX_LOCK(&interp->dict_state.watcher_mutex);
78877899
/* Some watchers are reserved for CPython, start at the first available one */
78887900
for (int i = FIRST_AVAILABLE_WATCHER; i < DICT_MAX_WATCHERS; i++) {
78897901
if (!interp->dict_state.watchers[i]) {
7890-
interp->dict_state.watchers[i] = callback;
7891-
return i;
7902+
FT_ATOMIC_STORE_PTR_RELEASE(interp->dict_state.watchers[i], callback);
7903+
watcher_id = i;
7904+
goto done;
78927905
}
78937906
}
7894-
78957907
PyErr_SetString(PyExc_RuntimeError, "no more dict watcher IDs available");
7896-
return -1;
7908+
done:
7909+
FT_MUTEX_UNLOCK(&interp->dict_state.watcher_mutex);
7910+
return watcher_id;
78977911
}
78987912

78997913
int
79007914
PyDict_ClearWatcher(int watcher_id)
79017915
{
7916+
int res = 0;
79027917
PyInterpreterState *interp = _PyInterpreterState_GET();
7918+
FT_MUTEX_LOCK(&interp->dict_state.watcher_mutex);
79037919
if (validate_watcher_id(interp, watcher_id)) {
7904-
return -1;
7920+
res = -1;
7921+
goto done;
79057922
}
7906-
interp->dict_state.watchers[watcher_id] = NULL;
7907-
return 0;
7923+
FT_ATOMIC_STORE_PTR_RELEASE(interp->dict_state.watchers[watcher_id], NULL);
7924+
done:
7925+
FT_MUTEX_UNLOCK(&interp->dict_state.watcher_mutex);
7926+
return res;
79087927
}
79097928

79107929
static const char *
@@ -7929,7 +7948,8 @@ _PyDict_SendEvent(int watcher_bits,
79297948
PyInterpreterState *interp = _PyInterpreterState_GET();
79307949
for (int i = 0; i < DICT_MAX_WATCHERS; i++) {
79317950
if (watcher_bits & 1) {
7932-
PyDict_WatchCallback cb = interp->dict_state.watchers[i];
7951+
PyDict_WatchCallback cb = FT_ATOMIC_LOAD_PTR_ACQUIRE(
7952+
interp->dict_state.watchers[i]);
79337953
if (cb && (cb(event, (PyObject*)mp, key, value) < 0)) {
79347954
// We don't want to resurrect the dict by potentially having an
79357955
// unraisablehook keep a reference to it, so we don't pass the

Python/optimizer_analysis.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "pycore_opcode_metadata.h"
1919
#include "pycore_opcode_utils.h"
2020
#include "pycore_pystate.h" // _PyInterpreterState_GET()
21+
#include "pycore_pyatomic_ft_wrappers.h" // FT_MUTEX_LOCK/UNLOCK
2122
#include "pycore_tstate.h" // _PyThreadStateImpl
2223
#include "pycore_uop_metadata.h"
2324
#include "pycore_long.h"
@@ -117,14 +118,15 @@ static int
117118
get_mutations(PyObject* dict) {
118119
assert(PyDict_CheckExact(dict));
119120
PyDictObject *d = (PyDictObject *)dict;
120-
return (d->_ma_watcher_tag >> DICT_MAX_WATCHERS) & ((1 << DICT_WATCHED_MUTATION_BITS)-1);
121+
uint64_t tag = FT_ATOMIC_LOAD_UINT64_RELAXED(d->_ma_watcher_tag);
122+
return (tag >> DICT_MAX_WATCHERS) & ((1 << DICT_WATCHED_MUTATION_BITS) - 1);
121123
}
122124

123125
static void
124126
increment_mutations(PyObject* dict) {
125127
assert(PyDict_CheckExact(dict));
126128
PyDictObject *d = (PyDictObject *)dict;
127-
d->_ma_watcher_tag += (1 << DICT_MAX_WATCHERS);
129+
FT_ATOMIC_ADD_UINT64(d->_ma_watcher_tag, 1ULL << DICT_MAX_WATCHERS);
128130
}
129131

130132
/* The first two dict watcher IDs are reserved for CPython,
@@ -467,8 +469,10 @@ optimize_uops(
467469

468470
// Make sure that watchers are set up
469471
PyInterpreterState *interp = _PyInterpreterState_GET();
470-
if (interp->dict_state.watchers[GLOBALS_WATCHER_ID] == NULL) {
471-
interp->dict_state.watchers[GLOBALS_WATCHER_ID] = globals_watcher_callback;
472+
if (FT_ATOMIC_LOAD_PTR_RELAXED(interp->dict_state.watchers[GLOBALS_WATCHER_ID]) == NULL) {
473+
FT_ATOMIC_STORE_PTR_RELEASE(
474+
interp->dict_state.watchers[GLOBALS_WATCHER_ID],
475+
globals_watcher_callback);
472476
interp->type_watchers[TYPE_WATCHER_ID] = type_watcher_callback;
473477
}
474478

Python/pystate.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ _Py_COMP_DIAG_POP
320320
&(runtime)->allocators.mutex, \
321321
&(runtime)->_main_interpreter.types.mutex, \
322322
&(runtime)->_main_interpreter.code_state.mutex, \
323+
&(runtime)->_main_interpreter.dict_state.watcher_mutex, \
323324
}
324325

325326
static void

0 commit comments

Comments
 (0)